New Case Study:See how Anthropic automated 95% of dependency reviews with Socket.Learn More
Socket
Sign inDemoInstall
Socket

zuko

Package Overview
Dependencies
Maintainers
1
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

zuko

Normalizing flows in PyTorch

  • 1.4.0
  • PyPI
  • Socket score

Maintainers
1

Zuko's banner

Zuko - Normalizing flows in PyTorch

Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the Distribution and Transform classes of torch are not sub-classes of torch.nn.Module, which means you cannot send their internal tensors to GPU with .to('cuda') or retrieve their parameters with .parameters(). Worse, the concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express.

To solve these problems, zuko defines two concepts: the LazyDistribution and LazyTransform, which are any modules whose forward pass returns a Distribution or Transform, respectively. Because the creation of the actual distribution/transformation is delayed, an eventual condition can be easily taken into account. This design enables lazy distributions, including normalizing flows, to act like distributions while retaining features inherent to modules, such as trainable parameters. It also makes the implementations easy to understand and extend.

In the Avatar cartoon, Zuko is a powerful firebender 🔥

Acknowledgements

Zuko takes significant inspiration from nflows and Stefan Webb's work in Pyro and FlowTorch.

Installation

The zuko package is available on PyPI, which means it is installable via pip.

pip install zuko

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/probabilists/zuko

Getting started

Normalizing flows are provided in the zuko.flows module. To build one, supply the number of sample and context features as well as the transformations' hyperparameters. Then, feeding a context $c$ to the flow returns a conditional distribution $p(x | c)$ which can be evaluated and sampled from.

import torch
import zuko

# Neural spline flow (NSF) with 3 sample features and 5 context features
flow = zuko.flows.NSF(3, 5, transforms=3, hidden_features=[128] * 3)

# Train to maximize the log-likelihood
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)

for x, c in trainset:
    loss = -flow(c).log_prob(x)  # -log p(x | c)
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Sample 64 points x ~ p(x | c*)
x = flow(c_star).sample((64,))

Alternatively, flows can be built as custom Flow objects.

from zuko.flows import Flow, UnconditionalDistribution, UnconditionalTransform
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.distributions import DiagNormal
from zuko.transforms import RotationTransform

flow = Flow(
    transform=[
        MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
        UnconditionalTransform(RotationTransform, torch.randn(3, 3)),
        MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
    ],
    base=UnconditionalDistribution(
        DiagNormal,
        torch.zeros(3),
        torch.ones(3),
        buffer=True,
    ),
)

For more information, check out the documentation and tutorials at zuko.readthedocs.io.

Available flows

ClassYearReference
GMM-Gaussian Mixture Model
NICE2014Non-linear Independent Components Estimation
MAF2017Masked Autoregressive Flow for Density Estimation
NSF2019Neural Spline Flows
NCSF2020Normalizing Flows on Tori and Spheres
SOSPF2019Sum-of-Squares Polynomial Flow
NAF2018Neural Autoregressive Flows
UNAF2019Unconstrained Monotonic Neural Networks
CNF2018Neural Ordinary Differential Equations
GF2020Gaussianization Flows
BPF2020Bernstein-Polynomial Normalizing Flows

Contributing

If you have a question, an issue or would like to contribute, please read our contributing guidelines.

Keywords

FAQs


Did you know?

Socket

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap
  • Changelog

Packages

npm

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc