
Security News
NIST Officially Stops Enriching Most CVEs as Vulnerability Volume Skyrockets
NIST will stop enriching most CVEs under a new risk-based model, narrowing the NVD's scope as vulnerability submissions continue to surge.
dataclass-array
Advanced tools
Dataclasses that behave like numpy arrays (with indexing, slicing, vectorization).
DataclassArray are dataclasses which behave like numpy-like arrays (can be
batched, reshaped, sliced,...), compatible with Jax, TensorFlow, and numpy (with
torch support planned).
This reduce boilerplate and improve readability. See the motivating examples section bellow.
To view an example of dataclass arrays used in practice, see visu3d.
To create a dca.DataclassArray, take a frozen dataclass and:
dca.DataclassArraydataclass_array.typing to specify the inner shape
and dtype of the array (see below for static or nested dataclass fields).
The array types are an alias from
etils.array_types.import dataclass_array as dca
from dataclass_array.typing import FloatArray
class Ray(dca.DataclassArray):
pos: FloatArray['*batch_shape 3']
dir: FloatArray['*batch_shape 3']
Afterwards, the dataclass can be used as a numpy array:
ray = Ray(pos=jnp.zeros((3, 3)), dir=jnp.eye(3))
ray.shape == (3,) # 3 rays batched together
ray.pos.shape == (3, 3) # Individual fields still available
# Numpy slicing/indexing/masking
ray = ray[..., 1:2]
ray = ray[norm(ray.dir) > 1e-7]
# Shape transformation
ray = ray.reshape((1, 3))
ray = ray.reshape('h w -> w h') # Native einops support
ray = ray.flatten()
# Stack multiple dataclass arrays together
ray = dca.stack([ray0, ray1, ...])
# Supports TF, Jax, Numpy (torch planned) and can be easily converted
ray = ray.as_jax() # as_np(), as_tf()
ray.xnp == jax.numpy # `numpy`, `jax.numpy`, `tf.experimental.numpy`
# Compatibility `with jax.tree_util`, `jax.vmap`,..
ray = jax.tree_util.tree_map(lambda x: x+1, ray)
A DataclassArray has 2 types of fields:
xnp.ndarray or nested dca.DataclassArray.jax.tree_map.class MyArray(dca.DataclassArray):
# Array fields
a: FloatArray['*batch_shape 3'] # Defined by `etils.array_types`
b: FloatArray['*batch_shape _ _'] # Dynamic shape
c: Ray # Nested DataclassArray (equivalent to `Ray['*batch_shape']`)
d: Ray['*batch_shape 6']
# Array fields explicitly defined
e: Any = dca.field(shape=(3,), dtype=np.float32)
f: Any = dca.field(shape=(None, None), dtype=np.float32) # Dynamic shape
g: Ray = dca.field(shape=(3,), dtype=Ray) # Nested DataclassArray
# Static field (everything not defined as above)
static0: float
static1: np.array
@dca.vectorize_method allow your dataclass method to automatically support
batching:
self.shape == ()dca.vectorize_methodclass Camera(dca.DataclassArray):
K: FloatArray['*batch_shape 4 4']
resolution = tuple[int, int]
@dca.vectorize_method
def rays(self) -> Ray:
# Inside `@dca.vectorize_method` shape is always guarantee to be `()`
assert self.shape == ()
assert self.K.shape == (4, 4)
# Compute the ray as if there was only a single camera
return Ray(pos=..., dir=...)
Afterward, we can generate rays for multiple camera batched together:
cams = Camera(K=K) # K.shape == (num_cams, 4, 4)
rays = cams.rays() # Generate the rays for all the cameras
cams.shape == (num_cams,)
rays.shape == (num_cams, h, w)
@dca.vectorize_method is similar to jax.vmap but:
dca.DataclassArray methods@dca.vectorize_method will vectorize
over *self.shape (not just self.shape[0]). This is like if vmap was
applied to self.flatten()1 are broadcasted.For example, with __matmul__(self, x: T) -> T:
() @ (*x,) -> (*x,)
(b,) @ (b, *x) -> (b, *x)
(b,) @ (1, *x) -> (b, *x)
(1,) @ (b, *x) -> (b, *x)
(b, h, w) @ (b, h, w, *x) -> (b, h, w, *x)
(1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)
(a, *x) @ (b, *x) -> Error: Incompatible a != b
To test on Colab, see the visu3d dataclass
Colab tutorial.
dca.DataclassArray improve readability by simplifying common patterns:
Reshaping all fields of a dataclass:
Before (rays is simple dataclass):
num_rays = math.prod(rays.origins.shape[:-1])
rays = jax.tree_map(lambda r: r.reshape((num_rays, -1)), rays)
After (rays is DataclassArray):
rays = rays.flatten() # (b, h, w) -> (b*h*w,)
Rendering a video:
Before (cams: list[Camera]):
img = cams[0].render(scene)
imgs = np.stack([cam.render(scene) for cam in cams[::2]])
imgs = np.stack([cam.render(scene) for cam in cams])
After (cams: Camera with cams.shape == (num_cams,)):
img = cams[0].render(scene) # Render only the first camera (to debug)
imgs = cams[::2].render(scene) # Render 1/2 frames (for quicker iteration)
imgs = cams.render(scene) # Render all cameras at once
pip install dataclass_array
This is not an official Google product
FAQs
Dataclasses that behave like numpy arrays (with indexing, slicing, vectorization).
We found that dataclass-array demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Security News
NIST will stop enriching most CVEs under a new risk-based model, narrowing the NVD's scope as vulnerability submissions continue to surge.

Company News
/Security News
Socket is an initial recipient of OpenAI's Cybersecurity Grant Program, which commits $10M in API credits to defenders securing open source software.

Security News
Socket CEO Feross Aboukhadijeh joins 10 Minutes or Less, a podcast by Ali Rohde, to discuss the recent surge in open source supply chain attacks.