Muon: An optimizer for the hidden layers of neural networks
This repo contains an implementation of the Muon
optimizer originally described in this thread and this writeup.
Installation
pip install git+https://github.com/KellerJordan/Muon
Usage
Muon is intended to optimize only the internal ≥2D parameters of a network.
Embeddings, classifier heads, and internal gains/biases should be optimized using AdamW.
from muon import MuonWithAuxAdam
hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
exterior_weights = [*model.head.parameters(), *model.embed.parameters()])
muon_group = dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True)
adam_group = dict(params=hidden_gains_biases+exterior_weights, lr=3e-4,
betas=(0.9, 0.95), weight_decay=0.01, use_muon=False)
optimizer = MuonWithAuxAdam([muon_group, adam_group])
You'll have to replace model.body
, model.head
, and model.embed
with whatever subset is appropriate for your model.
E.g., for a ConvNet, Muon should optimize all the convolutional filters except the first one, and AdamW should optimize everything else.
Example usage
Example use in the NanoGPT speedrun
Example use in the CIFAR-10 speedrun
Hyperparameter tuning
Typically, the default values of momentum (0.95), nesterov (True), and ns_steps (5) work well. The only hyperparameter which must be tuned is the learning rate.
It should have constant muP scaling, that is, as you scale up the model size, you shouldn't need to retune the learning rate.
Benchmarks
For a comparison between AdamW, Shampoo, SOAP, and Muon for training a 124M-parameter transformer, see here.
Accomplishments
More learning resources and results about Muon
Citation
@misc{jordan2024muon,
author = {Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and
Franz Cesista and Laker Newhouse and Jeremy Bernstein},
title = {Muon: An optimizer for hidden layers in neural networks},
year = {2024},
url = {https://kellerjordan.github.io/posts/muon/}
}