smallperm
Small library to generate permutations of a list of elements using pseudo-random permutations (PRP). Uses O(1)
memory and O(1)
time to generate the next element of the permutation.
>>> from smallperm import PseudoRandomPermutation
>>> list(PseudoRandomPermutation(42, 0xDEADBEEF))
[30, 11, 23, 21, 39, 9, 26, 5, 27, 38, 15, 37, 31, 35, 6, 13, 34, 10, 7, 0, 12, 22, 33, 17, 41, 29, 18, 20, 3, 40, 25, 4, 19, 24, 32, 16, 36, 14, 1, 28, 2, 8]
Motivation
In ML training, it is common to see things like
import numpy as np
sample_indices = np.arange(1_000_000)
np.random.shuffle(sample_indices)
for i in sample_indices:
...
Or to do Fisher-Yates online
import numpy as np
N = 1_000_000
sample_indices = np.arange(N)
for i in range(N):
j = np.random.randint(i, N)
sample_indices[i], sample_indices[j] = sample_indices[j], sample_indices[i]
...
The problem with either of these approaches is that they require O(n)
memory to store the shuffled indices, and offline shuffle has a bad "time-to-first-sample" problem when we approach the scale of one billion data points. This library provides a way to generate a permutation of [0, n)
using O(1)
memory and O(1)
time.
pip install smallperm
import numpy as np
from smallperm import PseudoRandomPermutation as PRP
N = 1_000_000
prp = PRP(N, np.random.randint(0, np.iinfo(np.int64).max+1))
print(prp[0], prp[50])
assert 50 == prp.backward(prp[50])
for ix in prp:
...
For most ML use cases this should be Pareto optimal: it is faster than Fisher-Yates, uses much less memory, and has a much better time-to-first-sample than offline shuffle. In other words, we used O(1)
time and O(1)
space to generate arr = np.arange(N); np.random.shuffle(arr)
, kind of magical, at the slight cost of some shuffling quality, but hey, in ML training when we constantly have > 1M data points it's not like our PRNG keys can represent the entire space of permutations anyway.
API
-
Initialization: PseudoRandomPermutation(length: int, seed: int)
- Generates a permutation of
[0, length)
using seed
. We impose no restriction on length
(except it fits under an unsigned 128-bit integer).
-
Usage: Iterate over the instance to get the next element of the permutation.
- Example:
list(PseudoRandomPermutation(42, 0xDEADBEEF))
-
O(1) forward/backward mapping:
forward(i: int) -> int
: Returns the i
-th element of the permutation (regardless of the current state of the iterator).backward(el: int) -> int
: Returns the index of el
in the permutation.
Features
- Hard-ware independent (i.e., reproducible across different machines, with the same seed) shuffling. This repo, barring major bugs, will not change the permutation generated by a given seed (in which case we will do major version bump).
- Extremely fast. On my MBP iterating through the array is only 2x-3x slower than iterating throw a
arange(N)
array.
How
We use a (somewhat) weak albeit fast symmetric cipher to generate the permutation. The resulting shuffle quality is not as high as Fisher-Yates shuffle, but it is extremely efficient. Compared to Fisher-Yates, we use O(1)
memory (as opposed to O(n)
, n
the length of the shuffle); fix $\sigma$ a permutation (i.e., PseudoRandomPermutation(n, seed)
) which maps ${0, 1, \ldots, n-1}$ to itself, we have $O(1)$ $\sigma(x)$ and $\sigma^{-1}(y)$, which can be very desirable properties in distributed ML training.
More examples
Shuffle non [0, n)
sequences
PRP
is a primitive yet powerful object, and can be composed.
from smallperm import PseudoRandomPermutation as PRP
from itertools import product
deck = list(product(range(4), range(13)))
prp = PRP(len(deck), 0xDEADBEEF)
shuffled_deck = [deck[i] for i in prp]
Replacing two random.shuffle
It is common to have two random.shuffle
when there is a global shuffle and there is a "within-shard" shuffle.
from smallperm import PseudoRandomPermutation as PRP
seeds = (42, 0)
N = 1_000_000
local_indices = range(2, N, 8)
local_n = len(local_indices)
global_shuffle = PRP(N, seeds[0])
local_shuffle = PRP(len(local_indices), seeds[1])
for i in range(local_n):
print(global_shuffle[local_indices[local_shuffle[i]]])
Infinite PRP with new shuffles per epoch
from itertools import chain, count
import numpy as np
N = 10_000_000
rng = np.random.default_rng(0xDEADBEEF)
uint32_upper_bound = 0x100000000
infinite_prp = chain.from_iterable(PRP(N, rng.integers(0, uint32_upper_bound)) for _ in count())
Acknowledgements
Gratefully modifies and reuses code from https://github.com/asimihsan/permutation-iterator-rs which
does most of the heavy lifting. Because the heavy lifting is done in Rust, this library is very efficient.