![Oracle Drags Its Feet in the JavaScript Trademark Dispute](https://cdn.sanity.io/images/cgdhsj6q/production/919c3b22c24f93884c548d60cbb338e819ff2435-1024x1024.webp?w=400&fit=max&auto=format)
Security News
Oracle Drags Its Feet in the JavaScript Trademark Dispute
Oracle seeks to dismiss fraud claims in the JavaScript trademark dispute, delaying the case and avoiding questions about its right to the name.
github.com/rwkv/rwkv-infctx-trainer
If you are new to RWKV, it would be better to find out more about us via our wiki first here: https://wiki.rwkv.com/
RWKV trainer with
With this implementation you can train on arbitrarily long context within (near) constant VRAM consumption; this increasing should be, about 2MB per 1024/2048 tokens (depending on your chosen ctx_len
, with RWKV 7B as an example) in the training sample, which will enable training on sequences over 1M tokens.
The training code is by the way tremendously refactored into using PyTorch 2.0, Lightning 2.0 and DeepSpeed 2.0, and the starting script now relies on LightningCLI so you will see the config-example.yaml containing all the switches, mostly standard ones that Lightning processes by itself. And new ones for RWKV and the dataset parser.
To use this repo, go into RWKV-v4neo
directory and do
python3 lightning_trainer.py fit -c {your_config}.yaml
Remember to modify the configuration for your own need.
See RWKV-v4neo/config-example.yaml for documentation on the various options
NOTE: Due to current incomplete implementation, without state gradient, bptt_truncate is forced to be true
Note: There is a known issue with CUDA 12.0 and multi-gpu at this point of writing. Upgrade to CUDA 12.1 or 12.2 atleast Or downgrade to 11.8
The following venv setup using conda, modify for your use case respectively
# ninja-build is required for the new trainer
sudo apt-get install ninja-build
# Update conda & its package listings
conda update conda
# Virtual env, with python 3.10
# python 3.11 have issues with torch.compile / h100s
# and if you want to use 3.11, you will need to do a nightly build install
conda create -n rwkv-infctx python=3.11 pip
conda activate rwkv-infctx
# Install pytorch (>=2.1.2)
conda install -y pytorch==2.1.2 torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
python -m pip install lightning==2.1.3 deepspeed==0.12.6
# Currently for torch.compile + 3.11 to work, for some platform, you will need the nightly build
# if so you may need to try the following instead - this is considered highly "unstable"
# ---
# conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch-nightly -c nvidia
# python -m pip install lightning==2.0.5 deepspeed==0.10.0
# Verify your pytorch version
python -c "import torch; print(torch.__version__)"
# Install all the other various dependencies
# PS: We use python -m pip, instead of pip directly, as it resolve issues with venv not loading the right pip
python -m pip install datasets transformers
python -m pip install ninja numexpr jsonargparse 'jsonargparse[signatures]'
python -m pip install lm-dataformat ftfy sentencepiece tokenizers wandb
# Optional dependencies, useful for running notebooks, etc
python -m pip install papermill
Alternatively you could use the requirements.txt (this may not install pytorch-cuda properly, and is found to be not compatible with conda environments)
python3 -m pip install -r requirements.txt
Due to issues with deepspeed on windows. Only linux environments are supported. WSl2 with windows is not recommended, due to heavy performance penalities in the process (cannot use deepspeed offload, ~50% slower)
python3 ./init_model.py --n_layer {number-of-layers} --n_embd {embedding-size} --vocab_size {vocab-size/neox/world} --skip-if-exists ../model/file/path.pth
python3 preload_datapath.py {you-config}.yaml
python3 lightning_trainer.py fit -c {your_config}.yaml
python3 export_checkpoint.py ../path/to/checkpoint/last.ckpt/ ../path/to/export/model.pth
python3 dragon_test.py ../path/to/export/model.pth
In summary with code, from the trainer directory (eg. RWKV-v4neo)
# Initialize the blank model (or download a pretrained model)
python3 init_model.py --n_layer {number-of-layers} --n_embd {embedding-size} --vocab_size {vocab-size/neox/world} --skip-if-exists ../model/file/path.pth
# Preload your dataset
python3 preload_datapath.py {you-config}.yaml
# Run the training process
python3 lightning_trainer.py fit -c {your_config}.yaml
# Export the checkpoint to model code
python3 export_checkpoint.py ../path/to/checkpoint/last.ckpt/ ../path/to/export/model.pth
# Quick test the model with the dragon prompt
python3 dragon_test.py ../path/to/export/model.pth
# @TODO, convert the model to bf16 format (instead of the huge fp32 format now)
# for now you will have to use the RWKV pip package to do this with python code:
# https://pypi.org/project/rwkv/
You can find the following notebook/examples at the following ...
For configuration issues, please review through the examples listed above first, before asking questions on discord.
You can find the training channel on our discord here: https://discord.com/channels/992359628979568762/992362252269256815
step
in the progress bar below, means 1 data sample PER GPU.trainer/global_step
in wandbsubstep
in wandb means a single data sample.
-(accumulate_gradiant_batch * gpu count) substeps = 1 trainer/global_step
Generally if your training a foundation model from scratch - with a fixed context size, and you need the absolute highest throughput across multiple nodes (ie. 10 nodes filled with A100 servers), the official trainer would perform much better (ie 2x faster depending on the settings)
If you need deepspeed 3 support, or you deal with dynamic datasets, this trainer is much more flexible, for nearly all other use cases.
Overtime as we optimize the infctx trainer, the gap to the official trainer should shrink, however this is not the highest priority (infctx working > absolute speed)
#rwkv-x
development. So that the entire train-test-validation of design changes can be done in this repository.The following features are not yet supported (that may exist in blinks original repo)
@picocreator - is the current maintainer of the project, you can ping him on the RWKV discord if you have any questions on this project
This project was intentionally a hard fork, as it has too many conflicting changes to the official RWKV-LM repo
FAQs
Unknown package
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Security News
Oracle seeks to dismiss fraud claims in the JavaScript trademark dispute, delaying the case and avoiding questions about its right to the name.
Security News
The Linux Foundation is warning open source developers that compliance with global sanctions is mandatory, highlighting legal risks and restrictions on contributions.
Security News
Maven Central now validates Sigstore signatures, making it easier for developers to verify the provenance of Java packages.