PyTorch/XLA
Current CI status: 
PyTorch/XLA is a Python package that uses the XLA deep learning
compiler to connect the PyTorch deep learning
framework and Cloud
TPUs. You can try it right now, for free, on a
single Cloud TPU VM with
Kaggle!
Take a look at one of our Kaggle
notebooks to get
started:
Installation
TPU
To install PyTorch/XLA stable build in a new TPU VM:
Note: Builds are available for Python 3.10 to 3.13; please use one of the supported versions.
pip install torch==2.9.0 'torch_xla[tpu]==2.9.0'
pip install --pre torch_xla[pallas] --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
As of 07/16/2025 and starting from Pytorch/XLA 2.8 release, PyTorch/XLA will
provide nightly and release wheels for Python 3.10 to 3.13
To install PyTorch/XLA nightly build in a new TPU VM:
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev-cp312-cp312-linux_x86_64.whl' \
-f https://storage.googleapis.com/libtpu-wheels/index.html
C++11 ABI builds
As of 03/18/2025 and starting from Pytorch/XLA 2.7 release, C++11 ABI builds
are the default and we no longer provide wheels built with pre-C++11 ABI.
In Pytorch/XLA 2.6, we'll provide wheels and docker images built with
two C++ ABI flavors: C++11 and pre-C++11. Pre-C++11 is the default to align with
PyTorch upstream, but C++11 ABI wheels and docker images have better lazy tensor
tracing performance.
To install C++11 ABI flavored 2.6 wheels (Python 3.10 example):
pip install torch==2.6.0+cpu.cxx11.abi \
https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl \
'torch_xla[tpu]' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html \
-f https://download.pytorch.org/whl/torch
The above command works for Python 3.10. We additionally have Python 3.9 and 3.11
wheels:
To access C++11 ABI flavored docker image:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11
If your model is tracing bound (e.g. you see that the host CPU is busy tracing
the model while TPUs are idle), switching to the C++11 ABI wheels/docker images
can improve performance. Mixtral 8x7B benchmarking results on v5p-256, global
batch size 1024:
- Pre-C++11 ABI MFU: 33%
- C++ ABI MFU: 39%
Github Doc Map
Our github contains many useful docs on working with different aspects of PyTorch XLA, here is a list of useful docs spread around our repository:
- docs/source/learn: docs for learning concepts associated with XLA, troubleshooting, pjrt, eager mode, and dynamic shape.
- docs/source/accelerators: references to
TPU accelerator documents.
- docs/source/perf: documentation about performance specific aspects of PyTorch/XLA such as:
AMP, DDP, Dynamo, Fori loop, FSDP, quantization, recompilation, and SPMD
- docs/source/features: documentation on distributed torch, pallas, scan, and stable hlo.
- docs/source/contribute: documents on setting up PyTorch for development, and guides for lowering operations.
- PJRT plugins:
- torchax/docs: torchax documents
Getting Started
Following here are guides for two modes:
- Single process: one Python interpreter controlling a single TPU at a time
- Multi process: N Python interpreters are launched, corresponding to N TPUs
found on the system
Another mode is SPMD, where one Python interpreter controls all N TPUs found on
the system. Multi processing is more complex, and is not compatible with SPMD. This
tutorial does not dive into SPMD. For more on that, check our
SPMD guide.
Simple single process
To update your exisitng training loop, make the following changes:
+import torch_xla
def train(model, training_data, ...):
...
for inputs, labels in train_loader:
+ with torch_xla.step():
inputs, labels = training_data[i]
+ inputs, labels = inputs.to('xla'), labels.to('xla')
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
+ torch_xla.sync()
...
if __name__ == '__main__':
...
+ # Move the model paramters to your XLA device
+ model.to('xla')
train(model, training_data, ...)
...
The changes above should get your model to train on the TPU.
Multi processing
To update your existing training loop, make the following changes:
-import torch.multiprocessing as mp
+import torch_xla
+import torch_xla.core.xla_model as xm
def _mp_fn(index):
...
+ # Move the model paramters to your XLA device
+ model.to('xla')
for inputs, labels in train_loader:
+ with torch_xla.step():
+ # Transfer data to the XLA device. This happens asynchronously.
+ inputs, labels = inputs.to('xla'), labels.to('xla')
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
- optimizer.step()
+ # `xm.optimizer_step` combines gradients across replicas
+ xm.optimizer_step(optimizer)
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ # torch_xla.launch automatically selects the correct world size
+ torch_xla.launch(_mp_fn, args=())
If you're using DistributedDataParallel, make the following changes:
import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla
+import torch_xla.distributed.xla_backend
def _mp_fn(rank):
...
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
- dist.init_process_group("gloo", rank=rank, world_size=world_size)
+ # Rank and world size are inferred from the XLA device runtime
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to('xla')
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
for inputs, labels in train_loader:
+ with torch_xla.step():
+ inputs, labels = inputs.to('xla'), labels.to('xla')
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ torch_xla.launch(_mp_fn, args=())
Additional information on PyTorch/XLA, including a description of its semantics
and functions, is available at PyTorch.org. See the
API Guide for best practices when writing networks that run on
XLA devices (TPU, CPU and...).
Our comprehensive user guides are available at:
Documentation for the latest release
Documentation for master branch
PyTorch/XLA tutorials
Reference implementations
The AI-Hypercomputer/tpu-recipes
repo. contains examples for training and serving many LLM and diffusion models.
Available docker images and wheels
Python packages
PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You
can now install the main build with pip install torch_xla. To also install the
Cloud TPU plugin corresponding to your installed torch_xla, install the optional tpu dependencies after installing the main build with
pip install torch_xla[tpu]
TPU nightly builds are available in our public GCS bucket.
| nightly (Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev-cp311-cp311-linux_x86_64.whl |
| nightly (Python 3.12) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev-cp312-cp312-linux_x86_64.whl |
| nightly (Python 3.13) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev-cp312-cp312-linux_x86_64.whl |
Use nightly build
You can also add yyyymmdd like torch_xla-2.9.0.devyyyymmdd (or the latest dev version)
to get the nightly wheel of a specified date. Here is an example:
pip3 install torch==2.9.0.dev20250423+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250423-cp310-cp310-linux_x86_64.whl
The torch wheel version 2.9.0.dev20250423+cpu can be found at https://download.pytorch.org/whl/nightly/torch/.
older versions
| 2.9 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.8 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.7 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.6 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.5 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.4 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.3 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl |
Docker
NOTE: Since PyTorch/XLA 2.7, all builds will use the C++11 ABI by default
| 2.9 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.9.0_3.10_tpuvm |
| 2.8 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.8.0_3.10_tpuvm |
| 2.7 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.7.0_3.10_tpuvm |
| 2.6 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm |
| 2.6 (C++11 ABI) | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 |
| 2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_tpuvm |
| 2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm |
| 2.3 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm |
| nightly python | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm |
To use the above dockers, please pass --privileged --net host --shm-size=16G along. Here is an example:
docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash
Troubleshooting
If PyTorch/XLA isn't performing as expected, see the troubleshooting
guide, which has suggestions for debugging and optimizing
your network(s).
Providing Feedback
The PyTorch/XLA team is always happy to hear from users and OSS contributors!
The best way to reach out is by filing an issue on this Github. Questions, bug
reports, feature requests, build issues, etc. are all welcome!
Contributing
See the contribution guide.
Disclaimer
This repository is jointly operated and maintained by Google, Meta and a
number of individual contributors listed in the
CONTRIBUTORS file. For
questions directed at Meta, please send an email to opensource@fb.com. For
questions directed at Google, please send an email to
pytorch-xla@googlegroups.com. For all other questions, please open up an issue
in this repository here.
Additional Reads
You can find additional useful reading materials in
Related Projects