Security News
Introducing the Socket Python SDK
The initial version of the Socket Python SDK is now on PyPI, enabling developers to more easily interact with the Socket REST API in Python projects.
PyTorch implementation of low-rank adaptation (LoRA), a parameter-efficient approach to adapt a large pre-trained deep learning model which obtains performance on-par with full fine-tuning.
(For the radio communication technique, see LoRa.)
This repo contains the source code of the Python package loralib
and several examples of how to integrate it with PyTorch models, such as those in Hugging Face.
We only support PyTorch for now.
See our paper for a detailed description of LoRA.
LoRA: Low-Rank Adaptation of Large Language Models
Edward J. Hu*, Yelong Shen*, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen
Paper: https://arxiv.org/abs/2106.09685
Update 2/2023: LoRA is now supported by the State-of-the-art Parameter-Efficient Fine-Tuning (PEFT) library by Hugging Face.
LoRA reduces the number of trainable parameters by learning pairs of rank-decompostion matrices while freezing the original weights. This vastly reduces the storage requirement for large language models adapted to specific tasks and enables efficient task-switching during deployment all without introducing inference latency. LoRA also outperforms several other adaptation methods including adapter, prefix-tuning, and fine-tuning.
We obtain result comparable or superior to full finetuning on the GLUE benchmark using RoBERTa (Liu et al., 2019) base and large and DeBERTa (He et al., 2020) XXL 1.5B, while only training and storing a fraction of the parameters. Click the numbers below to download the RoBERTa and DeBERTa LoRA checkpoints.
RoBERTa base Fine-tune | RoBERTa base LoRA | DeBERTa XXL Fine-tune | DeBERTa XXL LoRA | ||
---|---|---|---|---|---|
# of Trainable Params. | 125M | 0.8M | 1.5B | 4.7M | |
MNLI (m-Acc/mm-Acc) | 87.6 | 87.5±.3/86.9±.3 | 91.7/91.9 | 91.9±.1/91.9±.2 | |
SST2 (Acc) | 94.8 | 95.1±.2 | 97.2 | 96.9±.2 | |
MRPC (Acc) | 90.2 | 89.7±.7 | 92.0 | 92.6±.6 | |
CoLA (Matthew's Corr) | 63.6 | 63.4±1.2 | 72.0 | 72.4±1.1 | |
QNLI (Acc) | 92.8 | 93.3±.3 | 96.0 | 96.0±.1 | |
QQP (Acc) | 91.9 | 90.8±.1 | 92.7 | 92.9±.1 | |
RTE (Acc) | 78.7 | 86.6±.7 | 93.9 | 94.9±.4 | |
STSB (Pearson/Spearman Corr) | 91.2 | 91.5±.2/91.3±.2 | 92.9/92.6 | 93.0±.2/92.9±.3 | |
Average | 86.40 | 87.24 | 91.06 | 91.32 |
Note: You still need the original pre-trained checkpoint from Hugging Face to use the LoRA checkpoints.
Fine-tuning numbers are taken from Liu et al. (2019) and He et al. (2020). We include confidence intervals on results from our experiments. Please follow the instructions in examples/NLU/
to reproduce our results.
On GPT-2, LoRA compares favorably to both full finetuning and other efficient tuning methods, such as adapter (Houlsby et al., 2019) and prefix tuning (Li and Liang, 2021). We evaluated on E2E NLG Challenge, DART, and WebNLG:
Method | # of Trainable Params | E2E (BLEU) | DART (BLEU) | WebNLG (BLEU-U/S/A) | |
---|---|---|---|---|---|
GPT-2 M (Fine-Tune) | 354.92M | 68.2 | 46.0 | 30.4/63.2/47.6 | |
GPT-2 M (Adapter) | 0.37M | 66.3 | 42.4 | 45.1/54.5/50.2 | |
GPT-2 M (Prefix) | 0.35M | 69.7 | 45.7 | 44.1/63.1/54.4 | |
GPT-2 M (LoRA) | 0.35M | 70.4±.1 | 47.1±.2 | 46.7±.4/62.1±.2/55.3±.2 | |
GPT-2 L (Fine-Tune) | 774.03M | 68.5 | 46.5 | 41.7/64.6/54.2 | |
GPT-2 L (Adapter) | 0.88M | 69.1±.1 | 45.7±.1 | 49.8±.0/61.1±.0/56.0±.0 | |
GPT-2 L (Prefix) | 0.77M | 70.3 | 46.5 | 47.0/64.2/56.4 | |
GPT-2 L (LoRA) | 0.77M | 70.4±.1 | 47.5±.1 | 48.4±.3/64.0±.3/57.0±.1 |
Non-LoRA baselines, except for adapter on GPT-2 large, are taken from Li and Liang (2021). We include confidence intervals on results from our experiments.
Download the GPT-2 LoRA checkpoints:
Please follow the instructions in examples/NLG/
to reproduce our result.
(The initial release of this repo has been archived in the branch "snapshot-9-15-2021")
There are several directories in this repo:
loralib
, which needs to be installed to run the examples we provide;loralib
in GPT-2, RoBERTa, and DeBERTa v2loralib
is simplypip install loralib
# Alternatively
# pip install git+https://github.com/microsoft/LoRA
loralib
. We only support nn.Linear
, nn.Embedding
, and nn.Conv2d
for now. We also support a MergedLinear
for cases where a single nn.Linear
represents more than one layers, such as in some implementations of the attention qkv
projection (see Additional Notes for more).# ===== Before =====
# layer = nn.Linear(in_features, out_features)
# ===== After ======
import loralib as lora
# Add a pair of low-rank adaptation matrices with rank r=16
layer = lora.Linear(in_features, out_features, r=16)
import loralib as lora
model = BigModel()
# This sets requires_grad to False for all parameters without the string "lora_" in their names
lora.mark_only_lora_as_trainable(model)
# Training loop
for batch in dataloader:
...
state_dict
that only contains LoRA parameters.# ===== Before =====
# torch.save(model.state_dict(), checkpoint_path)
# ===== After =====
torch.save(lora.lora_state_dict(model), checkpoint_path)
load_state_dict
, be sure to set strict=False
.# Load the pretrained checkpoint first
model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False)
# Then load the LoRA checkpoint
model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)
While we focus on a simple yet effect setup, namely adapting only the q
and v
projection in a Transformer, in our examples, LoRA can be apply to any subsets of pre-trained weights. We encourage you to explore different configurations, such as adapting the embedding layer by replacing nn.Embedding
with lora.Embedding
and/or adapting the MLP layers. It's very likely that the optimal configuration varies for different model architectures and tasks.
Some Transformer implementation uses a single nn.Linear
for the projection matrices for query, key, and value. If one wishes to constrain the rank of the updates to the individual matrices, one has to either break it up into three separate matrices or use lora.MergedLinear
. Make sure to modify the checkpoint accordingly if you choose to break up the layer.
# ===== Before =====
# qkv_proj = nn.Linear(d_model, 3*d_model)
# ===== After =====
# Break it up (remember to modify the pretrained checkpoint accordingly)
q_proj = lora.Linear(d_model, d_model, r=8)
k_proj = nn.Linear(d_model, d_model)
v_proj = lora.Linear(d_model, d_model, r=8)
# Alternatively, use lora.MergedLinear (recommended)
qkv_proj = lora.MergedLinear(d_model, 3*d_model, r=8, enable_lora=[True, False, True])
lora
. You can mark some biases as trainable by passing "all" or "lora_only" to bias=
when calling mark_only_lora_as_trainable
. Remember to pass the corresponding bias=
argument to lora_state_dict
when saving a checkpoint.# ===== Before =====
# lora.mark_only_lora_as_trainable(model) # Not training any bias vectors
# ===== After =====
# Training all bias vectors associated with modules we apply LoRA to
lora.mark_only_lora_as_trainable(model, bias='lora_only')
# Alternatively, we can train *all* bias vectors in the model, including LayerNorm biases
lora.mark_only_lora_as_trainable(model, bias='all')
# When saving a checkpoint, use the same bias= ('all' or 'lora_only')
torch.save(lora.lora_state_dict(model, bias='all'), checkpoint_path)
model.eval()
will trigger the merging of LoRA parameters with the corresponding pretrained ones, which eliminates additional latency for subsequent forward passes. Calling model.train()
again will undo the merge. This can be disabled by passing merge_weights=False
to LoRA layers.Please contact us or post an issue if you have any questions.
For questions related to the package loralib
:
The GPT-2 example:
The RoBERTa/DeBERTa example:
We thank in alphabetical order Jianfeng Gao, Jade Huang, Jiayuan Huang, Lisa Xiang Li, Xiaodong Liu, Yabin Liu, Benjamin Van Durme, Luis Vargas, Haoran Wei, Peter Welinder, and Greg Yang for providing valuable feedback.
@inproceedings{
hu2022lora,
title={Lo{RA}: Low-Rank Adaptation of Large Language Models},
author={Edward J Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Lu Wang and Weizhu Chen},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=nZeVKeeFYf9}
}
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.
FAQs
PyTorch implementation of low-rank adaptation (LoRA), a parameter-efficient approach to adapt a large pre-trained deep learning model which obtains performance on-par with full fine-tuning.
We found that loralib demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Security News
The initial version of the Socket Python SDK is now on PyPI, enabling developers to more easily interact with the Socket REST API in Python projects.
Security News
Floating dependency ranges in npm can introduce instability and security risks into your project by allowing unverified or incompatible versions to be installed automatically, leading to unpredictable behavior and potential conflicts.
Security News
A new Rust RFC proposes "Trusted Publishing" for Crates.io, introducing short-lived access tokens via OIDC to improve security and reduce risks associated with long-lived API tokens.