from array_split import array_split, shape_split
import numpy as np
ary = np.arange(0, 4*9)
array_split(ary, 4) # 1D split into 4 sections (like numpy.array_split)
[array([0, 1, 2, 3, 4, 5, 6, 7, 8]),
array([ 9, 10, 11, 12, 13, 14, 15, 16, 17]),
array([18, 19, 20, 21, 22, 23, 24, 25, 26]),
array([27, 28, 29, 30, 31, 32, 33, 34, 35])]
shape_split(ary.shape, 4) # 1D split into 4 parts, returns slice objects
array([(slice(0, 9, None),), (slice(9, 18, None),), (slice(18, 27, None),), (slice(27, 36, None),)],
dtype=[('0', 'O')])
ary = ary.reshape(4, 9) # Make ary 2D
split = shape_split(ary.shape, axis=(2, 3)) # 2D split into 2*3=6 sections
split.shape
(2, 3)
split
array([[(slice(0, 2, None), slice(0, 3, None)),
(slice(0, 2, None), slice(3, 6, None)),
(slice(0, 2, None), slice(6, 9, None))],
[(slice(2, 4, None), slice(0, 3, None)),
(slice(2, 4, None), slice(3, 6, None)),
(slice(2, 4, None), slice(6, 9, None))]],
dtype=[('0', 'O'), ('1', 'O')])
sub_arys = [ary[tup] for tup in split.flatten()] # Create sub-array views from slice tuples.
sub_arys
[array([[ 0, 1, 2], [ 9, 10, 11]]),
array([[ 3, 4, 5], [12, 13, 14]]),
array([[ 6, 7, 8], [15, 16, 17]]),
array([[18, 19, 20], [27, 28, 29]]),
array([[21, 22, 23], [30, 31, 32]]),
array([[24, 25, 26], [33, 34, 35]])]