
Research
Two Malicious Rust Crates Impersonate Popular Logger to Steal Wallet Keys
Socket uncovers malicious Rust crates impersonating fast_log to steal Solana and Ethereum wallet keys from source code.
torchdistill
Advanced tools
A Modular, Configuration-Driven Framework for Knowledge Distillation. Trained models, training logs and configurations are available for ensuring the reproducibility.
torchdistill (formerly kdkit) offers various state-of-the-art knowledge distillation methods and enables you to design (new) experiments simply by editing a declarative yaml config file instead of Python code. Even when you need to extract intermediate representations in teacher/student models, you will NOT need to reimplement the models, that often change the interface of the forward, but instead specify the module path(s) in the yaml file. Refer to these papers for more details.
In addition to knowledge distillation, this framework helps you design and perform general deep learning experiments (WITHOUT coding) for reproducible deep learning studies. i.e., it enables you to train models without teachers simply by excluding teacher entries from a declarative yaml config file. You can find such examples below and in configs/sample/.
In December 2023, torchdistill officially joined PyTorch Ecosystem.
When you refer to torchdistill in your paper, please cite these papers
instead of this GitHub repository.
If you use torchdistill as part of your work, your citation is appreciated and motivates me to maintain and upgrade this framework!
You can find the API documentation and research projects that leverage torchdistill at https://yoshitomo-matsubara.net/torchdistill/
Using ForwardHookManager, you can extract intermediate representations in model without modifying the interface of its forward function.
This example notebook
will give you a better idea of the usage such as knowledge distillation and analysis of intermediate representations.
E.g., extract intermediate representations (feature map) of ResNet-18 for a random input batch
import torch
from torchvision import models
from torchdistill.core.forward_hook import ForwardHookManager
# Define a model and choose torch device
model = models.resnet18(pretrained=False)
device = torch.device('cpu')
# Register forward hooks for modules of your interest
forward_hook_manager = ForwardHookManager(device)
forward_hook_manager.add_hook(model, 'conv1', requires_input=True, requires_output=False)
forward_hook_manager.add_hook(model, 'layer1.0.bn2', requires_input=True, requires_output=True)
forward_hook_manager.add_hook(model, 'fc', requires_input=False, requires_output=True)
# Define a random input batch and run the model
x = torch.rand(32, 3, 224, 224)
y = model(x)
# Extract input and/or output of the modules
io_dict = forward_hook_manager.pop_io_dict()
conv1_input = io_dict['conv1']['input']
layer1_0_bn2_input = io_dict['layer1.0.bn2']['input']
layer1_0_bn2_output = io_dict['layer1.0.bn2']['output']
fc_output = io_dict['fc']['output']
In torchdistill, many components and PyTorch modules are abstracted e.g., models, datasets, optimizers, losses, and more! You can define them in a declarative PyYAML config file so that can be seen as a summary of your experiment, and in many cases, you will NOT need to write Python code at all. Take a look at some configurations available in configs/. You'll see what modules are abstracted and how they are defined in a declarative PyYAML config file to design an experiment.
E.g., instantiate CIFAR-10 datasets with a declarative PyYAML config file
from torchdistill.common import yaml_util
config = yaml_util.load_yaml_file('./test.yaml')
train_dataset = config['datasets']['cifar10/train']
test_dataset = config['datasets']['cifar10/test']
test.yaml
datasets:
cifar10/train: !import_call
key: 'torchvision.datasets.CIFAR10'
init:
kwargs:
root: &root_dir '~/datasets/cifar10'
train: True
download: True
transform: !import_call
key: 'torchvision.transforms.Compose'
init:
kwargs:
transforms:
- !import_call
key: 'torchvision.transforms.RandomCrop'
init:
kwargs:
size: 32
padding: 4
- !import_call
key: 'torchvision.transforms.RandomHorizontalFlip'
init:
kwargs:
p: 0.5
- !import_call
key: 'torchvision.transforms.ToTensor'
init:
- !import_call
key: 'torchvision.transforms.Normalize'
init:
kwargs: &normalize_kwargs
mean: [0.49139968, 0.48215841, 0.44653091]
std: [0.24703223, 0.24348513, 0.26158784]
cifar10/test: !import_call
key: 'torchvision.datasets.CIFAR10'
init:
kwargs:
root: *root_dir
train: False
download: True
transform: !import_call
key: 'torchvision.transforms.Compose'
init:
kwargs:
transforms:
- !import_call
key: 'torchvision.transforms.ToTensor'
init:
- !import_call
key: 'torchvision.transforms.Normalize'
init:
kwargs: *normalize_kwargs
If you want to use your own modules (models, loss functions, datasets, etc) with this framework,
you can do so without editing code in the local package torchdistill/
.
See the official documentation and Discussions for more details.
Top-1 validation accuracy for ILSVRC 2012 (ImageNet)
Executable code can be found in examples/ such as
For CIFAR-10 and CIFAR-100, some models are reimplemented and available as pretrained models in torchdistill. More details can be found here.
Some Transformer models fine-tuned by torchdistill for GLUE tasks are available at Hugging Face Model Hub. Sample GLUE benchmark results and details can be found here.
The following examples are available in demo/. Note that these examples are for Google Colab users and compatible with Amazon SageMaker Studio Lab. Usually, examples/ would be a better reference if you have your own GPU(s).
These examples write out test prediction files for you to see the test performance at the GLUE leaderboard system.
If you find models on PyTorch Hub or GitHub repositories supporting PyTorch Hub, you can import them as teacher/student models simply by editing a declarative yaml config file.
e.g., If you use a pretrained ResNeSt-50 available in huggingface/pytorch-image-models (aka timm) as a teacher model for ImageNet dataset, you can import the model via PyTorch Hub with the following entry in your declarative yaml config file.
models:
teacher_model:
key: 'resnest50d'
repo_or_dir: 'huggingface/pytorch-image-models'
kwargs:
num_classes: 1000
pretrained: True
pip3 install torchdistill
# or use pipenv
pipenv install torchdistill
git clone https://github.com/yoshitomo-matsubara/torchdistill.git
cd torchdistill/
pip3 install -e .
# or use pipenv
pipenv install "-e ."
Feel free to create an issue if you find a bug.
If you have either a question or feature request, start a new discussion here.
Please search through Issues and Discussions and make sure your issue/question/request has not been addressed yet.
Pull requests are welcome. Please start with an issue and discuss solutions with me rather than start with a pull request.
If you use torchdistill in your research, please cite the following papers:
[Paper] [Preprint]
@inproceedings{matsubara2021torchdistill,
title={{torchdistill: A Modular, Configuration-Driven Framework for Knowledge Distillation}},
author={Matsubara, Yoshitomo},
booktitle={International Workshop on Reproducible Research in Pattern Recognition},
pages={24--44},
year={2021},
organization={Springer}
}
[Paper] [OpenReview] [Preprint]
@inproceedings{matsubara2023torchdistill,
title={{torchdistill Meets Hugging Face Libraries for Reproducible, Coding-Free Deep Learning Studies: A Case Study on NLP}},
author={Matsubara, Yoshitomo},
booktitle={Proceedings of the 3rd Workshop for Natural Language Processing Open Source Software (NLP-OSS 2023)},
publisher={Empirical Methods in Natural Language Processing},
pages={153--164},
year={2023}
}
This project has been supported by Travis CI's OSS credits and JetBrain's Free License Programs (Open Source)
since November 2021 and June 2022, respectively.
FAQs
A Modular, Configuration-Driven Framework for Knowledge Distillation. Trained models, training logs and configurations are available for ensuring the reproducibility.
We found that torchdistill demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Research
Socket uncovers malicious Rust crates impersonating fast_log to steal Solana and Ethereum wallet keys from source code.
Research
A malicious package uses a QR code as steganography in an innovative technique.
Research
/Security News
Socket identified 80 fake candidates targeting engineering roles, including suspected North Korean operators, exposing the new reality of hiring as a security function.