Maximal Update Parametrization (μP) and Hyperparameter Transfer (μTransfer)
In Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer, we show that hyperparameters become stable across neural network sizes when we parametrize the model in maximal update parametrization (μP).
This can be used to tune extremely large neural networks such as large pretrained transformers, as we have done in our work.
More generally, μP reduces the fragility and uncertainty when transitioning from exploration to scaling up, which are not often talked about explicitly in the deep learning literature.
Figure above: Training loss against learning rate on Transformers of varying d_model
trained with Adam.
μP turns out to be the unique "natural" parametrization that has this hyperparameter stability property across width, as empirically verified in the gif below on MLPs trained with SGD. Here, across time, we interpolate between PyTorch default and μP's learning rate and initialization scalings (right), and we scale up the width-256 model (log2(width)=8) to width 2^13 = 8192 using this interpolated scaling rule (left).
This repo contains the source code for the mup
package, our tool that makes the implementation of μP in Pytorch models effortless and less error-prone.
Installation
pip install mup
Install From Source
Clone this repo, change to its directory, and do
pip install -r requirement.txt
pip install -e .
Basic Usage
from mup import MuReadout, make_base_shapes, set_base_shapes, MuSGD, MuAdam
class MyModel(nn.Module):
def __init__(self, width, ...):
...
readout = MuReadout(width, d_out)
...
def forward(self, ...):
...
attention_scores = query @ key.T * 8 / d
...
base_model = MyModel(width=1)
delta_model = MyModel(width=2)
model = MyModel(width=100)
set_base_shapes(model, base_model, delta=delta_model)
for param in model.parameters():
mup.init.uniform_(param, -0.1, 0.1)
optimizer = MuSGD(model.parameters(), lr=0.1)
Note the base and delta models do not need to be trained --- we are only extracting parameter shape information from them.
Ideally, we can do so without instantiating the model parameters at all, like in JAX, but unfortunately we currently can't do that in pytorch.
How mup
Works Under the Hood
By invoking set_base_shapes(model, ...)
, each parameter tensor p
of model
gets a p.infshape
attribute that stores, for each of its dimensions, the corresponding base dimension and whether that dimension should be considered infinite
(i.e. will be scaled up/down, e.g., d_model
of a Transformer) or finite
(i.e. will be fixed, e.g., vocabulary size).
This information is used in the initializers and optimizers to automatically scale the parameters or learning rates to be compliant with μP.
For example, the Adam learning rate of hidden weights p
is calculated as globalLR / p.infshape.width_mult()
, where p.infshape.width_mult()
essentially calculates fan_in / base_fan_in
.
Current Limitations
set_base_shapes(model, ...)
assumes that model
has just been randomly initialized in the standard way and rescales its parameters using the base shape information so the model is in μP.- If you want data parallelism, please use
torch.nn.parallel.DistributedDataParallel
instead of torch.nn.DataParallel
. This is because the latter removes the attributes the mup
package adds to each parameter tensor of the model. Also, for performance, pytorch
recommends the former anyway. - We scale the learning rate according to μP explicitly creating refined parameter groups from what is passed to the
mup
optimizer and by manipulating the lr
attribute in those groups. This means, if your code modifies the lr
in the optimizer param_groups
dynamically after the creation of the optimizer, then mup
might not work as expected. - By default, any parameter matrix that has 2 "infinite" dimensions (i.e. dimensions that are different from base dimensions) are considered by
mup
to have shape (fan_out, fan_in), i.e., in the forward pass, this matrix multiplies its input on the right. This is the case with all nn.Linear
weights from pytorch. If you have a custom parameter, say W
, that violates this convention, you can manually set W.infshape.main_idx = 0; W.infshape.main = W.infshape[0]
to let mup
know that its shape corresponds to (fan_in, fan_out). A similar discussion applies if you have a parameter tensor with many dimensions but exactly 2 "infinite" dimensions, for which the first is fan_in and the second is fan_out. - Currently,
torch.save
does not save the infshape
objects attached to each parameter tensor. Before this is fixed, you would have to set base shape manually after loading a model checkpoint like so:
model = torch.load('my/model/path.pt')
set_base_shapes(model, 'my/base/shape/path.bsh', rescale_params=False)
(set_base_shapes
by default rescales the parameters of model
, assuming it's freshly initialized by PyTorch, to be consistent with μP.
The rescale_params=False
flag turns off this behavior.)
Checking Correctness of Parametrization
Coord Check
Just like gradient checking is a simple way of verifying the correctness of an autograd implementation, coordinate checking is a simple way to verify you have implemented μP correctly: calculate the average size (which we denote in the y-axis below by l1
) of the coordinates of each activation vector in, and output of, the model, for a few steps of training and a few different widths.
If implemented correctly, then we shall see this l1
stable over many widths; otherwise, the l1
can blow up or shrink to 0 with width.
(We are essentially checking desideratum 1 described below.)
(The l1
calculates x.abs().mean()
for each activation vector x
and is just one measure of the "average size" of x
's entries; one can also use analogously defined l2
, l4
, etc, though they may exhibit greater fluctuation with random seeds.)
For example, in the following, we plot width
vs l1
for 2 steps of training, where t=1 means at initialization, before any gradient update.
Each curve corresponds to an (pre-)activation vector of a layer or the output of the network.
The first set of 3 plots shows an MLP in standard parametrization (SP), trained by adam.
We see after 1 step of update, activation/output l1
are exploding with width.
This means SP is "incorrect."
We now do the same for an MLP in maximal update parametrization (μP) (including using mup.optim.MuAdam
instead of torch.optim.Adam
).
In contrast to the above, all curves stay horizontal, indicating that μP is implemented correctly.
We call this way of checking implementation correctness a coord check, short for "coordinate check."
Making Your Own Coord Check Plots
We provide an easy way to implement this check via functions in the mup.coord_check
module.
The workflow typically looks like the following.
from mup.coord_check import get_coord_data, plot_coord_data
def lazy_model(width):
return lambda: set_base_shapes(MyMuModel(width), 'my/base/shape/path.bsh')
models = {64: lazy_model(64), ..., 1024: lazy_model(1024)}
dataloader = ...
df = get_coord_data(models, dataloader)
plot_coord_data(df, save_to=filename)
For example, the mup.coord_check.example_plot_coord_check
function is implemented this way for toy MLP and CNN models.
If you see the curves blow up or shrink to 0 with width after a few steps of training, then there's a bug in your μP implementation (did you forget to vary some dimension, like d_ffn
, in the delta model?).
If instead you see the curves converge to the right, then most likely your implementation is correct.
However, there are two typical exceptions to this;
the following can shrink to 0 at initialization in μP (at a 1/sqrt(width) rate):
- the network output
- the attention logits in a Transformer
These are transient, and after a few steps their curves should be roughly flat.
Nevertheless, to remove the discrepancy at init, we recommend
- initializing the output layer
(should be a
MuReadout
instance) weights to be 0 via
the readout_zero_init=True
option and - initializing the query matrix in a Transformer to 0
(this has to be done manually). If symmetry-breaking is desired in the attention logits at init, initialize the (relative) position biases with nonzero variance.
Wider is Always Better
Another sign that μP has not been implemented correctly is if going wider does worse (on training loss) after some width, at some point during training.
The figure above illustrates this in a collection of training curves: (left) the correct implementation should always see performance improve with width, at any point in training; (middle) if you used standard parametrization (SP), sometimes you may see performance improve with width up to some point and then suddenly it becomes worse with wider models; (right) or you may immediately see worsening performance even for narrow models.
Examples
See the MLP
, Transformer
, and ResNet
folders inside examples/
as well as the tests in mup/test
for examples.
People familiar with Huggingface Transformers may also find the examples/mutransformers
submodule instructive (obtained via git submodule update --init
), which is also available standalone at https://github.com/microsoft/mutransformers.
Running Tests
To run tests, do
python -m mup.test
The Basic Math
μP is designed so as to satisfy the following desiderata:
At any time during training
- Every (pre)activation vector in a network should have Θ(1)-sized coordinates
- Neural network output should be O(1).
- All parameters should be updated as much as possible (in terms of scaling in width) without leading to divergence
It turns out these desiderata uniquely single out μP.
To derive μP from them, one needs to carefully consider how the coordinate size of a vector Av, resulting from a square matrix A multiplying vector v, depends on those of A and v, when A and v are "correlated".
Here you can think of A as weights and v as an activation vector.
This in turn depends on what kind of matrix is A and what kind of vector is v.
In the context of training a wide neural network, it turns out we only need to consider vectors that has approximately iid coordinates, and two kinds of matrices: 1) those that look like outer products of such vectors, and 2) random iid matrices.
Those of type 1 cover things like weight gradients; those of type 2 cover things like weight initialization.
Then, if A and v both have entry size Θ(1) and they are correlated in ways that arise naturally during training, then we have the following table.
| outer product A (type 1) | iid A (type 2) |
---|
Entry size of Av | Θ(n) | Θ(sqrt(n)) |
Given this table, one can then trace the forward and backward computation of a network to derive μP straightforwardly.
See our blog post for a gentle primer and our paper for details.
Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the Microsoft Open Source Code of Conduct.
For more information see the Code of Conduct FAQ or
contact opencode@microsoft.com with any additional questions or comments.
Trademarks
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
Microsoft's Trademark & Brand Guidelines.
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.