Amazon S3 Connector for PyTorch
The Amazon S3 Connector for PyTorch delivers high throughput for PyTorch training jobs that access or store data in
Amazon S3. Using the S3 Connector for PyTorch
automatically optimizes performance when downloading training data from and writing checkpoints to Amazon S3,
eliminating the need to write your own code to list S3 buckets and manage concurrent requests.
Amazon S3 Connector for PyTorch provides implementations of PyTorch's
dataset primitives that you can use to load
training data from Amazon S3.
It supports both map-style datasets for random data
access patterns and iterable-style datasets for
streaming sequential data access patterns.
The S3 Connector for PyTorch also includes a checkpointing interface to save and load checkpoints directly to
Amazon S3, without first saving to local storage.
Getting Started
Prerequisites
- Python 3.8 or greater is installed (Note: Using 3.12+ is not recommended as PyTorch does not support).
- PyTorch >= 2.0 (TODO: Check with PyTorch 1.x)
Installation
pip install s3torchconnector
Amazon S3 Connector for PyTorch supports pre-build wheels via Pip only for Linux and MacOS for now. For other platforms,
see DEVELOPMENT for build instructions.
Configuration
To use s3torchconnector
, AWS credentials must be provided through one of the following methods:
- If you are using this library on an EC2 instance, specify an IAM role and then give the EC2 instance access to
that role.
- Install and configure
awscli
and run aws configure
. - Set credentials in the AWS credentials profile file on the local system, located at:
~/.aws/credentials
on Unix or macOS. - Set the
AWS_ACCESS_KEY_ID
and AWS_SECRET_ACCESS_KEY
environment variables.
Examples
API docs are showing API of the public components.
End to end example of how to use s3torchconnector
can be found under the examples directory.
Sample Examples
The simplest way to use the S3 Connector for PyTorch is to construct a dataset, either a map-style or iterable-style
dataset, by specifying an S3 URI (a bucket and optional prefix) and the region the bucket is located in:
from s3torchconnector import S3MapDataset, S3IterableDataset
DATASET_URI="s3://<BUCKET>/<PREFIX>"
REGION = "us-east-1"
iterable_dataset = S3IterableDataset.from_prefix(DATASET_URI, region=REGION)
for item in iterable_dataset:
print(item.key)
map_dataset = S3MapDataset.from_prefix(DATASET_URI, region=REGION)
item = map_dataset[0]
bucket = item.bucket
key = item.key
content = item.read()
len(content)
In addition to data loading primitives, the S3 Connector for PyTorch also provides an interface for saving and loading
model checkpoints directly to and from an S3 bucket.
from s3torchconnector import S3Checkpoint
import torchvision
import torch
CHECKPOINT_URI="s3://<BUCKET>/<KEY>/"
REGION = "us-east-1"
checkpoint = S3Checkpoint(region=REGION)
model = torchvision.models.resnet18()
with checkpoint.writer(CHECKPOINT_URI + "epoch0.ckpt") as writer:
torch.save(model.state_dict(), writer)
with checkpoint.reader(CHECKPOINT_URI + "epoch0.ckpt") as reader:
state_dict = torch.load(reader)
model.load_state_dict(state_dict)
Using datasets or checkpoints with
Amazon S3 Express One Zone
directory buckets requires only to update the URI, following base-name--azid--x-s3
bucket name format.
For example, assuming the following directory bucket name my-test-bucket--usw2-az1--x-s3
with the Availability Zone ID
usw2-az1, then the URI used will look like: s3://my-test-bucket--usw2-az1--x-s3/<PREFIX>
(please note that the
prefix for Amazon S3 Express One Zone should end with '/'), paired with region us-west-2.
Distributed checkpoints
Overview
Amazon S3 Connector for PyTorch provides robust support for PyTorch distributed checkpoints. This feature includes:
S3StorageWriter
and S3StorageReader
: Implementations of PyTorch's StorageWriter and StorageReader interfaces.S3FileSystem
: An implementation of PyTorch's FileSystemBase.
These tools enable seamless integration of Amazon S3 with
PyTorch Distributed Checkpoints,
allowing efficient storage and retrieval of distributed model checkpoints.
Prerequisites and Installation
PyTorch 2.3 or newer is required.
To use the distributed checkpoints feature, install S3 Connector for PyTorch with the dcp
extra:
pip install s3torchconnector[dcp]
Sample Example
End-to-end examples for using distributed checkpoints with S3 Connector for PyTorch
can be found in the examples/dcp directory.
from s3torchconnector import S3StorageWriter, S3StorageReader
import torchvision
import torch.distributed.checkpoint as DCP
CHECKPOINT_URI = "s3://<BUCKET>/<KEY>/"
REGION = "us-east-1"
model = torchvision.models.resnet18()
s3_storage_writer = S3StorageWriter(region=REGION, path=CHECKPOINT_URI)
DCP.save(
state_dict=model.state_dict,
storage_writer=s3_storage_writer,
)
model = torchvision.models.resnet18()
model_state_dict = model.state_dict()
s3_storage_reader = S3StorageReader(region=REGION, path=CHECKPOINT_URI)
DCP.load(
state_dict=model_state_dict,
storage_reader=s3_storage_reader,
)
model.load_state_dict(model_state_dict)
Parallel/Distributed Training
Amazon S3 Connector for PyTorch provides support for parallel and distributed training with PyTorch,
allowing you to leverage multiple processes and nodes for efficient data loading and training.
Both S3IterableDataset and S3MapDataset can be used for this purpose.
S3IterableDataset
The S3IterableDataset can be directly passed to PyTorch's DataLoader for parallel and distributed training.
By default, all worker processes will share the same list of training objects. However,
if you need each worker to have access to a unique portion of the dataset for better parallelization,
you can enable dataset sharding using the enable_sharding
parameter.
dataset = S3IterableDataset.from_prefix(DATASET_URI, region=REGION, enable_sharding=True)
dataloader = DataLoader(dataset, num_workers=4)
When enable_sharding
is set to True, the dataset will be automatically sharded across available number of workers.
This sharding mechanism supports both parallel training on a single host and distributed training across multiple hosts.
Each worker, regardless of its host, will load and process a distinct subset of the dataset.
S3MapDataset
For the S3MapDataset, you need to pass it to DataLoader along with a DistributedSampler wrapped around it.
The DistributedSampler ensures that each worker or node receives a unique subset of the dataset,
enabling efficient parallel and distributed training.
dataset = S3MapDataset.from_prefix(DATASET_URI, region=REGION)
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, num_workers=4)
Lightning Integration
Amazon S3 Connector for PyTorch includes an integration for PyTorch Lightning, featuring S3LightningCheckpoint, an
implementation of Lightning's CheckpointIO. This allows users to make use of Amazon S3 Connector for PyTorch's S3
checkpointing functionality with Pytorch Lightning.
Getting Started
Installation
pip install s3torchconnector[lightning]
Examples
End to end examples for the Pytorch Lightning integration can be found in the examples/lightning
directory.
from lightning import Trainer
from s3torchconnector.lightning import S3LightningCheckpoint
...
s3_checkpoint_io = S3LightningCheckpoint("us-east-1")
trainer = Trainer(
plugins=[s3_checkpoint_io],
default_root_dir="s3://bucket_name/key_prefix/"
)
trainer.fit(model)
Using S3 Versioning to Manage Checkpoints
When working with model checkpoints, you can use the S3 Versioning feature to preserve, retrieve, and restore every version of your checkpoint objects. With versioning, you can recover more easily from unintended overwrites or deletions of existing checkpoint files due to incorrect configuration or multiple hosts accessing the same storage path.
When versioning is enabled on an S3 bucket, deletions insert a delete marker instead of removing the object permanently. The delete marker becomes the current object version. If you overwrite an object, it results in a new object version in the bucket. You can always restore the previous version. See Deleting object versions from a versioning-enabled bucket for more details on managing object versions.
To enable versioning on an S3 bucket, see Enabling versioning on buckets. Normal Amazon S3 rates apply for every version of an object stored and transferred. To customize your data retention approach and control storage costs for earlier versions of objects, use object versioning with S3 Lifecycle.
S3 Versioning and S3 Lifecycle are not supported by S3 Express One Zone.
Contributing
We welcome contributions to Amazon S3 Connector for PyTorch. Please see CONTRIBUTING for more
information on how to report bugs or submit pull requests.
Development
See DEVELOPMENT for information about code style, development process, and guidelines.
Compatibility with other storage services
S3 Connector for PyTorch delivers high throughput for PyTorch training jobs that access or store data in Amazon S3.
While it may be functional against other storage services that use S3-like APIs, they may inadvertently break when we
make changes to better support Amazon S3. We welcome contributions of minor compatibility fixes or performance
improvements for these services if the changes can be tested against Amazon S3.
Security issue notifications
If you discover a potential security issue in this project we ask that you notify AWS Security via our
vulnerability reporting page.
Code of conduct
This project has adopted the Amazon Open Source Code of Conduct.
See CODE_OF_CONDUCT.md for more details.
License
Amazon S3 Connector for PyTorch has a BSD 3-Clause License, as found in the LICENSE file.