
Product
Announcing Socket Fix 2.0
Socket Fix 2.0 brings targeted CVE remediation, smarter upgrade planning, and broader ecosystem support to help developers get to zero alerts.
Gradient Cache is a production-ready PyTorch extension that reduces GPU memory usage by 90%+ during neural network training through intelligent gradient compression and CPU offloading.
Model | Parameters | Memory Saved | Compression |
---|---|---|---|
GPT-2 Small | 124M | 479 MB/step | 100x |
GPT-2 Medium | 350M | ~1.3 GB/step | 100x |
Custom NN | 50M | 144 MB/step | 100x |
pip install gradient-cache
Or install from source:
git clone https://github.com/your-username/gradient-cache
cd gradient-cache
pip install -e .
Add gradient cache to any PyTorch training loop with just 3 lines:
import gradient_cache
# Create your model
model = create_your_model().cuda()
# Add gradient cache (1 line)
hook_manager = gradient_cache.create_gradient_cache(model, compression_ratio=100)
# Normal training loop
optimizer = torch.optim.Adam(model.parameters())
for batch in dataloader:
loss = model(batch).mean()
loss.backward()
# Compress gradients (1 line)
hook_manager.compress_and_free_gradients()
# Restore gradients and update (1 line)
hook_manager.apply_gradients()
optimizer.step()
optimizer.zero_grad()
Use the decorator for automatic integration:
from metaflow import FlowSpec, step
import gradient_cache
class MyTrainingFlow(FlowSpec):
@step
@gradient_cache.optimize(compression_ratio=100)
def train(self):
# Your training code - no changes needed!
model = create_model()
optimizer = torch.optim.Adam(model.parameters())
# ... rest of training
import pytorch_lightning as pl
import gradient_cache
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = create_model()
self.hook_manager = gradient_cache.create_gradient_cache(self.model)
def training_step(self, batch, batch_idx):
loss = self.model(batch).mean()
return loss
def on_after_backward(self):
self.hook_manager.compress_and_free_gradients()
def optimizer_step(self, *args, **kwargs):
self.hook_manager.apply_gradients()
super().optimizer_step(*args, **kwargs)
# Conservative - 10x compression (keep 10%)
hook_manager = gradient_cache.create_gradient_cache(model, compression_ratio=10)
# Aggressive - 1000x compression (keep 0.1%)
hook_manager = gradient_cache.create_gradient_cache(model, compression_ratio=1000)
# Don't compress embeddings or output layers
hook_manager = gradient_cache.GradientCacheHookManager(
model,
compression_ratio=100,
exclude_layers=['embedding', 'lm_head']
)
# Enable verbose mode
hook_manager = gradient_cache.create_gradient_cache(model, verbose=True)
# Get compression statistics
stats = hook_manager.get_compression_summary()
print(f"Compression ratio: {stats['overall_compression_ratio']:.1f}x")
print(f"Memory saved: {stats['memory_saved_mb']:.1f} MB")
Run the test suite:
python tests/test_gradient_cache.py
If you use Gradient Cache in your research, please cite:
@software{gradient_cache,
title = {Gradient Cache: GPU Memory-Efficient Training},
author = {Gradient Cache Contributors},
year = {2024},
url = {https://github.com/gradient-cache/gradient-cache}
}
Apache License 2.0 - see LICENSE for details.
We welcome contributions! Please submit issues and pull requests on GitHub.
Built with โค๏ธ for the ML community
FAQs
GPU memory-efficient training with gradient compression for PyTorch
We found that gradient-cache demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago.ย It has 1 open source maintainer collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Product
Socket Fix 2.0 brings targeted CVE remediation, smarter upgrade planning, and broader ecosystem support to help developers get to zero alerts.
Security News
Socket CEO Feross Aboukhadijeh joins Risky Business Weekly to unpack recent npm phishing attacks, their limited impact, and the risks if attackers get smarter.
Product
Socketโs new Tier 1 Reachability filters out up to 80% of irrelevant CVEs, so security teams can focus on the vulnerabilities that matter.