xmmutablemap
Jax-compatible Immutable Map
JAX prefers immutable objects but neither Python nor JAX provide an immutable
dictionary. 😢
This repository defines a light-weight immutable map
(lower-level than a dict) that JAX understands as a PyTree. 🎉 🕶️
Installation
pip install xmmutablemap
Documentation
xmutablemap
provides the class ImmutableMap
, which is a full implementation
of
Python's Mapping
ABC.
If you've used a dict
then you already know how to use ImmutableMap
! The
things ImmutableMap
adds is 1) immutability (and related benefits like
hashability) and 2) compatibility with JAX
.
from xmmutablemap import ImmutableMap
print(ImmutableMap(a=1, b=2, c=3))
print(ImmutableMap({"a": 1, "b": 2.0, "c": "3"}))
Development
We welcome contributions!