correctionlib-gradients
A JAX-friendly, auto-differentiable, Python-only implementation of correctionlib correction evaluations.
Table of Contents
Installation
pip install correctionlib-gradients
Usage
- construct a
CorrectionWithGradient
object from a correctionlib.schemav2.Correction
- there is no point 2: you can use
CorrectionWithGradient.evaluate
as a normal JAX-friendly, auto-differentiable function
Example
import jax
import jax.numpy as jnp
from correctionlib import schemav2
from correctionlib_gradients import CorrectionWithGradient
formula_schema = schemav2.Correction(
name="x squared",
version=2,
inputs=[schemav2.Variable(name="x", type="real")],
output=schemav2.Variable(name="a scale", type="real"),
data=schemav2.Formula(
nodetype="formula",
expression="x * x",
parser="TFormula",
variables=["x"],
),
)
c = CorrectionWithGradient(formula_schema)
value, grad = jax.value_and_grad(c.evaluate)(3.0)
assert jnp.isclose(value, 9.0)
assert jnp.isclose(grad, 6.0)
xs = jnp.array([3.0, 4.0])
values, grads = jax.vmap(jax.jit(jax.value_and_grad(c.evaluate)))(xs)
assert jnp.allclose(values, jnp.array([9.0, 16.0]))
assert jnp.allclose(grads, jnp.array([6.0, 8.0]))
Supported types of corrections
Currently the following corrections from correctionlib.schemav2
are supported:
Formula
, including parametrical formulasBinning
with uniform or non-uniform bin edges and flow="clamp"
; bin contents can be either:
- all scalar values
- all
Formula
or FormulaRef
- scalar constants
Known limitations
Only the evaluation of Formula
corrections is fully JAX traceable.
For other corrections, e.g. Binning
, gradients can be computed (jax.grad
works) but as JAX cannot
trace the computation utilities such as jax.jit
and jax.vmap
will not work.
np.vectorize
can be used as an alternative to jax.vmap
in these cases.
License
correctionlib-gradients
is distributed under the terms of the BSD 3-Clause license.