Huge News!Announcing our $40M Series B led by Abstract Ventures.Learn More
Socket
Sign inDemoInstall
Socket

jax-ai-stack

Package Overview
Dependencies
Maintainers
1
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

jax-ai-stack

  • 2024.11.1
  • PyPI
  • Socket score

Maintainers
1

JAX AI Stack

Continuous integration PyPI version

JAX is a Python package for array-oriented computation and program transformation. Built around it is a growing ecosystem of packages for specialized numerical computing across a range of domains; an up-to-date list of such projects can be found at Awesome JAX.

Though JAX is often compared to neural network libraries like PyTorch, the JAX core package itself contains very little that is specific to neural network models. Instead, JAX encourages modularity, where domain-specific libraries are developed separately from the core package: this helps drive innovation as researchers and other users explore what is possible.

Within this larger, distributed ecosystem, there are a number of projects that Google researchers and engineers have found useful for implementing and deploying the models behind generative AI tools like Imagen, Gemini, and more. The JAX AI stack serves as a single point-of-entry for this suite of libraries, so you can install and begin using many of the same open source packages that Google developers are using in their everyday work.

To get started with the JAX AI stack, you can check out Getting started with JAX. This is still a work-in-progress, please check back for more documentation and tutorials in the coming weeks!

Installing the stack

The stack can be installed with the following command:

pip install jax-ai-stack

This pins particular versions of component projects which are known to work correctly together via the integration tests in this repository. Packages include:

  • JAX: the core JAX package, which includes array operations and program transformations like jit, vmap, grad, etc.
  • flax: build neural networks with JAX
  • ml_dtypes: NumPy dtype extensions for machine learning.
  • optax: gradient processing and optimization in JAX.
  • orbax: checkpointing and persistence utilities for JAX.

Optional packages

Additionally, there are optional packages you can install with pip extras. The following command:

pip install jax-ai-stack[grain]

will install a compatible version of the grain data loader (currently linux-only).

Similarly, the following command:

pip install jax-ai-stack[tfds]

will install a compatible version of tensorflow and tensorflow-datasets.

FAQs


Did you know?

Socket

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap
  • Changelog

Packages

npm

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc