Huge News!Announcing our $40M Series B led by Abstract Ventures.Learn More
Socket
Sign inDemoInstall
Socket

pytreeclass

Package Overview
Dependencies
Maintainers
1
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

pytreeclass

Visualize, create, and operate on JAX PyTree in the most intuitive way possible.

  • 0.9.2
  • Source
  • PyPI
  • Socket score

Maintainers
1


Installation |Description |Quick Example |StatefulComputation |Benchamrks |Acknowledgements

Tests pyver pyver codestyle Open In Colab Downloads codecov Documentation Status GitHub commit activity DOI PyPI CodeFactor

🛠️ Installation

pip install pytreeclass

Install development version

pip install git+https://github.com/ASEM000/pytreeclass

📖 Description

pytreeclass is a JAX-compatible class builder to create and operate on stateful JAX PyTrees in a performant and intuitive way, by building on familiar concepts found in numpy, dataclasses, and others.

See documentation and 🍳 Common recipes to check if this library is a good fit for your work. If you find the package useful consider giving it a 🌟.

⏩ Quick Example

import jax
import jax.numpy as jnp
import pytreeclass as tc

@tc.autoinit
class Tree(tc.TreeClass):
    a: float = 1.0
    b: tuple[float, float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x


tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
       .at["a"].set(100.0)\
       .at["b"][0].set(10.0)\
       .at[mask].set(100.0)

print(tree)
# Tree(a=100.0, b=(10.0, 3.0), c=[  4.   5. 100.])

print(tc.tree_diagram(tree))
# Tree
# ├── .a=100.0
# ├── .b:tuple
# │   ├── [0]=10.0
# │   └── [1]=3.0
# └── .c=f32[3](μ=36.33, σ=45.02, ∈[4.00,100.00])

print(tc.tree_summary(tree))
# ┌─────┬──────┬─────┬──────┐
# │Name │Type  │Count│Size  │
# ├─────┼──────┼─────┼──────┤
# │.a   │float │1    │      │
# ├─────┼──────┼─────┼──────┤
# │.b[0]│float │1    │      │
# ├─────┼──────┼─────┼──────┤
# │.b[1]│float │1    │      │
# ├─────┼──────┼─────┼──────┤
# │.c   │f32[3]│3    │12.00B│
# ├─────┼──────┼─────┼──────┤
# │Σ    │Tree  │6    │12.00B│
# └─────┴──────┴─────┴──────┘

# ** pass it to jax transformations **
# works with jit, grad, vmap, etc.

@jax.jit
@jax.grad
def sum_tree(tree: Tree, x):
    return sum(tree(x))

print(sum_tree(tree, 1.0))
# Tree(a=3.0, b=(3.0, 0.0), c=[1. 1. 1.])

📜 Stateful computations

Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using TreeClass no need to separate the instance variables ; instead the whole instance is passed as a state.

Using the following pattern,Updating state functionally can be achieved under jax.jit

import jax
import pytreeclass as tc

class Counter(tc.TreeClass):
    def __init__(self, calls: int = 0):
        self.calls = calls

    def increment(self):
        self.calls += 1
counter = Counter() # Counter(calls=0)

Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at. To achieve this we can use .at[method_name].__call__(*args,**kwargs), this functional call will return the value of this call and a new model instance with the update state.

@jax.jit
def update(counter):
    value, new_counter = counter.at["increment"]()
    return new_counter

for i in range(10):
    counter = update(counter)

print(counter.calls) # 10

➕ Benchmarks

Benchmark flatten/unflatten compared to Flax and Equinox

Open In Colab

CPUGPU
Benchmark simple training against `flax` and `equinox`

Training simple sequential linear benchmark against flax and equinox

Num of layersFlax/tc time
Open In Colab
Equinox/tc time
Open In Colab
101.4276.671
1001.11302.714

📙 Acknowledgements

Keywords

FAQs


Did you know?

Socket

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap
  • Changelog

Packages

npm

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc