
Security News
MCP Community Begins Work on Official MCP Metaregistry
The MCP community is launching an official registry to standardize AI tool discovery and let agents dynamically find and install MCP servers.
Equinox is your one-stop JAX library, for everything you need that isn't already in core JAX:
and best of all, Equinox isn't a framework: everything you write in Equinox is compatible with anything else in JAX or the ecosystem.
If you're completely new to JAX, then start with this CNN on MNIST example.
Coming from Flax or Haiku? The main difference is that Equinox (a) offers a lot of advanced features not found in these libraries, like PyTree manipulation or runtime errors; (b) has a simpler way of building models: they're just PyTrees, so they can pass across JIT/grad/etc. boundaries smoothly.
Requires Python 3.10+.
pip install equinox
Equinox is also available through a community-supported build on conda-forge.
Available at https://docs.kidger.site/equinox.
Models are defined using PyTorch-like syntax:
import equinox as eqx
import jax
class Linear(eqx.Module):
weight: jax.Array
bias: jax.Array
def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
def __call__(self, x):
return self.weight @ x + self.bias
and are fully compatible with normal JAX operations:
@jax.jit
@jax.grad
def loss_fn(model, x, y):
pred_y = jax.vmap(model)(x)
return jax.numpy.mean((y - pred_y) ** 2)
batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)
Finally, there's no magic behind the scenes. All eqx.Module
does is register your class as a PyTree. From that point onwards, JAX already knows how to work with PyTrees.
If you found this library to be useful in academic work, then please cite: (arXiv link)
@article{kidger2021equinox,
author={Patrick Kidger and Cristian Garcia},
title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
year={2021},
journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}
(Also consider starring the project on GitHub.)
Always useful
jaxtyping: type annotations for shape/dtype of arrays.
Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).
paramax: parameterizations and constraints for PyTrees.
Scientific computing
Diffrax: numerical differential equation solvers.
Optimistix: root finding, minimisation, fixed points, and least squares.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
PySR: symbolic regression. (Non-JAX honourable mention!)
Awesome JAX
Awesome JAX: a longer list of other JAX projects.
FAQs
Elegant easy-to-use neural networks in JAX.
We found that equinox demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Security News
The MCP community is launching an official registry to standardize AI tool discovery and let agents dynamically find and install MCP servers.
Research
Security News
Socket uncovers an npm Trojan stealing crypto wallets and BullX credentials via obfuscated code and Telegram exfiltration.
Research
Security News
Malicious npm packages posing as developer tools target macOS Cursor IDE users, stealing credentials and modifying files to gain persistent backdoor access.