Haliax
Though you don’t seem to be much for listening, it’s best to be careful. If you managed to catch hold of even just a piece of my name, you’d have all manner of power over me.
— Patrick Rothfuss, The Name of the Wind
Haliax is a JAX library for building neural networks with named tensors, in the tradition of Alexander Rush's Tensor Considered Harmful.
Named tensors improve the legibility and compositionality of tensor programs by using named axes instead of positional indices
as typically used in NumPy, PyTorch, etc.
Despite the focus on legibility, Haliax
is also fast, typically about as fast as "pure" JAX code.
Haliax is also built to be scalable: it
can support Fully-Sharded Data Parallelism (FSDP) and Tensor Parallelism with just a few lines of code. Haliax powers Levanter,
our companion library for training large language models and other foundation models, with scale proven up to 20B parameters
and up to a TPU v3-256 pod slice.
Example: Attention
Here's a minimal attention module implementation in Haliax. For a more detailed introduction,
please see the Haliax tutorial.
(We use the excellent Equinox library for its module system and tree transformations.)
import haliax.nn.normalization
import haliax.nn.activations
import equinox as eqx
import jax
import jax.numpy as jnp
import haliax as hax
import haliax.nn as hnn
Pos = hax.Axis("position", 1024)
KPos = Pos.alias("key_position")
Head = hax.Axis("head", 8)
Key = hax.Axis("key", 64)
Embed = hax.Axis("embed", 512)
def attention_scores(Key, KPos, query, key, mask):
scores = hax.dot(query, key, axis=Key) / jnp.sqrt(Key.size)
if mask is not None:
scores -= 1E9 * (1.0 - mask)
scores = haliax.nn.normalization.softmax(scores, KPos)
return scores
def attention(Key, KPos, query, key, value, mask):
scores = attention_scores(Key, KPos, query, key, mask)
answers = hax.dot(scores, value, axis=KPos)
return answers
causal_mask = hax.arange(Pos).broadcast_axis(KPos) >= hax.arange(KPos)
class Attention(eqx.Module):
proj_q: hnn.Linear
proj_k: hnn.Linear
proj_v: hnn.Linear
proj_answer: hnn.Linear
@staticmethod
def init(Embed, Head, Key, *, key):
k_q, k_k, k_v, k_ans = jax.random.split(key, 4)
proj_q = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_q)
proj_k = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_k)
proj_v = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_v)
proj_answer = hnn.Linear.init(In=(Head, Key), Out=Embed, key=k_ans)
return Attention(proj_q, proj_k, proj_v, proj_answer)
def __call__(self, x, mask=None):
q = self.proj_q(x)
k = self.proj_k(x).rename({"position": "key_position"})
v = self.proj_v(x).rename({"position": "key_position"})
answers = attention(Key, KPos, q, k, v, causal_mask)
x = self.proj_answer(answers)
return x
Haliax was created by Stanford's Center for Research on Foundation Models (CRFM)'s research engineering team.
You can find us in the #levanter channel on the unofficial Jax LLM Discord.
Documentation
Tutorials
These are some tutorials to get you started with Haliax. They are available as Colab notebooks:
API Reference
Haliax's API documentation is available at haliax.readthedocs.io.
Contributing
We welcome contributions! Please see CONTRIBUTING.md for more information.
We also have a list of good first issues
to help you get started. (If those don't appeal, don't hesitate to reach out to us on Discord!)
License
Haliax is licensed under the Apache License, Version 2.0. See LICENSE for the full license text.