Huge News!Announcing our $40M Series B led by Abstract Ventures.Learn More
Socket
Sign inDemoInstall
Socket

elegy

Package Overview
Dependencies
Maintainers
1
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

elegy

Elegy is a Neural Networks framework based on Jax and Haiku.

  • 0.8.6
  • PyPI
  • Socket score

Maintainers
1

Elegy

Coverage Status Contributions welcome


A High Level API for Deep Learning in JAX

Main Features
  • 😀 Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to use for most common tasks.
  • 💪‍ Flexible: Elegy provides a Pytorch Lightning-like low-level API that offers maximum flexibility when needed.
  • 🔌 Compatible: Elegy supports various frameworks and data sources including Flax & Haiku Modules, Optax Optimizers, TensorFlow Datasets, Pytorch DataLoaders, and more.

Elegy is built on top of Treex and Treeo and reexports their APIs for convenience.

Getting Started | Examples | Documentation

What is included?

  • A Model class with an Estimator-like API.
  • A callbacks module with common Keras callbacks.

From Treex

  • A Module class.
  • A nn module for with common layers.
  • A losses module with common loss functions.
  • A metrics module with common metrics.

Installation

Install using pip:

pip install elegy

For Windows users, we recommend the Windows subsystem for Linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a straightforward interface you can use by implementing the following steps:

1. Define the architecture inside a Module:

import jax
import elegy as eg

class MLP(eg.Module):
    @eg.compact
    def __call__(self, x):
        x = eg.Linear(300)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(10)(x)
        return x

2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:

import optax optax
import elegy as eg

model = eg.Model(
    module=MLP(),
    loss=[
        eg.losses.Crossentropy(),
        eg.regularizers.L2(l=1e-5),
    ],
    metrics=eg.metrics.Accuracy(),
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    inputs=X_train,
    labels=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[eg.callbacks.TensorBoard("summaries")]
)
Using Flax
Show

To use Flax with Elegy just create a flax.linen.Module and pass it to Model.

import jax
import elegy as eg
import optax optax
import flax.linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x, training: bool):
        x = nn.Dense(300)(x)
        x = jax.nn.relu(x)
        x = nn.Dense(10)(x)
        return x


model = eg.Model(
    module=MLP(),
    loss=[
        eg.losses.Crossentropy(),
        eg.regularizers.L2(l=1e-5),
    ],
    metrics=eg.metrics.Accuracy(),
    optimizer=optax.rmsprop(1e-3),
)

As shown here, Flax Modules can optionally request a training argument to __call__ which will be provided by Elegy / Treex.

Using Haiku
Show

To use Haiku with Elegy do the following:

  • Create a forward function.
  • Create a TransformedWithState object by feeding forward to hk.transform_with_state.
  • Pass your TransformedWithState to Model.

You can also optionally create your own hk.Module and use it in forward if needed. Putting everything together should look like this:

import jax
import elegy as eg
import optax optax
import haiku as hk


def forward(x, training: bool):
    x = hk.Linear(300)(x)
    x = jax.nn.relu(x)
    x = hk.Linear(10)(x)
    return x


model = eg.Model(
    module=hk.transform_with_state(forward),
    loss=[
        eg.losses.Crossentropy(),
        eg.regularizers.L2(l=1e-5),
    ],
    metrics=eg.metrics.Accuracy(),
    optimizer=optax.rmsprop(1e-3),
)

As shown here, forward can optionally request a training argument which will be provided by Elegy / Treex.

Quick Start: Low-level API

Elegy's low-level API lets you explicitly define what goes on during training, testing, and inference. Let's define our own custom Model to implement a LinearClassifier with pure JAX:

1. Define a custom init_step method:

class LinearClassifier(eg.Model):
    # use treex's API to declare parameter nodes
    w: jnp.ndarray = eg.Parameter.node()
    b: jnp.ndarray = eg.Parameter.node()

    def init_step(self, key: jnp.ndarray, inputs: jnp.ndarray):
        self.w = jax.random.uniform(
            key=key,
            shape=[features_in, 10],
        )
        self.b = jnp.zeros([10])

        self.optimizer = self.optimizer.init(self)

        return self

Here we declared the parameters w and b using Treex's Parameter.node() for pedagogical reasons, however normally you don't have to do this since you typically use a sub-Module instead.

2. Define a custom test_step method:

    def test_step(self, inputs, labels):
        # flatten + scale
        inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255

        # forward
        logits = jnp.dot(inputs, self.w) + self.b

        # crossentropy loss
        target = jax.nn.one_hot(labels["target"], 10)
        loss = optax.softmax_cross_entropy(logits, target).mean()

        # metrics
        logs = dict(
            acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]),
            loss=loss,
        )

        return loss, logs, self

3. Instantiate our LinearClassifier with an optimizer:

model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)

4. Train the model using the fit method:

model.fit(
    inputs=X_train,
    labels=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[eg.callbacks.TensorBoard("summaries")]
)
Using other JAX Frameworks
Show

It is straightforward to integrate other functional JAX libraries with this low-level API, here is an example with Flax:

import elegy as eg
import flax.linen as nn

class LinearClassifier(eg.Model):
    params: Mapping[str, Any] = eg.Parameter.node()
    batch_stats: Mapping[str, Any] = eg.BatchStat.node()
    next_key: eg.KeySeq

    def __init__(self, module: nn.Module, **kwargs):
        self.flax_module = module
        super().__init__(**kwargs)

    def init_step(self, key, inputs):
        self.next_key = eg.KeySeq(key)

        variables = self.flax_module.init(
            {"params": self.next_key(), "dropout": self.next_key()}, x
        )
        self.params = variables["params"]
        self.batch_stats = variables["batch_stats"]

        self.optimizer = self.optimizer.init(self.parameters())

    def test_step(self, inputs, labels):
        # forward
        variables = dict(
            params=self.params,
            batch_stats=self.batch_stats,
        )
        logits, variables = self.flax_module.apply(
            variables,
            inputs, 
            rngs={"dropout": self.next_key()}, 
            mutable=True,
        )
        self.batch_stats = variables["batch_stats"]
        
        # loss
        target = jax.nn.one_hot(labels["target"], 10)
        loss = optax.softmax_cross_entropy(logits, target).mean()

        # logs
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, self

Examples

Check out the /example directory for some inspiration. To run an example, first install some requirements:

pip install -r examples/requirements.txt

And the run it normally with python e.g.

python examples/flax/mnist_vae.py

Contributing

If your are interested in helping improve Elegy check out the Contributing Guide.

Sponsors 💚

Citing Elegy

BibTeX

@software{elegy2020repository,
	title        = {Elegy: A High Level API for Deep Learning in JAX},
	author       = {PoetsAI},
	year         = 2021,
	url          = {https://github.com/poets-ai/elegy},
	version      = {0.8.1}
}

FAQs


Did you know?

Socket

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap
  • Changelog

Packages

npm

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc