Research
Security News
Malicious npm Package Targets Solana Developers and Hijacks Funds
A malicious npm package targets Solana developers, rerouting funds in 2% of transactions to a hardcoded address.
Schedule-Free Optimizers in PyTorch.
Preprint: The Road Less Scheduled
Authors: Aaron Defazio, Xingyu (Alice) Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
TLDR Faster training without schedules - no need to specify the stopping time/steps in advance!
pip install schedulefree
We provide several Schedule-Free optimizer implementations:
SGDScheduleFree
and SGDScheduleFreeReference
: Schedule-free variants of SGDAdamWScheduleFree
and AdamWScheduleFreeReference
: Schedule-free variants of AdamWRAdamScheduleFree
: Schedule-free variant of RAdam, which eliminates the need for both learning rate scheduling and warmup (implementation community contributed)ScheduleFreeWrapper
to combine with other optimizersScheduleFreeReference
versions have a simplified implementation, but which use more memory. There are also ScheduleFreeClosure
versions which can be used with PyTorch's optimizer step closures.
A Jax implementation is availiable as part of Optax.
Schedule-Free learning replaces the momentum of an underlying optimizer with a combination of interpolation and averaging. In the case of gradient descent, the basic Schedule-Free update is:
$$ \begin{align*} y_{t} & = (1-\beta)z_{t} + \beta x_{t},\ z_{t+1} & =z_{t}-\gamma\nabla f(y_{t}),\ x_{t+1} & =\left(1-\frac{1}{t+1}\right)x_{t}+\frac{1}{t+1}z_{t+1}, \end{align*} $$
Here $x$ is the sequence that evaluations of test/val loss should occur at, which differs from the primary iterates $z$ and the gradient evaluation locations $y$. The updates to $z$ correspond to the underlying optimizer, in this case a simple gradient step.
As the name suggests, Schedule-Free learning does not require a decreasing learning rate schedule, yet typically out-performs, or at worst matches, SOTA schedules such as cosine-decay and linear decay. Only two sequences need to be stored at a time (the third can be computed from the other two on the fly) so this method has the same memory requirements as the base optimizer (parameter buffer + momentum).
We provide both AdamW and SGD versions in this repo, as we as an experimental wrapper version that can be used with any base optimizer.
Since our optimizer uses two different points for gradient calls and test/val loss calculations, it's necessary to switch the param buffer between the two during training. This is done by calling optimizer.train()
at the same place you call model.train()
and optimizer.eval()
at the same place you call model.eval()
. The optimizer should also be placed in eval mode when storing checkpoints.
If your code supports PyTorch Optimizer step closures, you can use the closure forms of the optimizers, which do not require the .train()
and .eval()
calls.
If you use Schedule-Free training in your work, please cite our preprint as:
@misc{defazio2024road,
title={The Road Less Scheduled},
author={Aaron Defazio and Xingyu Yang and Harsh Mehta and Konstantin Mishchenko and Ahmed Khaled and Ashok Cutkosky},
year={2024},
eprint={2405.15682},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
New Version 1.4 adds a RAdam implementation by nhamanasu.
Version 1.3 changes the behavior of weight decay during learning rate warmup
to improve stabiliy and be more consistant with the behavior of standard AdamW in PyTorch. The previous implementation is still available as AdamWScheduleFreePaper
.
Examples of using the schedulefree
package can be found in the examples
folder. These include:
*Example is modified from Pytorch Examples Repo.
model.train()
optimizer.eval()
with torch.no_grad():
for batch in itertools.islice(train_loader, 50):
model(batch)
model.eval()
This will replace the training_mean
/training_var
cache (which is updated in each forward pass when in model.train() mode) with values calculated at $x$ instead of $y$. Using PreciseBN will also avoid this issue.
warmup_steps
parameter.We offer a highly experimental wrapper version ScheduleFreeWrapper
which can wrap any base optimizer. When using this version, you can disable the base optimizer's
momentum, as it's no longer necessary when using our wrapper's momentum (although you can use both types of momentum if you want).
Example usage:
base_optimizer = torch.optim.RMSprop(model.parameters(), lr=0.0025)
optimizer = ScheduleFreeWrapper(
base_optimizer, momentum=0.9, weight_decay_at_y=0.1)
If you set weight decay on the base optimizer, it computes weight decay at $z$. We offer the option to compute weight decay at $y$, via the weight_decay_at_y
parameter, which seems to give better results in our experiments.
We also include a ScheduleFreeWrapperReference version which uses more memory but is more numerically stable, we recommended this version for early experimentation or research work.
See the License file.
Schedule-Free learning can be seen as an interpolation between primal averaging ($\beta=1$) and Polyak-Ruppert averaging ($\beta=0)$. The advantage of this interpolation is that it allows us to get the best of both worlds. We can achieve the fast early stage convergence of Polyak-Ruppert averaging (since the $z$ sequence moves quicker than the $x$ sequence), without the $x$ sequence straying too far from the $z$ sequence, which causes instability.
Our method is also related to Nesterov's accelerated method (Nesterov, 1983) in AC-SA form (Ghadimi & Lan 2010):
$$ \begin{align*} y_{t} & =(1-2/(t+1))x_{t} + (2/(t+1))z_{t}\ z_{t+1} & =z_{t}-\frac{t}{2L}\nabla f(y_{t})\ x_{t+1} & =(1-2/(t+1))x_{t}+(2/(t+1))z_{t+1} \end{align*} $$
Our approach has the same three sequences, but uses very different weights, and crucially, does not include an increasing learning rate over time, which is essential for accelerated rates with Nesterov's method. We also use different weight sequences for the interpolation operation versus the averaging operation.
Tail averaging approaches such as Stochastic Weight Averaging (Izmailov et al., 2018) and LAtest Weight Averaging (Kaddour, 2022; Sanyal et al., 2023) combine averaging with large or cyclic learning rates. They still require the use of a schedule, introduce additional hyper-parameters to tune, and require additional memory compared to our technique. It is also possible to use SWA and LAWA on top of our approach, potentially giving further gains.
Portes et al. (2022) use cyclic learning rate schedules with increasing cycle periods to give a method that explores multiple points along the Pareto frontier of training time vs eval performance. Each point at the end of a cycle is an approximation to the model from a tuned schedule ending at that time. Our method gives the entire frontier, rather than just a few points along the path.
Exponential moving averages (EMA) of the iterate sequence are used in the popular Lookahead optimizer (Zhang et al., 2019). The Lookahead method can be seen as the EMA version of primal averaging, just as exponential weight averaging is the EMA version of Polyak-Ruppert averaging. Our extra interpolation step can potentially be used in combination with the lookahead optimizer also.
FAQs
Schedule Free Learning in PyTorch
We found that schedulefree demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Research
Security News
A malicious npm package targets Solana developers, rerouting funds in 2% of transactions to a hardcoded address.
Security News
Research
Socket researchers have discovered malicious npm packages targeting crypto developers, stealing credentials and wallet data using spyware delivered through typosquats of popular cryptographic libraries.
Security News
Socket's package search now displays weekly downloads for npm packages, helping developers quickly assess popularity and make more informed decisions.