Security News
Research
Data Theft Repackaged: A Case Study in Malicious Wrapper Packages on npm
The Socket Research Team breaks down a malicious wrapper package that uses obfuscation to harvest credentials and exfiltrate sensitive data.
jax_dataclasses
provides a simple wrapper around dataclasses.dataclass
for use in
JAX, which enables automatic support for:
flax.serialization
.Distinguishing features include:
In Python >=3.7:
pip install jax_dataclasses
We can then import:
import jax_dataclasses as jdc
jax_dataclasses
is meant to provide a drop-in replacement for
dataclasses.dataclass
: jdc.pytree_dataclass
has
the same interface as dataclasses.dataclass
, but also registers the target
class as a pytree node.
We also provide several aliases:
jdc.[field, asdict, astuples, is_dataclass, replace]
are identical to
their counterparts in the standard dataclasses library.
To mark a field as static (in this context: constant at compile-time), we can
wrap its type with jdc.Static[]
:
@jdc.pytree_dataclass
class A:
a: jax.Array
b: jdc.Static[bool]
In a pytree node, static fields will be treated as part of the treedef instead of as a child of the node; all fields that are not explicitly marked static should contain arrays or child nodes.
Bonus: if you like jdc.Static[]
, we also introduce
jdc.jit()
. This enables use in function
signatures, for example:
@jdc.jit
def f(a: jax.Array, b: jdc.Static[bool]) -> jax.Array:
...
All dataclasses are automatically marked as frozen and thus immutable (even when
no frozen=
parameter is passed in). To make changes to nested structures
easier, jdc.copy_and_mutate
(a) makes a copy of a
pytree and (b) returns a context in which any of that copy's contained
dataclasses are temporarily mutable:
import jax
from jax import numpy as jnp
import jax_dataclasses as jdc
@jdc.pytree_dataclass
class Node:
child: jax.Array
obj = Node(child=jnp.zeros(3))
with jdc.copy_and_mutate(obj) as obj_updated:
# Make mutations to the dataclass. This is primarily useful for nested
# dataclasses.
#
# Does input validation by default: if the treedef, leaf shapes, or dtypes
# of `obj` and `obj_updated` don't match, an AssertionError will be raised.
# This can be disabled with a `validate=False` argument.
obj_updated.child = jnp.ones(3)
print(obj)
print(obj_updated)
A few other solutions exist for automatically integrating dataclass-style
objects into pytree structures. Great ones include:
chex.dataclass
,
flax.struct
, and
tjax.dataclass
. These all influenced
this library.
The main differentiators of jax_dataclasses
are:
Static analysis support. tjax
has a custom mypy plugin to enable type
checking, but isn't supported by other tools. flax.struct
implements the
dataclass_transform
spec proposed by pyright, but isn't supported by other tools. Because
@jdc.pytree_dataclass
has the same API as @dataclasses.dataclass
, it can
include pytree registration behavior at runtime while being treated as the
standard decorator during static analysis. This means that all static
checkers, language servers, and autocomplete engines that support the standard
dataclasses
library should work out of the box with jax_dataclasses
.
Nested dataclasses. Making replacements/modifications in deeply nested
dataclasses can be really frustrating. The three alternatives all introduce a
.replace(self, ...)
method to dataclasses that's a bit more convenient than
the traditional dataclasses.replace(obj, ...)
API for shallow changes, but
still becomes really cumbersome to use when dataclasses are nested.
jdc.copy_and_mutate()
is introduced to address this.
Static field support. Parameters that should not be traced in JAX should
be marked as static. This is supported in flax
, tjax
, and
jax_dataclasses
, but not chex
.
Serialization. When working with flax
, being able to serialize
dataclasses is really handy. This is supported in flax.struct
(naturally)
and jax_dataclasses
, but not chex
or tjax
.
You can also eschew the dataclass-style interface entirely;
see how brax registers pytrees.
This is a reasonable thing to prefer: it requires some floating strings and
breaks things that I care about but you may not (like immutability and
__post_init__
), but gives more flexibility with custom __init__
methods.
jax_dataclasses
was originally written for and factored out of
jaxfg, where
Nick Heppert provided valuable feedback.
FAQs
Dataclasses + JAX
We found that jax-dataclasses demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Security News
Research
The Socket Research Team breaks down a malicious wrapper package that uses obfuscation to harvest credentials and exfiltrate sensitive data.
Research
Security News
Attackers used a malicious npm package typosquatting a popular ESLint plugin to steal sensitive data, execute commands, and exploit developer systems.
Security News
The Ultralytics' PyPI Package was compromised four times in one weekend through GitHub Actions cache poisoning and failure to rotate previously compromised API tokens.