cut2min-bucket
A PyTorch Batch Sampler that buckets by input length and cuts to min size in batch

This package provides 2 utilities:
cut2min_bucket.DatasetWrapper
to eliminate padding and cut to min size in batch
cut2min_bucket.BucketBatchSampler
a batch sampler that buckets by input length.
In addition, we provide a Distributed Data Parallel version of the batch sampler: cut2min_bucket.DistributedBucketBatchSampler
.
A detailed motivation for this package can be found on my blog.
Simple example:
import cut2min_bucket
import torch
import numpy as np
X = []
for _ in range(10000):
X.append(torch.tensor(np.random.randn(torch.randint(size=(), low=2, high=1000),)))
seqlens = torch.tensor([len(x) for x in X])
X = torch.nn.utils.rnn.pad_sequence(X, batch_first=True)
y = (torch.rand(10000)>0.5).int()
dataset = torch.utils.data.TensorDataset(X, y)
dataset = cut2min_bucket.DatasetWrapper(
dataset, seqlens,
index_or_key=0
)
batch_sampler = cut2min_bucket.BucketBatchSampler(
dataset,
seqlens,
batch_size=8,
n_partitions=5
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=dataset.collate_fn,
)
next(iter(dataloader))