You're Invited:Meet the Socket Team at RSAC and BSidesSF 2026, March 23–26.RSVP
Socket
Book a DemoSign in
Socket

efax

Package Overview
Dependencies
Maintainers
1
Versions
115
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

efax - pypi Package Compare versions

Comparing version
1.22.0
to
1.22.1
+1
tests/match_scipy/__init__.py
"""These tests ensure that our distributions match scipy's."""
from __future__ import annotations
from numpy.random import Generator
from numpy.testing import assert_allclose
from efax import HasEntropyEP, HasEntropyNP
from ..create_info import ChiInfo, ChiSquareInfo
from ..distribution_info import DistributionInfo
def test_nat_entropy(generator: Generator,
entropy_distribution_info: DistributionInfo
) -> None:
"""Test that the entropy calculation matches scipy's."""
shape = (7, 13)
nat_parameters = entropy_distribution_info.nat_parameter_generator(generator, shape=shape)
assert isinstance(nat_parameters, HasEntropyNP)
scipy_distribution = entropy_distribution_info.nat_to_scipy_distribution(nat_parameters)
rtol = 2e-5
my_entropy = nat_parameters.entropy()
scipy_entropy = scipy_distribution.entropy()
assert_allclose(my_entropy, scipy_entropy, rtol=rtol)
def test_exp_entropy(generator: Generator,
entropy_distribution_info: DistributionInfo
) -> None:
"""Test that the entropy calculation matches scipy's."""
shape = (7, 13)
exp_parameters = entropy_distribution_info.exp_parameter_generator(generator, shape=shape)
assert isinstance(exp_parameters, HasEntropyEP)
scipy_distribution = entropy_distribution_info.exp_to_scipy_distribution(exp_parameters)
rtol = (1e-5
if isinstance(entropy_distribution_info, ChiInfo | ChiSquareInfo)
else 1e-6)
my_entropy = exp_parameters.entropy()
scipy_entropy = scipy_distribution.entropy()
assert_allclose(my_entropy, scipy_entropy, rtol=rtol)
from __future__ import annotations
from functools import partial
from typing import Any
import numpy as np
from jax import Array
from numpy.random import Generator
from tjax import assert_tree_allclose
from efax import (MaximumLikelihoodEstimator, NaturalParametrization, flat_dict_of_observations,
flat_dict_of_parameters, parameter_map, unflatten_mapping)
from ..create_info import (ComplexCircularlySymmetricNormalInfo, IsotropicNormalInfo,
MultivariateFixedVarianceNormalInfo, MultivariateNormalInfo,
VonMisesFisherInfo)
from ..distribution_info import DistributionInfo
def test_maximum_likelihood_estimation(
generator: Generator,
distribution_info: DistributionInfo[NaturalParametrization]
) -> None:
"""Test maximum likelihood estimation using SciPy.
Test that maximum likelihood estimation from scipy-generated variates produce the same
distribution from which they were drawn.
"""
rtol = 2e-2
if isinstance(distribution_info,
ComplexCircularlySymmetricNormalInfo | MultivariateNormalInfo
| VonMisesFisherInfo | MultivariateFixedVarianceNormalInfo):
atol = 1e-2
elif isinstance(distribution_info, IsotropicNormalInfo):
atol = 1e-3
else:
atol = 1e-6
n = 70000
# Generate a distribution with expectation parameters.
exp_parameters = distribution_info.exp_parameter_generator(generator, shape=())
# Generate variates from the corresponding scipy distribution.
scipy_distribution = distribution_info.exp_to_scipy_distribution(exp_parameters)
scipy_x = scipy_distribution.rvs(random_state=generator, size=n)
# Convert the variates to sufficient statistics.
efax_x = distribution_info.scipy_to_exp_family_observation(scipy_x)
flat_efax_x = flat_dict_of_observations(efax_x)
flat_parameters = flat_dict_of_parameters(exp_parameters)
flat_efax_x_clamped = {path: flat_parameters[path].domain_support().clamp(value)
for path, value in flat_efax_x.items()}
efax_x_clamped: Array | dict[str, Any] = (flat_efax_x_clamped[()]
if flat_efax_x_clamped.keys() == {()}
else unflatten_mapping(flat_efax_x_clamped))
estimator = MaximumLikelihoodEstimator.create_estimator(exp_parameters)
sufficient_stats = estimator.sufficient_statistics(efax_x_clamped)
# Verify that the mean of the sufficient statistics equals the expectation parameters.
calculated_parameters = parameter_map(partial(np.mean, axis=0), # type: ignore[arg-type]
sufficient_stats)
assert_tree_allclose(exp_parameters, calculated_parameters, rtol=rtol, atol=atol)
from __future__ import annotations
from typing import Any
import numpy as np
from jax import Array
from numpy.random import Generator
from numpy.testing import assert_allclose
from tjax import JaxComplexArray
from efax import JointDistributionN, Multidimensional, NaturalParametrization, SimpleDistribution
from ..create_info import MultivariateDiagonalNormalInfo
from ..distribution_info import DistributionInfo
def _check_observation_shape(nat_parameters: NaturalParametrization,
efax_x: JaxComplexArray | dict[str, Any],
distribution_shape: tuple[int, ...],
) -> None:
"""Verify that the sufficient statistics have the right shape."""
if isinstance(nat_parameters, JointDistributionN):
assert isinstance(efax_x, dict)
for name, value in nat_parameters.sub_distributions().items():
_check_observation_shape(value, efax_x[name], distribution_shape)
return
assert isinstance(nat_parameters, SimpleDistribution) # type: ignore[unreachable]
assert isinstance(efax_x, Array) # type: ignore[unreachable]
dimensions = (nat_parameters.dimensions()
if isinstance(nat_parameters, Multidimensional)
else 0)
ideal_shape = distribution_shape + nat_parameters.domain_support().shape(dimensions)
assert efax_x.shape == ideal_shape
def test_pdf(generator: Generator, distribution_info: DistributionInfo) -> None:
"""Test that the density/mass function calculation matches scipy's."""
distribution_shape = (10,)
nat_parameters = distribution_info.nat_parameter_generator(generator, shape=distribution_shape)
scipy_distribution = distribution_info.nat_to_scipy_distribution(nat_parameters)
scipy_x = scipy_distribution.rvs(random_state=generator)
efax_x = distribution_info.scipy_to_exp_family_observation(scipy_x)
_check_observation_shape(nat_parameters, efax_x, distribution_shape)
# Verify that the density matches scipy.
efax_density = np.asarray(nat_parameters.pdf(efax_x), dtype=np.float64)
try:
scipy_density = scipy_distribution.pdf(scipy_x)
except AttributeError:
scipy_density = scipy_distribution.pmf(scipy_x)
if isinstance(distribution_info, MultivariateDiagonalNormalInfo):
atol = 1e-5
rtol = 3e-4
else:
atol = 1e-5
rtol = 1e-4
assert_allclose(efax_density, scipy_density, rtol=rtol, atol=atol)
from __future__ import annotations
import pytest
from numpy.random import Generator
from tjax import assert_tree_allclose, print_generic, tree_allclose
from efax import parameters
from .create_info import GeneralizedDirichletInfo
from .distribution_info import DistributionInfo
def test_conversion(generator: Generator,
distribution_info: DistributionInfo
) -> None:
"""Test that the conversion between the different parametrizations are consistent."""
if isinstance(distribution_info, GeneralizedDirichletInfo):
pytest.skip()
n = 30
shape = (n,)
original_np = distribution_info.nat_parameter_generator(generator, shape=shape)
intermediate_ep = original_np.to_exp()
final_np = intermediate_ep.to_nat()
# Check round trip.
if not tree_allclose(final_np, original_np):
for i in range(n):
if not tree_allclose(final_np[i], original_np[i]):
print_generic({"original": original_np[i],
"intermediate": intermediate_ep[i],
"final": final_np[i]})
pytest.fail("Conversion failure")
# Check fixed parameters.
original_fixed = parameters(original_np, fixed=True)
intermediate_fixed = parameters(intermediate_ep, fixed=True)
final_fixed = parameters(final_np, fixed=True)
assert_tree_allclose(original_fixed, intermediate_fixed)
assert_tree_allclose(original_fixed, final_fixed)
from __future__ import annotations
from collections.abc import Callable
from typing import TypeAlias
from array_api_compat import array_namespace
from jax import grad, jvp, vjp
from jax.custom_derivatives import zero_from_primal
from numpy.random import Generator
from numpy.testing import assert_allclose
from tjax import JaxRealArray, assert_tree_allclose, jit
from efax import NaturalParametrization, Structure, parameters
from .distribution_info import DistributionInfo
_LogNormalizer: TypeAlias = Callable[[NaturalParametrization], JaxRealArray]
def _prelude(generator: Generator,
distribution_info: DistributionInfo
) -> tuple[_LogNormalizer, _LogNormalizer]:
cls = distribution_info.nat_class()
original_ln = cls._original_log_normalizer
optimized_ln = cls.log_normalizer
return original_ln, optimized_ln
def test_primals(generator: Generator, distribution_info: DistributionInfo) -> None:
"""Tests that the gradient log-normalizer equals the gradient of the log-normalizer."""
original_ln, optimized_ln = _prelude(generator, distribution_info)
original_gln = jit(grad(original_ln, allow_int=True))
optimized_gln = jit(grad(optimized_ln, allow_int=True))
for _ in range(20):
generated_np = distribution_info.nat_parameter_generator(generator, shape=())
generated_ep = generated_np.to_exp() # Regular transformation.
generated_parameters = parameters(generated_ep, fixed=False)
structure_ep = Structure.create(generated_ep)
# Original GLN.
original_gln_np = original_gln(generated_np)
original_gln_ep = structure_ep.reinterpret(original_gln_np)
original_gln_parameters = parameters(original_gln_ep, fixed=False)
# Optimized GLN.
optimized_gln_np = optimized_gln(generated_np)
optimized_gln_ep = structure_ep.reinterpret(optimized_gln_np)
optimized_gln_parameters = parameters(optimized_gln_ep, fixed=False)
# Test primal evaluation.
# parameters(generated_ep, fixed=False)
assert_tree_allclose(generated_parameters, original_gln_parameters, rtol=1e-5)
assert_tree_allclose(generated_parameters, optimized_gln_parameters, rtol=1e-5)
def _unit_tangent(nat_parameters: NaturalParametrization
) -> NaturalParametrization:
xp = array_namespace(nat_parameters)
new_variable_parameters = {path: xp.ones_like(value)
for path, value in parameters(nat_parameters, fixed=False).items()}
new_fixed_parameters = {path: zero_from_primal(value, symbolic_zeros=False)
for path, value in parameters(nat_parameters, fixed=True).items()}
structure = Structure.create(nat_parameters)
return structure.assemble({**new_variable_parameters, **new_fixed_parameters})
def test_jvp(generator: Generator, distribution_info: DistributionInfo) -> None:
"""Tests that the gradient log-normalizer equals the gradient of the log-normalizer."""
original_ln, optimized_ln = _prelude(generator, distribution_info)
for _ in range(20):
nat_parameters = distribution_info.nat_parameter_generator(generator, shape=())
# Test JVP.
nat_tangent = _unit_tangent(nat_parameters)
original_ln_of_nat, original_jvp = jvp(original_ln, (nat_parameters,), (nat_tangent,))
optimized_ln_of_nat, optimized_jvp = jvp(optimized_ln, (nat_parameters,), (nat_tangent,))
assert_allclose(original_ln_of_nat, optimized_ln_of_nat, rtol=1.5e-5)
assert_allclose(original_jvp, optimized_jvp, rtol=1.5e-5)
def test_vjp(generator: Generator, distribution_info: DistributionInfo) -> None:
"""Tests that the gradient log-normalizer equals the gradient of the log-normalizer."""
original_ln, optimized_ln = _prelude(generator, distribution_info)
for _ in range(20):
nat_parameters = distribution_info.nat_parameter_generator(generator, shape=())
nat_tangent = _unit_tangent(nat_parameters)
original_ln_of_nat, _ = jvp(original_ln, (nat_parameters,), (nat_tangent,))
original_ln_of_nat_b, original_vjp = vjp(original_ln, nat_parameters)
original_gln_of_nat, = original_vjp(1.0)
optimized_ln_of_nat_b, optimized_vjp = vjp(optimized_ln, nat_parameters)
optimized_gln_of_nat, = optimized_vjp(1.0)
assert_allclose(original_ln_of_nat_b, optimized_ln_of_nat_b, rtol=1e-5)
assert_allclose(original_ln_of_nat, original_ln_of_nat_b, rtol=1e-5)
assert_tree_allclose(parameters(original_gln_of_nat, fixed=False),
parameters(optimized_gln_of_nat, fixed=False),
rtol=1e-5)
+9
-9
from __future__ import annotations
from typing import Any, Self
from typing import Self

@@ -46,3 +46,3 @@ import jax.random as jr

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.logaddexp(self.log_odds, xp.asarray(0.0))

@@ -56,3 +56,3 @@

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(x.shape)

@@ -67,3 +67,3 @@

def nat_to_probability(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
p = jss.expit(self.log_odds)

@@ -74,3 +74,3 @@ final_p = 1.0 - p

def nat_to_surprisal(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
total_p = self.nat_to_probability()

@@ -117,3 +117,3 @@ return -xp.log(total_p)

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -123,3 +123,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxBooleanArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -131,3 +131,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

def conjugate_prior_distribution(self, n: JaxRealArray) -> BetaNP:
xp = self.array_namespace()
xp = array_namespace(self)
reshaped_n = n[..., np.newaxis]

@@ -138,3 +138,3 @@ return BetaNP(reshaped_n * xp.stack([self.probability, (1.0 - self.probability)], axis=-1))

@override
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any]
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization
) -> tuple[Self, JaxRealArray]:

@@ -141,0 +141,0 @@ assert isinstance(cp, BetaNP)

@@ -63,3 +63,3 @@ from __future__ import annotations

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.log(2.0 * x) - xp.square(x) * 0.5

@@ -66,0 +66,0 @@

@@ -6,2 +6,3 @@ from __future__ import annotations

import jax.random as jr
from array_api_compat import array_namespace
from tjax import JaxArray, JaxComplexArray, JaxRealArray, KeyArray, Shape, outer_product

@@ -53,3 +54,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
log_det_s = xp.log(xp.real(xp.linalg.det(-self.negative_precision)))

@@ -60,3 +61,3 @@ return -log_det_s + self.dimensions() * math.log(math.pi)

def to_exp(self) -> ComplexCircularlySymmetricNormalEP:
xp = self.array_namespace()
xp = array_namespace(self)
return ComplexCircularlySymmetricNormalEP(

@@ -67,3 +68,3 @@ xp.conj(xp.linalg.inv(-self.negative_precision)))

def carrier_measure(self, x: JaxComplexArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(x.shape[:-1])

@@ -118,3 +119,3 @@

def to_nat(self) -> ComplexCircularlySymmetricNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
return ComplexCircularlySymmetricNormalNP(xp.conj(-xp.linalg.inv(self.variance)))

@@ -124,3 +125,3 @@

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -144,3 +145,3 @@

"""Return the mean of a corresponding real distribution with double the size."""
xp = self.array_namespace()
xp = array_namespace(self)
n = self.dimensions()

@@ -151,3 +152,3 @@ return xp.zeros((*self.shape, n * 2))

"""Return the covariance of a corresponding real distribution with double the size."""
xp = self.array_namespace()
xp = array_namespace(self)
gamma_r = 0.5 * xp.real(self.variance)

@@ -154,0 +155,0 @@ gamma_i = 0.5 * xp.imag(self.variance)

@@ -6,2 +6,3 @@ from __future__ import annotations

import jax.random as jr
from array_api_compat import array_namespace
from tjax import JaxArray, JaxComplexArray, JaxRealArray, KeyArray, Shape, abs_square

@@ -51,3 +52,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
mean_conjugate = self.two_mean_conjugate * 0.5

@@ -58,3 +59,3 @@ return xp.sum(abs_square(mean_conjugate), axis=-1) + self.dimensions() * math.log(math.pi)

def to_exp(self) -> ComplexMultivariateUnitVarianceNormalEP:
xp = self.array_namespace()
xp = array_namespace(self)
return ComplexMultivariateUnitVarianceNormalEP(xp.conj(self.two_mean_conjugate) * 0.5)

@@ -64,3 +65,3 @@

def carrier_measure(self, x: JaxComplexArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return -xp.sum(abs_square(x), axis=-1)

@@ -114,3 +115,3 @@

def to_nat(self) -> ComplexMultivariateUnitVarianceNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
return ComplexMultivariateUnitVarianceNormalNP(xp.conj(self.mean) * 2.0)

@@ -120,3 +121,3 @@

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
# The second moment of a normal distribution with the given mean.

@@ -127,3 +128,3 @@ return -(xp.sum(abs_square(self.mean), axis=-1) + self.dimensions())

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxComplexArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.mean.shape if shape is None else shape + self.mean.shape

@@ -130,0 +131,0 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape))

@@ -54,3 +54,3 @@ from __future__ import annotations

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
_, s, mu = self._r_s_mu()

@@ -67,3 +67,3 @@ det_s = s

def to_exp(self) -> ComplexNormalEP:
xp = self.array_namespace()
xp = array_namespace(self)
r, s, mu = self._r_s_mu()

@@ -75,3 +75,3 @@ u = xp.conj(r * s)

def carrier_measure(self, x: JaxComplexArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(x.shape)

@@ -87,3 +87,3 @@

def _r_s_mu(self) -> tuple[JaxComplexArray, JaxRealArray, JaxComplexArray]:
xp = self.array_namespace()
xp = array_namespace(self)
r = -self.pseudo_precision / self.negative_precision

@@ -143,3 +143,3 @@ s = xp.reciprocal((abs_square(r) - 1.0) * self.negative_precision)

def to_nat(self) -> ComplexNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
variance = self.second_moment - abs_square(self.mean)

@@ -159,3 +159,3 @@ pseudo_variance = self.pseudo_second_moment - xp.square(self.mean)

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -173,3 +173,3 @@

def _multivariate_normal_cov(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
variance = self.second_moment - abs_square(self.mean)

@@ -187,3 +187,3 @@ pseudo_variance = self.pseudo_second_moment - xp.square(self.mean)

shape = self.shape if shape is None else shape + self.shape
xp = self.array_namespace()
xp = array_namespace(self)
mn_mean = xp.stack([xp.real(self.mean), xp.imag(self.mean)], axis=-1)

@@ -190,0 +190,0 @@ mn_cov = self._multivariate_normal_cov()

@@ -6,2 +6,3 @@ from __future__ import annotations

import jax.random as jr
from array_api_compat import array_namespace
from tjax import JaxArray, JaxComplexArray, JaxRealArray, KeyArray, Shape, abs_square

@@ -54,3 +55,3 @@ from tjax.dataclasses import dataclass

def to_exp(self) -> ComplexUnitVarianceNormalEP:
xp = self.array_namespace()
xp = array_namespace(self)
return ComplexUnitVarianceNormalEP(xp.conj(self.two_mean_conjugate) * 0.5)

@@ -103,3 +104,3 @@

def to_nat(self) -> ComplexUnitVarianceNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
return ComplexUnitVarianceNormalNP(xp.conj(self.mean) * 2.0)

@@ -114,3 +115,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxComplexArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -117,0 +118,0 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

@@ -7,2 +7,3 @@ from __future__ import annotations

import jax.scipy.special as jss
from array_api_compat import array_namespace
from tjax import JaxRealArray, KeyArray, Shape

@@ -38,3 +39,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
q = self.alpha_minus_one

@@ -55,3 +56,3 @@ return (xp.sum(jss.gammaln(q + 1.0), axis=-1)

def _exp_helper(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
q = self.alpha_minus_one

@@ -79,3 +80,3 @@ return jss.digamma(q + 1.0) - jss.digamma(xp.sum(q, axis=-1, keepdims=True) + q.shape[-1])

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -82,0 +83,0 @@

@@ -41,3 +41,3 @@ from __future__ import annotations

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(x.shape[:len(x.shape) - 1])

@@ -44,0 +44,0 @@

from __future__ import annotations
from typing import Any, Self
from typing import Self
import jax.random as jr
from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray, KeyArray, Shape

@@ -41,3 +42,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return -xp.log(-self.negative_rate)

@@ -47,3 +48,3 @@

def to_exp(self) -> ExponentialEP:
xp = self.array_namespace()
xp = array_namespace(self)
return ExponentialEP(-xp.reciprocal(self.negative_rate))

@@ -53,3 +54,3 @@

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(x.shape)

@@ -65,3 +66,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -100,3 +101,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

def to_nat(self) -> ExponentialNP:
xp = self.array_namespace()
xp = array_namespace(self)
return ExponentialNP(-xp.reciprocal(self.mean))

@@ -106,3 +107,3 @@

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -112,3 +113,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -124,3 +125,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

@override
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any]
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization
) -> tuple[Self, JaxRealArray]:

@@ -127,0 +128,0 @@ assert isinstance(cp, GammaNP)

@@ -48,3 +48,3 @@ from __future__ import annotations

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape_minus_one + 1.0

@@ -55,3 +55,3 @@ return jss.gammaln(shape) - shape * xp.log(-self.negative_rate)

def to_exp(self) -> GammaEP:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape_minus_one + 1.0

@@ -63,3 +63,3 @@ return GammaEP(-shape / self.negative_rate,

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(x.shape)

@@ -76,3 +76,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -90,3 +90,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

def to_approximate_log_normal(self) -> LogNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape_minus_one + 1.0

@@ -133,3 +133,3 @@ rate = -self.negative_rate

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -149,3 +149,3 @@

def search_gradient(self, search_parameters: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = softplus(search_parameters[..., 0])

@@ -159,3 +159,3 @@ log_mean_minus_mean_log = xp.log(self.mean) - self.mean_log

def initial_search_parameters(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
log_mean_minus_mean_log = xp.log(self.mean) - self.mean_log

@@ -162,0 +162,0 @@ initial_shape: JaxRealArray = (

@@ -44,3 +44,3 @@ """The generalized Dirichlet distribution.

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
alpha, beta = self.alpha_beta()

@@ -56,3 +56,3 @@ return xp.sum(jss.betaln(alpha, beta), axis=-1)

# = jss.digamma(beta) - jss.digamma(alpha + beta)
xp = self.array_namespace()
xp = array_namespace(self)
alpha, beta = self.alpha_beta()

@@ -79,3 +79,3 @@ digamma_sum = jss.digamma(alpha + beta)

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(x.shape[:len(x.shape) - 1])

@@ -88,3 +88,3 @@

def alpha_beta(self) -> tuple[JaxRealArray, JaxRealArray]:
xp = self.array_namespace()
xp = array_namespace(self)
alpha = self.alpha_minus_one + 1.0

@@ -135,3 +135,3 @@ # cs_alpha[i] = sum_{j>=i} alpha[j]

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -138,0 +138,0 @@

from __future__ import annotations
import jax.random as jr
from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray, KeyArray, Shape

@@ -77,3 +78,3 @@ from tjax.dataclasses import dataclass

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.mean.shape)

@@ -85,3 +86,3 @@

shape += self.shape
xp = self.array_namespace()
xp = array_namespace(self)
p = xp.reciprocal(self.mean)

@@ -88,0 +89,0 @@ return jr.geometric(key, p, shape)

@@ -59,3 +59,3 @@ from __future__ import annotations

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return -2.0 * xp.log(x)

@@ -65,3 +65,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.reciprocal(self.base_distribution().sample(key, shape))

@@ -109,3 +109,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.reciprocal(self.base_distribution().sample(key, shape))

@@ -44,3 +44,3 @@ from __future__ import annotations

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return (-0.5 * xp.log(-2.0 * self.negative_lambda_over_two)

@@ -52,3 +52,3 @@ - 2.0 * xp.sqrt(self.negative_lambda_over_two_mu_squared

def to_exp(self) -> InverseGaussianEP:
xp = self.array_namespace()
xp = array_namespace(self)
mu = xp.sqrt(self.negative_lambda_over_two / self.negative_lambda_over_two_mu_squared)

@@ -61,3 +61,3 @@ lambda_ = -2.0 * self.negative_lambda_over_two

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return -0.5 * (xp.log(2.0 * xp.pi) + 3.0 * xp.log(x))

@@ -84,3 +84,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -136,3 +136,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

def to_nat(self) -> InverseGaussianNP:
xp = self.array_namespace()
xp = array_namespace(self)
lambda_ = xp.reciprocal(self.mean_reciprocal - xp.reciprocal(self.mean))

@@ -139,0 +139,0 @@ eta2 = -0.5 * lambda_

@@ -64,3 +64,3 @@ from __future__ import annotations

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return -xp.log(x)

@@ -70,3 +70,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -121,3 +121,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -124,0 +124,0 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

from __future__ import annotations
from typing import Any, Self
from typing import Self

@@ -65,3 +65,3 @@ import jax.random as jr

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return -0.5 * xp.square(xp.log(x)) - xp.log(x)

@@ -111,3 +111,3 @@

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return -self.mean - 0.5 * (xp.square(self.mean) + 1.0)

@@ -117,3 +117,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -130,3 +130,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

@override
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any]
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization
) -> tuple[Self, JaxRealArray]:

@@ -133,0 +133,0 @@ uvn, n = UnitVarianceNormalEP.from_conjugate_prior_distribution(cp)

from __future__ import annotations
from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray, Shape

@@ -40,3 +41,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.log(-xp.log1p(-xp.exp(self.log_probability)))

@@ -46,3 +47,3 @@

def to_exp(self) -> LogarithmicEP:
xp = self.array_namespace()
xp = array_namespace(self)
probability = xp.exp(self.log_probability)

@@ -59,3 +60,3 @@ chi = xp.where(self.log_probability < log_probability_floor,

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return -xp.log(x)

@@ -100,3 +101,3 @@

def to_nat(self) -> LogarithmicNP:
xp = self.array_namespace()
xp = array_namespace(self)
z: LogarithmicNP = super().to_nat()

@@ -103,0 +104,0 @@ return LogarithmicNP(xp.where(self.chi < 1.0,

from __future__ import annotations
from typing import Any, Self
from typing import Self

@@ -9,3 +9,3 @@ import array_api_extra as xpx

import scipy.special as sc
from jax.nn import one_hot
from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray, KeyArray, Shape

@@ -48,3 +48,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
max_q = xp.maximum(0.0, xp.max(self.log_odds, axis=-1))

@@ -57,3 +57,3 @@ q_minus_max_q = self.log_odds - max_q[..., np.newaxis]

def to_exp(self) -> MultinomialEP:
xp = self.array_namespace()
xp = array_namespace(self)
max_q = xp.maximum(0.0, xp.max(self.log_odds, axis=-1))

@@ -66,3 +66,3 @@ q_minus_max_q = self.log_odds - max_q[..., np.newaxis]

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(x.shape[:-1])

@@ -80,4 +80,5 @@

shape += self.shape
return one_hot(jr.categorical(key, self.log_odds, shape=shape),
self.dimensions())
retval = xpx.one_hot(jr.categorical(key, self.log_odds, shape=shape), self.dimensions())
assert isinstance(retval, JaxRealArray)
return retval

@@ -89,3 +90,3 @@ @override

def nat_to_probability(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
max_q = xp.maximum(0.0, xp.max(self.log_odds, axis=-1))

@@ -99,3 +100,3 @@ q_minus_max_q = self.log_odds - max_q[..., np.newaxis]

def nat_to_surprisal(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
total_p = self.nat_to_probability()

@@ -135,3 +136,3 @@ return -xp.log(total_p)

def to_nat(self) -> MultinomialNP:
xp = self.array_namespace()
xp = array_namespace(self)
p_k = 1.0 - xp.sum(self.probability, axis=-1, keepdims=True)

@@ -142,3 +143,3 @@ return MultinomialNP(xp.log(self.probability / p_k))

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -148,3 +149,3 @@

def conjugate_prior_distribution(self, n: JaxRealArray) -> DirichletNP:
xp = self.array_namespace()
xp = array_namespace(self)
reshaped_n = n[..., np.newaxis]

@@ -156,3 +157,3 @@ final_p = 1.0 - xp.sum(self.probability, axis=-1, keepdims=True)

@override
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any],
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization,
) -> tuple[Self, JaxRealArray]:

@@ -164,3 +165,3 @@ assert isinstance(cp, GeneralizedDirichletNP)

def generalized_conjugate_prior_distribution(self, n: JaxRealArray) -> GeneralizedDirichletNP:
xp = self.array_namespace()
xp = array_namespace(self)
final_p = 1.0 - xp.sum(self.probability, axis=-1, keepdims=True)

@@ -167,0 +168,0 @@ all_p = xp.concat((self.probability, final_p), axis=-1)

@@ -7,2 +7,3 @@ from __future__ import annotations

import jax.random as jr
from array_api_compat import array_namespace
from tjax import (JaxArray, JaxRealArray, KeyArray, Shape, matrix_dot_product, matrix_vector_mul,

@@ -47,3 +48,3 @@ outer_product)

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
eta = self.mean_times_precision

@@ -60,3 +61,3 @@ h_inv = xp.linalg.inv(self.negative_half_precision)

def variance(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
h_inv: JaxRealArray = xp.linalg.inv(self.negative_half_precision)

@@ -76,3 +77,3 @@ return -0.5 * h_inv

def to_exp(self) -> MultivariateNormalEP:
xp = self.array_namespace()
xp = array_namespace(self)
h_inv = xp.linalg.inv(self.negative_half_precision)

@@ -87,3 +88,3 @@ h_inv = cast('JaxRealArray', h_inv)

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(x.shape[:-1])

@@ -127,3 +128,3 @@

def to_nat(self) -> MultivariateNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
precision = xp.linalg.inv(self.variance())

@@ -136,3 +137,3 @@ precision = cast('JaxRealArray', precision)

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -139,0 +140,0 @@

@@ -48,3 +48,3 @@ from __future__ import annotations

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
components = (-xp.square(self.mean_times_precision) / (4.0 * self.negative_half_precision)

@@ -56,3 +56,3 @@ + 0.5 * xp.log(-np.pi / self.negative_half_precision))

def to_exp(self) -> MultivariateDiagonalNormalEP:
xp = self.array_namespace()
xp = array_namespace(self)
mean = -self.mean_times_precision / (2.0 * self.negative_half_precision)

@@ -64,3 +64,3 @@ second_moment = xp.square(mean) - 0.5 / self.negative_half_precision

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(x.shape[:-1])

@@ -121,3 +121,3 @@

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -134,3 +134,3 @@

def variance(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return self.second_moment - xp.square(self.mean)

@@ -165,3 +165,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.mean.shape if shape is None else shape + self.mean.shape

@@ -180,3 +180,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape))

def to_exp(self) -> MultivariateDiagonalNormalEP:
xp = self.array_namespace()
xp = array_namespace(self)
second_moment = self.variance + xp.square(self.mean)

@@ -186,3 +186,3 @@ return MultivariateDiagonalNormalEP(self.mean, second_moment)

def to_nat(self) -> MultivariateDiagonalNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
precision = xp.reciprocal(self.variance)

@@ -189,0 +189,0 @@ mean_times_precision = self.mean * precision

from __future__ import annotations
import math
from typing import Any, Self
from typing import Self
import jax.random as jr
from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray, KeyArray, Shape

@@ -51,3 +52,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
eta = self.mean_times_precision

@@ -59,3 +60,3 @@ return 0.5 * (xp.sum(xp.square(eta), axis=-1) * self.variance

def to_exp(self) -> MultivariateFixedVarianceNormalEP:
xp = self.array_namespace()
xp = array_namespace(self)
return MultivariateFixedVarianceNormalEP(

@@ -67,3 +68,3 @@ self.mean_times_precision * self.variance[..., xp.newaxis],

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return -0.5 * xp.sum(xp.square(x), axis=-1) / self.variance

@@ -121,3 +122,3 @@

def to_nat(self) -> MultivariateFixedVarianceNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
return MultivariateFixedVarianceNormalNP(self.mean / self.variance[..., xp.newaxis],

@@ -128,3 +129,3 @@ variance=self.variance)

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return -0.5 * (xp.sum(xp.square(self.mean), axis=-1) / self.variance + self.dimensions())

@@ -134,3 +135,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.mean.shape if shape is None else shape + self.mean.shape

@@ -143,3 +144,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape))

def conjugate_prior_distribution(self, n: JaxRealArray) -> IsotropicNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
n_over_variance = n / self.variance

@@ -152,3 +153,3 @@ negative_half_precision = -0.5 * n_over_variance

@override
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any],
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization,
variance: JaxRealArray | None = None

@@ -158,3 +159,3 @@ ) -> tuple[Self, JaxRealArray]:

assert variance is not None
xp = cp.array_namespace()
xp = array_namespace(cp)
n_over_variance = -2.0 * cp.negative_half_precision

@@ -168,3 +169,3 @@ n = n_over_variance * variance

) -> MultivariateDiagonalNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
n_over_variance = n / self.variance[..., xp.newaxis]

@@ -171,0 +172,0 @@ negative_half_precision = -0.5 * n_over_variance

@@ -44,3 +44,3 @@ from __future__ import annotations

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
eta = self.mean_times_precision

@@ -52,3 +52,3 @@ return 0.5 * (-0.5 * xp.sum(xp.square(eta), axis=-1) / self.negative_half_precision

def to_exp(self) -> IsotropicNormalEP:
xp = self.array_namespace()
xp = array_namespace(self)
precision = -2.0 * self.negative_half_precision

@@ -111,3 +111,3 @@ mean = self.mean_times_precision / precision[..., xp.newaxis]

def to_nat(self) -> IsotropicNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
variance = self.variance()

@@ -120,3 +120,3 @@ negative_half_precision = -0.5 / variance

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -126,3 +126,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.mean.shape if shape is None else shape + self.mean.shape

@@ -138,4 +138,4 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape))

def variance(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
dimensions = self.dimensions()
return (self.total_second_moment - xp.sum(xp.square(self.mean), axis=-1)) / dimensions
from __future__ import annotations
import math
from typing import Any, Self
from typing import Self
import jax.random as jr
from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray, KeyArray, Shape

@@ -48,3 +49,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return 0.5 * (xp.sum(xp.square(self.mean), axis=-1)

@@ -60,3 +61,3 @@ + self.dimensions() * math.log(math.pi * 2.0))

# The second moment of a delta distribution at x.
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return -0.5 * xp.sum(xp.square(x), axis=-1)

@@ -72,3 +73,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.mean.shape if shape is None else shape + self.mean.shape

@@ -120,3 +121,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape))

# The second moment of a normal distribution with the given mean.
xp = self.array_namespace()
xp = array_namespace(self)
return -0.5 * (xp.sum(xp.square(self.mean), axis=-1) + self.dimensions())

@@ -130,3 +131,3 @@

def conjugate_prior_distribution(self, n: JaxRealArray) -> IsotropicNormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
negative_half_precision = -0.5 * n

@@ -137,6 +138,6 @@ return IsotropicNormalNP(n[..., xp.newaxis] * self.mean, negative_half_precision)

@override
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any]
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization
) -> tuple[Self, JaxRealArray]:
assert isinstance(cp, IsotropicNormalNP)
xp = cp.array_namespace()
xp = array_namespace(cp)
n = -2.0 * cp.negative_half_precision

@@ -143,0 +144,0 @@ mean = cp.mean_times_precision / n[..., xp.newaxis]

@@ -7,2 +7,3 @@ from __future__ import annotations

import jax.scipy.special as jss
from array_api_compat import array_namespace
from tjax import JaxIntegralArray, JaxRealArray, Shape

@@ -37,3 +38,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return -self._failures() * xp.log1p(-xp.exp(self.log_not_p))

@@ -48,3 +49,3 @@

def _mean(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return self._failures() / xp.expm1(-self.log_not_p)

@@ -79,3 +80,3 @@

def _log_not_p(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return -xp.log1p(self._failures() / self.mean)

@@ -82,0 +83,0 @@

@@ -44,3 +44,3 @@ from __future__ import annotations

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return (-xp.square(self.mean_times_precision) / (4.0 * self.negative_half_precision)

@@ -51,3 +51,3 @@ + 0.5 * xp.log(-np.pi / self.negative_half_precision))

def to_exp(self) -> NormalEP:
xp = self.array_namespace()
xp = array_namespace(self)
mean = -self.mean_times_precision / (2.0 * self.negative_half_precision)

@@ -59,3 +59,3 @@ second_moment = xp.square(mean) - 0.5 / self.negative_half_precision

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(x.shape)

@@ -115,3 +115,3 @@

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -124,3 +124,3 @@

def variance(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return self.second_moment - xp.square(self.mean)

@@ -167,3 +167,3 @@

def to_exp(self) -> NormalEP:
xp = self.array_namespace()
xp = array_namespace(self)
second_moment = self.variance + xp.square(self.mean)

@@ -173,3 +173,3 @@ return NormalEP(self.mean, second_moment)

def to_nat(self) -> NormalNP:
xp = self.array_namespace()
xp = array_namespace(self)
precision = xp.reciprocal(self.variance)

@@ -181,3 +181,3 @@ mean_times_precision = self.mean * precision

def to_deviation_parametrization(self) -> NormalDP:
xp = self.array_namespace()
xp = array_namespace(self)
return NormalDP(self.mean, xp.sqrt(self.variance))

@@ -209,3 +209,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -228,3 +228,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

def to_variance_parametrization(self) -> NormalVP:
xp = self.array_namespace()
xp = array_namespace(self)
return NormalVP(self.mean, xp.square(self.deviation))
from __future__ import annotations
import math
from typing import Any, Self
from typing import Self
import jax.random as jr
from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray, KeyArray, Shape

@@ -44,3 +45,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return 0.5 * (xp.square(self.mean) + math.log(math.pi * 2.0))

@@ -55,3 +56,3 @@

# The second moment of a delta distribution at x.
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return -0.5 * xp.square(x)

@@ -106,3 +107,3 @@

# The second moment of a normal distribution with the given mean.
xp = self.array_namespace()
xp = array_namespace(self)
return -0.5 * (xp.square(self.mean) + 1.0)

@@ -112,3 +113,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -125,3 +126,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

@override
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any]
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization
) -> tuple[Self, JaxRealArray]:

@@ -128,0 +129,0 @@ assert isinstance(cp, NormalNP)

from __future__ import annotations
from typing import Any, Self
from typing import Self
import jax.random as jr
import jax.scipy.special as jss
from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray, KeyArray, Shape

@@ -43,3 +44,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.exp(self.log_mean)

@@ -49,3 +50,3 @@

def to_exp(self) -> PoissonEP:
xp = self.array_namespace()
xp = array_namespace(self)
return PoissonEP(xp.exp(self.log_mean))

@@ -97,3 +98,3 @@

def to_nat(self) -> PoissonNP:
xp = self.array_namespace()
xp = array_namespace(self)
return PoissonNP(xp.log(self.mean))

@@ -115,3 +116,3 @@

@override
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any]
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization
) -> tuple[Self, JaxRealArray]:

@@ -118,0 +119,0 @@ assert isinstance(cp, GammaNP)

@@ -64,3 +64,3 @@ from __future__ import annotations

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.log(x) + math.log(2.0)

@@ -72,3 +72,3 @@

shape += self.shape
xp = self.array_namespace()
xp = array_namespace(self)
sigma = xp.sqrt(-0.5 / self.eta)

@@ -115,3 +115,3 @@ return jr.rayleigh(key, sigma, shape)

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return 0.5 * xp.log(self.chi * 0.5) + (1.5 * math.log(2.0) - 0.5 * np.euler_gamma)

@@ -123,4 +123,4 @@

shape += self.shape
xp = self.array_namespace()
xp = array_namespace(self)
sigma = xp.sqrt(0.5 * self.chi)
return jr.rayleigh(key, sigma, shape)

@@ -6,2 +6,3 @@ from __future__ import annotations

import jax.random as jr
from array_api_compat import array_namespace
from tjax import Array, JaxArray, JaxRealArray, KeyArray, Shape, inverse_softplus, softplus

@@ -60,3 +61,3 @@ from tjax.dataclasses import dataclass

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return -xp.log1p(-xp.exp(-x))

@@ -66,3 +67,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -108,3 +109,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -111,0 +112,0 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape))

from __future__ import annotations
from typing import Any, Self, cast
from typing import Self, cast
import jax.random as jr
from array_api_compat import array_namespace
from tjax import Array, JaxArray, JaxRealArray, KeyArray, Shape, inverse_softplus, softplus

@@ -61,3 +62,3 @@ from tjax.dataclasses import dataclass

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return -xp.log1p(-xp.exp(-x)) - 0.5 * xp.square(inverse_softplus(x))

@@ -106,3 +107,3 @@

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -119,3 +120,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

@override
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any]
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization
) -> tuple[Self, JaxRealArray]:

@@ -122,0 +123,0 @@ uvn, n = UnitVarianceNormalEP.from_conjugate_prior_distribution(cp)

@@ -6,2 +6,3 @@ from __future__ import annotations

from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray, Shape, inverse_softplus, softplus

@@ -42,3 +43,3 @@ from tjax.dataclasses import dataclass

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
half_k = self.dimensions() * 0.5

@@ -53,3 +54,3 @@ kappa = xp.linalg.vector_norm(self.mean_times_concentration, axis=-1)

def to_exp(self) -> VonMisesFisherEP:
xp = self.array_namespace()
xp = array_namespace(self)
q = self.mean_times_concentration

@@ -64,3 +65,3 @@ kappa: JaxRealArray = xp.linalg.vector_norm(q, axis=-1, keepdims=True)

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return xp.zeros(self.shape)

@@ -79,3 +80,3 @@

def kappa(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.linalg.vector_norm(self.mean_times_concentration, axis=-1)

@@ -86,3 +87,3 @@

raise ValueError
xp = self.array_namespace()
xp = array_namespace(self)
kappa = self.kappa()

@@ -124,3 +125,3 @@ angle = xp.where(kappa == 0.0,

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros(self.shape)

@@ -130,3 +131,3 @@

def initial_search_parameters(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
mu: JaxRealArray = xp.linalg.vector_norm(self.mean, axis=-1)

@@ -142,3 +143,3 @@ # 0 <= mu <= 1.0

def search_to_natural(self, search_parameters: JaxRealArray) -> VonMisesFisherNP:
xp = self.array_namespace()
xp = array_namespace(self)
kappa = softplus(search_parameters)

@@ -151,3 +152,3 @@ mu = xp.linalg.vector_norm(self.mean, axis=-1, keepdims=True)

def search_gradient(self, search_parameters: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
kappa = softplus(search_parameters)

@@ -154,0 +155,0 @@ mu = xp.linalg.vector_norm(self.mean, axis=-1)

@@ -44,3 +44,3 @@ from __future__ import annotations

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
return -xp.log(-self.eta) - xp.log(self.concentration)

@@ -50,3 +50,3 @@

def to_exp(self) -> WeibullEP:
xp = self.array_namespace()
xp = array_namespace(self)
return WeibullEP(self.concentration, -xp.reciprocal(self.eta))

@@ -56,3 +56,3 @@

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
return (self.concentration - 1.0) * xp.log(x)

@@ -106,3 +106,3 @@

def to_nat(self) -> WeibullNP:
xp = self.array_namespace()
xp = array_namespace(self)
return WeibullNP(self.concentration, -xp.reciprocal(self.chi))

@@ -112,3 +112,3 @@

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
k = self.concentration

@@ -120,3 +120,3 @@ one_minus_one_over_k = 1.0 - xp.reciprocal(k)

def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
shape = self.shape if shape is None else shape + self.shape

@@ -123,0 +123,0 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape))

from __future__ import annotations
from abc import abstractmethod
from typing import Any, Generic, TypeVar, final
from typing import Any, Generic, final
from array_api_compat import array_namespace
from tjax import JaxRealArray, jit
from typing_extensions import TypeVar

@@ -12,3 +14,3 @@ from .natural_parametrization import NaturalParametrization

NP = TypeVar('NP', bound=NaturalParametrization[Any, Any])
NP = TypeVar('NP', bound=NaturalParametrization, default=Any)

@@ -49,3 +51,3 @@

self_nat = self.to_nat()
xp = self.array_namespace()
xp = array_namespace(self)
difference = parameter_map(xp.subtract, self_nat, q)

@@ -52,0 +54,0 @@ return (parameter_dot_product(difference, self)

from __future__ import annotations
from abc import abstractmethod
from typing import Any, Self
from typing import Self

@@ -13,6 +13,6 @@ from tjax import JaxComplexArray, JaxRealArray

class HasConjugatePrior(ExpectationParametrization[Any]):
class HasConjugatePrior(ExpectationParametrization):
@abstractmethod
def conjugate_prior_distribution(self, n: JaxRealArray
) -> NaturalParametrization[Any, Any]:
) -> NaturalParametrization:
"""The conjugate prior distribution.

@@ -27,3 +27,3 @@

@abstractmethod
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any]
def from_conjugate_prior_distribution(cls, cp: NaturalParametrization
) -> tuple[Self, JaxRealArray]:

@@ -49,3 +49,3 @@ """Given a conjugate prior distribution, find the distribution and observation count.

def generalized_conjugate_prior_distribution(self, n: JaxRealArray
) -> NaturalParametrization[Any, Any]:
) -> NaturalParametrization:
"""A generalization of the conjugate prior distribution.

@@ -52,0 +52,0 @@

@@ -6,2 +6,3 @@ from collections.abc import Iterable, Mapping

from tjax import JaxComplexArray
from typing_extensions import TypeIs

@@ -13,2 +14,8 @@ from .parameter import Support

def is_string_mapping(x: object) -> TypeIs[Mapping[str, object]]:
if not isinstance(x, Mapping):
return False
return all(isinstance(xi, str) for xi in x)
def flatten_mapping(m: Mapping[str, Any], /) -> dict[Path, Any]:

@@ -18,8 +25,10 @@ """Flatten a nested mapping."""

def _flatten(m: Mapping[str, Mapping[str, Any] | Any], prefix: Path) -> None:
def _flatten(m: Mapping[str, Any], prefix: Path) -> None:
for key, value in m.items():
path = (*prefix, key)
if isinstance(value, Mapping):
if is_string_mapping(value):
_flatten(value, path)
continue
if isinstance(value, Mapping):
raise TypeError
result[path] = value

@@ -26,0 +35,0 @@ _flatten(m, ())

from __future__ import annotations
from dataclasses import KW_ONLY, field
from typing import Any, Generic, Self, TypeAlias, TypeVar
from typing import Any, Generic, Self, TypeAlias
from array_api_compat import array_namespace
from jax import vmap
from tjax import JaxRealArray, jit
from tjax.dataclasses import dataclass
from typing_extensions import override
from typing_extensions import TypeVar, override

@@ -16,3 +17,3 @@ from ...expectation_parametrization import ExpectationParametrization

NP = TypeVar('NP', bound=NaturalParametrization[Any, Any])
NP = TypeVar('NP', bound=NaturalParametrization, default=Any)
SP: TypeAlias = JaxRealArray

@@ -22,3 +23,3 @@

class ExpToNatMinimizer:
def solve(self, exp_to_nat: ExpToNat[Any]) -> SP:
def solve(self, exp_to_nat: ExpToNat) -> SP:
raise NotImplementedError

@@ -68,3 +69,3 @@

_, flattened = Flattener.flatten(self)
xp = self.array_namespace()
xp = array_namespace(self)
return xp.zeros_like(flattened)

@@ -71,0 +72,0 @@

from typing import Any, TypeAlias, TypeVar
import optimistix as optx
from array_api_compat import array_namespace
from jax import jit

@@ -27,4 +28,4 @@ from tjax import JaxRealArray

@override
def solve(self, exp_to_nat: ExpToNat[Any]) -> JaxRealArray:
xp = exp_to_nat.array_namespace()
def solve(self, exp_to_nat: ExpToNat) -> JaxRealArray:
xp = array_namespace(exp_to_nat)

@@ -31,0 +32,0 @@ @jit

from __future__ import annotations
from abc import abstractmethod
from typing import Any, Generic, TypeVar, final, override
from typing import Any, Generic, final, override
from jax.lax import stop_gradient
from tjax import JaxAbstractClass, JaxRealArray, abstract_jit, jit
from tjax import JaxAbstractClass, JaxRealArray, abstract_jit, jit, stop_gradient
from typing_extensions import TypeVar

@@ -14,3 +14,3 @@ from ..expectation_parametrization import ExpectationParametrization

NP = TypeVar('NP', bound=NaturalParametrization[Any, Any])
NP = TypeVar('NP', bound=NaturalParametrization, default=Any)

@@ -60,6 +60,6 @@

EP = TypeVar('EP', bound=HasEntropyEP[Any])
EP = TypeVar('EP', bound=HasEntropyEP, default=Any)
class HasEntropyNP(NaturalParametrization[EP, Any],
class HasEntropyNP(NaturalParametrization[EP],
HasEntropy,

@@ -66,0 +66,0 @@ Generic[EP]):

@@ -5,7 +5,8 @@ from __future__ import annotations

from functools import partial
from typing import Any, Generic, TypeVar, cast
from typing import Any, Generic, cast
from array_api_compat import array_namespace
from jax import jacobian, vmap
from tjax import JaxArray, JaxComplexArray, JaxRealArray, Shape
from typing_extensions import override
from typing_extensions import TypeVar, override

@@ -16,5 +17,5 @@ from ..expectation_parametrization import ExpectationParametrization

TEP = TypeVar('TEP', bound=ExpectationParametrization[Any])
NP = TypeVar('NP', bound=NaturalParametrization[Any, Any])
Domain = TypeVar('Domain', bound=JaxComplexArray)
TEP = TypeVar('TEP', bound=ExpectationParametrization, default=Any)
NP = TypeVar('NP', bound=NaturalParametrization, default=Any)
Domain = TypeVar('Domain', bound=JaxComplexArray, default=JaxComplexArray)

@@ -61,3 +62,3 @@

def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
xp = self.array_namespace(x)
xp = array_namespace(self, x)
casted_x = cast('Domain', x)

@@ -95,3 +96,3 @@ fixed_parameters = parameters(self, fixed=True, recurse=False)

TNP = TypeVar('TNP', bound=TransformedNaturalParametrization[Any, Any, Any, Any])
TNP = TypeVar('TNP', bound=TransformedNaturalParametrization, default=Any)

@@ -98,0 +99,0 @@

from __future__ import annotations
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar, final, get_type_hints
from typing import TYPE_CHECKING, Any, Generic, Self, final, get_type_hints
from array_api_compat import array_namespace
from jax import grad, jacfwd, vjp, vmap

@@ -10,2 +11,3 @@ from tjax import (JaxAbstractClass, JaxArray, JaxComplexArray, JaxRealArray, abstract_custom_jvp,

from tjax.dataclasses import dataclass
from typing_extensions import TypeVar

@@ -23,9 +25,9 @@ from .iteration import parameters

EP = TypeVar('EP', bound='ExpectationParametrization[Any]')
Domain = TypeVar('Domain', bound=JaxComplexArray | dict[str, Any])
EP = TypeVar('EP', bound='ExpectationParametrization', default=Any)
Domain = TypeVar('Domain', bound=JaxComplexArray | dict[str, Any], default=Any)
def log_normalizer_jvp(primals: tuple[NaturalParametrization[Any, Any]],
tangents: tuple[NaturalParametrization[Any, Any]],
) -> tuple[JaxRealArray, JaxRealArray]:
def _log_normalizer_jvp(primals: tuple[NaturalParametrization],
tangents: tuple[NaturalParametrization],
) -> tuple[JaxRealArray, JaxRealArray]:
"""The log-normalizer's special JVP vastly improves numerical stability."""

@@ -49,3 +51,3 @@ q, = primals

"""
@abstract_custom_jvp(log_normalizer_jvp)
@abstract_custom_jvp(_log_normalizer_jvp)
@abstract_jit

@@ -100,3 +102,3 @@ @abstractmethod

"""
xp = self.array_namespace()
xp = array_namespace(self)
return xp.exp(self.log_pdf(x))

@@ -125,3 +127,3 @@

"""
xp = self.array_namespace()
xp = array_namespace(self)
flattener, _ = Flattener.flatten(self)

@@ -141,3 +143,3 @@ fisher_matrix = self._fisher_information_matrix()

"""
xp = self.array_namespace()
xp = array_namespace(self)
fisher_information_diagonal = self.fisher_information_diagonal()

@@ -161,3 +163,3 @@ structure = Structure.create(self)

def jeffreys_prior(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
fisher_matrix = self._fisher_information_matrix()

@@ -164,0 +166,0 @@ return xp.sqrt(xp.linalg.det(fisher_matrix))

@@ -7,2 +7,3 @@ from __future__ import annotations

import array_api_extra as xpx
import jax.scipy.special as jss

@@ -53,3 +54,3 @@ import numpy as np

def general_array_namespace(x: RealNumeric) -> ModuleType:
def _general_array_namespace(x: RealNumeric) -> ModuleType:
if isinstance(x, float):

@@ -60,5 +61,5 @@ return np

def canonical_float_epsilon(xp: ModuleType) -> float:
dtype = xp.empty((), dtype=float).dtype # For Jax, this is canonicalize_dtype(float).
return float(xp.finfo(dtype).eps)
def _canonical_float_epsilon(xp: ModuleType) -> float:
dtype = xpx.default_dtype(xp)
return xp.finfo(dtype).eps

@@ -76,4 +77,4 @@

if self.min_open and self.minimum is not None:
xp = general_array_namespace(self.minimum)
eps = canonical_float_epsilon(xp)
xp = _general_array_namespace(self.minimum)
eps = _canonical_float_epsilon(xp)
self.minimum = xp.asarray(xp.maximum(

@@ -83,4 +84,4 @@ self.minimum + eps,

if self.max_open and self.maximum is not None:
xp = general_array_namespace(self.maximum)
eps = canonical_float_epsilon(xp)
xp = _general_array_namespace(self.maximum)
eps = _canonical_float_epsilon(xp)
self.maximum = xp.asarray(xp.minimum(

@@ -217,3 +218,3 @@ self.maximum - eps,

xp = array_namespace(x)
return xp.asarray(x, dtype=float)
return xp.astype(x, float)

@@ -223,7 +224,7 @@ @override

xp = array_namespace(y)
return xp.asarray(y, dtype=xp.bool_)
return xp.astype(y, xp.bool_)
@override
def generate(self, xp: Namespace, rng: Generator, shape: Shape, safety: float) -> JaxRealArray:
return xp.asarray(rng.binomial(1, 0.5, shape), dtype=xp.bool_)
return xp.astype(rng.binomial(1, 0.5, shape), xp.bool_)

@@ -243,3 +244,3 @@

xp = array_namespace(x)
return xp.asarray(x, dtype=float)
return xp.astype(x, float)

@@ -249,3 +250,3 @@ @override

xp = array_namespace(y)
return xp.asarray(y, dtype=int)
return xp.astype(y, int)

@@ -252,0 +253,0 @@ @override

@@ -12,2 +12,3 @@ from __future__ import annotations

from numpy.random import Generator
from opt_einsum import contract
from tjax import JaxArray, JaxRealArray, Shape

@@ -234,3 +235,3 @@ from typing_extensions import override

# Perform QR decomposition to obtain an orthogonal matrix Q
q, _ = np.linalg.qr(m)
q, _ = xp.linalg.qr(m)
# Generate Eigenvalues.

@@ -242,3 +243,3 @@ assert isinstance(self.ring, RealField | ComplexField)

# Return Q.T @ diag(eig) @ Q.
return xp.einsum('...ji,...j,...jk->...ik', xp.conj(q) if self.hermitian else q, eig, q)
return contract('...ji,...j,...jk->...ik', xp.conj(q) if self.hermitian else q, eig, q)
mt = xp.matrix_transpose(m)

@@ -245,0 +246,0 @@ if self.hermitian:

@@ -41,3 +41,3 @@ from __future__ import annotations

def array_namespace(self, *x: Any) -> ModuleType:
def __array_namespace__(self, api_version: str | None = None) -> ModuleType: # noqa: PLW3201
from .iteration import parameters # noqa: PLC0415

@@ -44,0 +44,0 @@ values = parameters(self).values()

@@ -120,3 +120,3 @@ from __future__ import annotations

for i in np.ndindex(*self.shape):
objects[i] = self.objects[i].as_multivariate_normal()
objects[i] = self.objects[i].as_multivariate_normal() # pyright: ignore
return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects)

@@ -135,3 +135,3 @@

for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].variance
retval[i] = self.objects[i].variance # pyright: ignore
return retval

@@ -143,3 +143,3 @@

for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].pseudo_variance
retval[i] = self.objects[i].pseudo_variance # pyright: ignore
return retval

@@ -117,3 +117,3 @@ from __future__ import annotations

for i in np.ndindex(*self.shape):
objects[i] = self.objects[i].as_multivariate_normal()
objects[i] = self.objects[i].as_multivariate_normal() # pyright: ignore
return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects)

@@ -132,3 +132,3 @@

for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].variance
retval[i] = self.objects[i].variance # pyright: ignore
return retval

@@ -140,3 +140,3 @@

for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].pseudo_variance
retval[i] = self.objects[i].pseudo_variance # pyright: ignore
return retval

@@ -40,3 +40,3 @@ from __future__ import annotations

for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].rvs(size=size, random_state=random_state)
retval[i] = self.objects[i].rvs(size=size, random_state=random_state) # pyright: ignore
return retval

@@ -47,3 +47,3 @@

for i in np.ndindex(*self.shape):
value = self.objects[i].pdf(x[i])
value = self.objects[i].pdf(x[i]) # pyright: ignore
if i == ():

@@ -57,3 +57,3 @@ return value

for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].entropy()
retval[i] = self.objects[i].entropy() # pyright: ignore
return retval

@@ -60,0 +60,0 @@

@@ -40,3 +40,3 @@ from __future__ import annotations

for i in np.ndindex(*shape):
objects[i] = ss.vonmises_fisher(mu[i], kappa[i])
objects[i] = ss.vonmises_fisher(mu[i], kappa[i]) # pyright: ignore
super().__init__(shape, rvs_shape, dtype, objects)

@@ -16,2 +16,3 @@ from dataclasses import fields

from ..natural_parametrization import NaturalParametrization
NP = TypeVar('NP', bound=NaturalParametrization)

@@ -33,5 +34,5 @@ T = TypeVar('T')

def create_simple_estimator(cls,
type_p: type[SimpleDistribution],
type_p: type[SP],
**fixed_parameters: JaxArray
) -> 'MaximumLikelihoodEstimator[Any]':
) -> 'MaximumLikelihoodEstimator[SP]':
"""Create an estimator for a simple expectation parametrization class.

@@ -57,8 +58,7 @@

@classmethod
def create_estimator_from_natural(cls, p: 'NaturalParametrization[Any, Any]'
) -> 'MaximumLikelihoodEstimator[Any]':
def create_estimator_from_natural(cls, p: 'NP') -> 'MaximumLikelihoodEstimator[NP]':
"""Create an estimator for a natural parametrization."""
infos = MaximumLikelihoodEstimator.create(p).to_exp().infos
fixed_parameters = parameters(p, fixed=True)
return cls(infos, fixed_parameters)
return MaximumLikelihoodEstimator(infos, fixed_parameters)

@@ -69,3 +69,3 @@ def sufficient_statistics(self, x: dict[str, Any] | JaxComplexArray) -> P:

from ..transform.joint import JointDistributionE # noqa: PLC0415
constructed: dict[Path, ExpectationParametrization[Any]] = {}
constructed: dict[Path, ExpectationParametrization] = {}

@@ -106,3 +106,3 @@ def g(info: SubDistributionInfo, x: JaxComplexArray) -> None:

def from_conjugate_prior_distribution(self,
cp: 'NaturalParametrization[Any, Any]'
cp: 'NaturalParametrization'
) -> tuple[P, JaxRealArray]:

@@ -109,0 +109,0 @@ from ..interfaces.conjugate_prior import HasConjugatePrior # noqa: PLC0415

from dataclasses import replace
from functools import partial
from typing import Any, Self, TypeVar, cast, overload
from typing import Any, Self, cast, overload

@@ -8,2 +8,3 @@ from array_api_compat import array_namespace

from tjax.dataclasses import dataclass, field
from typing_extensions import TypeVar

@@ -17,4 +18,4 @@ from ..iteration import parameters

P = TypeVar('P', bound=Distribution)
SP = TypeVar('SP', bound=SimpleDistribution)
P = TypeVar('P', bound=Distribution, default=Any)
SP = TypeVar('SP', bound=SimpleDistribution, default=Any)

@@ -97,3 +98,3 @@

"""
xp = p.array_namespace()
xp = array_namespace(p)
arrays = [x

@@ -100,0 +101,0 @@ for xs in cls._walk(partial(cls._make_flat, map_to_plane=map_to_plane), p)

@@ -13,3 +13,3 @@ from collections.abc import Generator, Iterable

) -> Generator[str]:
"""Return the parameter names in a distribution.
"""The parameter names in a distribution.

@@ -21,3 +21,3 @@ Args:

Returns:
The name of each parameter.
The parameter names.
"""

@@ -24,0 +24,0 @@ def _parameters(q: type[Distribution],

@@ -26,3 +26,3 @@ from collections.abc import Generator

) -> Generator[tuple[str, Support, ValueReceptacle]]:
"""Return the parameter supports in a distribution.
"""The parameter supports in a distribution.

@@ -29,0 +29,0 @@ Args:

from collections.abc import Callable, Iterable, Mapping
from typing import Any, Generic, TypeVar, cast
from typing import Any, Generic, cast
from numpy.random import Generator
from tjax import JaxComplexArray, Shape
from tjax import JaxComplexArray, JaxRealArray, Shape
from tjax.dataclasses import dataclass, field
from typing_extensions import TypeVar

@@ -26,4 +27,3 @@ from ..iteration import parameters

T = TypeVar('T')
P = TypeVar('P', bound=Distribution)
SP = TypeVar('SP', bound=SimpleDistribution)
P = TypeVar('P', bound=Distribution, default=Any)

@@ -51,3 +51,3 @@

def to_nat(self) -> 'Structure[Any]':
def to_nat(self) -> 'Structure':
from ..expectation_parametrization import ExpectationParametrization # noqa: PLC0415

@@ -61,5 +61,5 @@ infos = []

def to_exp(self) -> 'Structure[Any]':
def to_exp(self) -> 'Structure':
from ..natural_parametrization import NaturalParametrization # noqa: PLC0415
infos = []
infos: list[SubDistributionInfo] = []
for info in self.infos:

@@ -105,3 +105,3 @@ assert issubclass(info.type_, NaturalParametrization)

"""Generate a random distribution."""
path_and_values = {}
path_and_values: dict[tuple[str, ...], JaxRealArray] = {}
for info in self.infos:

@@ -108,0 +108,0 @@ for name, support, value_receptacle in parameter_supports(info.type_):

@@ -21,3 +21,3 @@ from __future__ import annotations

@jit
def parameter_dot_product(x: NaturalParametrization[Any, Any], y: Any, /) -> JaxRealArray:
def parameter_dot_product(x: NaturalParametrization, y: Any, /) -> JaxRealArray:
"""Return the vectorized dot product over all of the variable parameters."""

@@ -43,3 +43,3 @@ def dotted_fields() -> Iterable[JaxRealArray]:

"""Return the mean of the parameters (fixed and variable)."""
xp = x.array_namespace()
xp = array_namespace(x)
structure = Structure.create(x)

@@ -46,0 +46,0 @@ p = parameters(x)

@@ -5,2 +5,3 @@ from collections.abc import Callable, Mapping

from array_api_compat import array_namespace
from tjax import JaxArray, JaxComplexArray, JaxRealArray, KeyArray, RngStream, Shape

@@ -59,6 +60,6 @@ from tjax.dataclasses import dataclass

HasEntropyEP['JointDistributionN']):
_sub_distributions: Mapping[str, ExpectationParametrization[Any]]
_sub_distributions: Mapping[str, ExpectationParametrization]
@override
def sub_distributions(self) -> Mapping[str, ExpectationParametrization[Any]]:
def sub_distributions(self) -> Mapping[str, ExpectationParametrization]:
return self._sub_distributions

@@ -78,5 +79,5 @@

def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)
def f(x: ExpectationParametrization[Any]) -> JaxRealArray:
def f(x: ExpectationParametrization) -> JaxRealArray:
assert isinstance(x, HasEntropyEP)

@@ -92,6 +93,6 @@ return x.expected_carrier_measure()

NaturalParametrization[JointDistributionE, dict[str, Any]]):
_sub_distributions: Mapping[str, NaturalParametrization[Any, Any]]
_sub_distributions: Mapping[str, NaturalParametrization]
@override
def sub_distributions(self) -> Mapping[str, NaturalParametrization[Any, Any]]:
def sub_distributions(self) -> Mapping[str, NaturalParametrization]:
return self._sub_distributions

@@ -101,3 +102,3 @@

def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)

@@ -114,3 +115,3 @@ return reduce(xp.add,

def carrier_measure(self, x: dict[str, Any]) -> JaxRealArray:
xp = self.array_namespace()
xp = array_namespace(self)

@@ -117,0 +118,0 @@ joined = join_mappings(sub=self._sub_distributions, x=x)

Metadata-Version: 2.4
Name: efax
Version: 1.22.0
Version: 1.22.1
Summary: Exponential families for JAX
Project-URL: repository, https://github.com/NeilGirdhar/efax
Project-URL: source, https://github.com/NeilGirdhar/efax
Author: Neil Girdhar

@@ -26,5 +26,6 @@ Author-email: mistersheik@gmail.com

Requires-Dist: array-api-compat>=1.10
Requires-Dist: array-api-extra>=0.7
Requires-Dist: array-api-extra>=0.8
Requires-Dist: jax>=0.6.1
Requires-Dist: numpy>=1.25
Requires-Dist: opt-einsum>=3.4
Requires-Dist: optimistix>=0.0.9

@@ -35,3 +36,3 @@ Requires-Dist: optype>=0.8.0

Requires-Dist: tensorflow-probability>=0.15
Requires-Dist: tjax>=1.3.1
Requires-Dist: tjax>=1.3.10
Requires-Dist: typing-extensions>=4.8

@@ -44,6 +45,6 @@ Provides-Extra: dev

Requires-Dist: pylint>=3.3; extra == 'dev'
Requires-Dist: pyright>=0.0.13; extra == 'dev'
Requires-Dist: pyright>=1.1.401; extra == 'dev'
Requires-Dist: pytest-ordering; extra == 'dev'
Requires-Dist: pytest-xdist[psutil]>=3; extra == 'dev'
Requires-Dist: pytest>=8; extra == 'dev'
Requires-Dist: pytest>=8.4; extra == 'dev'
Requires-Dist: ruff>=0.9.10; extra == 'dev'

@@ -50,0 +51,0 @@ Description-Content-Type: text/x-rst

@@ -7,3 +7,3 @@ [build-system]

name = "efax"
version = "1.22.0"
version = "1.22.1"
description = "Exponential families for JAX"

@@ -31,5 +31,6 @@ readme = "README.rst"

"array_api_compat>=1.10",
"array_api_extra>=0.7",
"array_api_extra>=0.8",
"jax>=0.6.1",
"numpy>=1.25",
"opt-einsum>=3.4",
"optimistix>=0.0.9",

@@ -40,3 +41,3 @@ "optype>=0.8.0",

"tensorflow_probability>=0.15",
"tjax>=1.3.1",
"tjax>=1.3.10",
"typing_extensions>=4.8",

@@ -52,6 +53,6 @@ ]

"pylint>=3.3",
"pyright>=0.0.13",
"pyright>=1.1.401",
"pytest-ordering",
"pytest-xdist[psutil]>=3",
"pytest>=8",
"pytest>=8.4",
"ruff>=0.9.10",

@@ -61,3 +62,3 @@ ]

[project.urls]
repository = "https://github.com/NeilGirdhar/efax"
source = "https://github.com/NeilGirdhar/efax"

@@ -64,0 +65,0 @@ [tool.isort]

@@ -62,6 +62,6 @@ from __future__ import annotations

def distribution_name(request: pytest.FixtureRequest) -> str | None:
return request.config.getoption("--distribution") # pyright: ignore
return request.config.getoption("--distribution")
def supports(s: Structure[Any], abc: type[Any]) -> bool:
def supports(s: Structure, abc: type[Any]) -> bool:
return all(issubclass(info.type_, abc) or issubclass(info.type_, JointDistribution)

@@ -71,3 +71,3 @@ for info in s.infos)

def any_integral_supports(structure: Structure[Any]) -> bool:
def any_integral_supports(structure: Structure) -> bool:
return any(isinstance(s.ring, BooleanRing | IntegralRing)

@@ -74,0 +74,0 @@ for s in structure.domain_support().values())

@@ -309,3 +309,3 @@ from __future__ import annotations

class JointInfo(DistributionInfo[JointDistributionN, JointDistributionE, dict[str, Any]]):
def __init__(self, infos: Mapping[str, DistributionInfo[Any, Any, Any]]) -> None:
def __init__(self, infos: Mapping[str, DistributionInfo]) -> None:
super().__init__()

@@ -396,3 +396,3 @@ self.infos = dict(infos)

variance = np.asarray(p.variance())
covariance = xpx.create_diagonal(variance) # type: ignore[arg-type]
covariance = xpx.create_diagonal(variance) # type: ignore[arg-type] # pyright: ignore
assert isinstance(covariance, np.ndarray) # type: ignore[unreachable]

@@ -643,3 +643,3 @@ return ScipyMultivariateNormal.from_mc( # type: ignore[unreachable]

def create_infos() -> list[DistributionInfo[Any, Any, Any]]:
def create_infos() -> list[DistributionInfo]:
return [

@@ -646,0 +646,0 @@ BernoulliInfo(),

from __future__ import annotations
from collections.abc import Callable
from typing import Any, Generic, TypeVar, final
from typing import Any, Generic, final

@@ -10,9 +10,9 @@ import jax.numpy as jnp

from tjax import JaxComplexArray, NumpyComplexArray, Shape
from typing_extensions import override
from typing_extensions import TypeVar, override
from efax import ExpectationParametrization, NaturalParametrization, Structure, SubDistributionInfo
NP = TypeVar('NP', bound=NaturalParametrization[Any, Any])
EP = TypeVar('EP', bound=ExpectationParametrization[Any])
Domain = TypeVar('Domain', bound=NumpyComplexArray | dict[str, Any])
NP = TypeVar('NP', bound=NaturalParametrization, default=Any)
EP = TypeVar('EP', bound=ExpectationParametrization, default=Any)
Domain = TypeVar('Domain', bound=NumpyComplexArray | dict[str, Any], default=Any)

@@ -19,0 +19,0 @@

@@ -13,8 +13,8 @@ from __future__ import annotations

# Tools --------------------------------------------------------------------------------------------
def random_complex_array(generator: Generator, shape: Shape = ()) -> NumpyComplexArray:
def _random_complex_array(generator: Generator, shape: Shape = ()) -> NumpyComplexArray:
return np.asarray(sum(x * generator.normal(size=shape) for x in (0.5, 0.5j)))
def build_uvcn(generator: Generator, shape: Shape) -> ScipyComplexNormal:
mean = random_complex_array(generator, shape)
def _build_uvcn(generator: Generator, shape: Shape) -> ScipyComplexNormal:
mean = _random_complex_array(generator, shape)
variance = generator.exponential(size=shape)

@@ -27,11 +27,11 @@ pseudo_variance = (variance

def build_mvcn(generator: Generator,
shape: Shape,
dimensions: int,
polarization: float = 0.98
) -> ScipyComplexMultivariateNormal:
def _build_mvcn(generator: Generator,
shape: Shape,
dimensions: int,
polarization: float = 0.98
) -> ScipyComplexMultivariateNormal:
directions = 3
weights = np.asarray(range(directions)) + 1.5
mean = random_complex_array(generator, (*shape, dimensions))
z = random_complex_array(generator, (*shape, dimensions, directions))
mean = _random_complex_array(generator, (*shape, dimensions))
z = _random_complex_array(generator, (*shape, dimensions, directions))
regularizer = np.tile(np.eye(dimensions), (*shape, 1, 1))

@@ -56,3 +56,3 @@ variance = (

rvs_axes = tuple(range(-len(rvs_shape), 0))
dist = build_uvcn(generator, shape)
dist = _build_uvcn(generator, shape)
rvs = dist.rvs(random_state=generator, size=rvs_shape)

@@ -79,3 +79,3 @@ assert rvs.shape == shape + rvs_shape

rvs_axes2 = tuple(range(-len(rvs_shape) - 2, -2))
dist = build_mvcn(generator, shape, dimensions)
dist = _build_mvcn(generator, shape, dimensions)
rvs = dist.rvs(random_state=generator, size=rvs_shape)

@@ -99,3 +99,3 @@ assert rvs.shape == shape + rvs_shape + (dimensions,)

def test_univariate_multivariate_consistency(generator: Generator) -> None:
mv = build_mvcn(generator, (), 1, polarization=0.5)
mv = _build_mvcn(generator, (), 1, polarization=0.5)
component = mv.access_object(())

@@ -106,3 +106,3 @@ mean: NumpyComplexArray = np.asarray(component.mean[0])

uv = ScipyComplexNormal(mean, variance, pseudo_variance)
x = random_complex_array(generator)
x = _random_complex_array(generator)
assert_allclose(mv.pdf(x[np.newaxis]), uv.pdf(x))
from __future__ import annotations
from typing import Any
import jax.numpy as jnp

@@ -17,3 +15,3 @@ from jax import grad, vmap

def test_conjugate_prior(generator: Generator,
cp_distribution_info: DistributionInfo[Any, Any, Any],
cp_distribution_info: DistributionInfo,
distribution_name: str | None) -> None:

@@ -50,3 +48,3 @@ """Test that the conjugate prior actually matches the distribution."""

def test_from_conjugate_prior(generator: Generator,
cp_distribution_info: DistributionInfo[Any, Any, Any],
cp_distribution_info: DistributionInfo,
distribution_name: str | None) -> None:

@@ -71,3 +69,3 @@ """Test that the conjugate prior is reversible."""

def test_generalized_conjugate_prior(generator: Generator,
gcp_distribution_info: DistributionInfo[Any, Any, Any],
gcp_distribution_info: DistributionInfo,
distribution_name: str | None

@@ -74,0 +72,0 @@ ) -> None:

@@ -1,2 +0,2 @@

"""These tests are related to entropy."""
"""These tests verify entropy gradients."""
from __future__ import annotations

@@ -23,3 +23,3 @@

@jit
def sum_entropy(flattened: JaxRealArray, flattener: Flattener[Any]) -> JaxRealArray:
def _sum_entropy(flattened: JaxRealArray, flattener: Flattener) -> JaxRealArray:
x = flattener.unflatten(flattened)

@@ -30,11 +30,11 @@ return jnp.sum(x.entropy())

@jit
def all_finite(some_tree: Any, /) -> JaxBooleanArray:
def _all_finite(some_tree: Any, /) -> JaxBooleanArray:
return dynamic_tree_all(tree.map(lambda x: jnp.all(jnp.isfinite(x)), some_tree))
def check_entropy_gradient(distribution: HasEntropy, /) -> None:
def _check_entropy_gradient(distribution: HasEntropy, /) -> None:
flattener, flattened = Flattener.flatten(distribution, map_to_plane=False)
p_sum_entropy = partial(sum_entropy, flattener=flattener)
p_sum_entropy = partial(_sum_entropy, flattener=flattener)
calculated_gradient = grad(p_sum_entropy)(flattened)
if not all_finite(calculated_gradient):
if not _all_finite(calculated_gradient):
indices = jnp.argwhere(jnp.isnan(calculated_gradient))

@@ -51,15 +51,15 @@ bad_distributions = [distribution[tuple(index)[:-1]] for index in indices]

def test_nat_entropy_gradient(generator: Generator,
entropy_distribution_info: DistributionInfo[Any, Any, Any],
entropy_distribution_info: DistributionInfo,
) -> None:
shape = (7, 13)
nat_parameters = entropy_distribution_info.nat_parameter_generator(generator, shape=shape)
check_entropy_gradient(nat_parameters)
_check_entropy_gradient(nat_parameters)
def test_exp_entropy_gradient(generator: Generator,
entropy_distribution_info: DistributionInfo[Any, Any, Any],
entropy_distribution_info: DistributionInfo,
) -> None:
shape = (7, 13)
exp_parameters = entropy_distribution_info.exp_parameter_generator(generator, shape=shape)
check_entropy_gradient(exp_parameters)
_check_entropy_gradient(exp_parameters)

@@ -66,0 +66,0 @@

from __future__ import annotations
from typing import Any
import jax.numpy as jnp

@@ -57,3 +55,3 @@ import numpy as np

def test_fisher_information_is_convex(generator: Generator,
distribution_info: DistributionInfo[Any, Any, Any]) -> None:
distribution_info: DistributionInfo) -> None:
shape = (3, 2)

@@ -60,0 +58,0 @@ nat_parameters = distribution_info.nat_parameter_generator(generator, shape=shape)

from __future__ import annotations
from typing import Any
import jax.numpy as jnp

@@ -18,3 +16,3 @@ import numpy as np

def test_flatten(generator: Generator,
distribution_info: DistributionInfo[Any, Any, Any],
distribution_info: DistributionInfo,
*,

@@ -21,0 +19,0 @@ natural: bool,

"""These tests apply to only samplable distributions."""
from __future__ import annotations
from typing import Any
import jax.numpy as jnp

@@ -18,3 +16,3 @@ import jax.random as jr

def _sample_using_flattened(flattened_parameters: JaxRealArray,
flattener: Flattener[Any],
flattener: Flattener,
key: KeyArray,

@@ -77,3 +75,3 @@ ) -> JaxArray:

key: KeyArray,
sampling_wc_distribution_info: DistributionInfo[Any, Any, Any],
sampling_wc_distribution_info: DistributionInfo,
*,

@@ -80,0 +78,0 @@ distribution_name: str | None,

from __future__ import annotations
from typing import Any
import numpy as np

@@ -19,6 +17,6 @@ from numpy.linalg import det, inv

def prelude(generator: Generator,
distribution_info_kl: DistributionInfo[Any, Any, Any],
distribution_name: str | None
) -> tuple[ExpectationParametrization[Any], NaturalParametrization[Any, Any],
def _prelude(generator: Generator,
distribution_info_kl: DistributionInfo,
distribution_name: str | None
) -> tuple[ExpectationParametrization, NaturalParametrization,
JaxRealArray]:

@@ -35,3 +33,3 @@ shape = (3, 2)

"""Test the KL divergence."""
x, y, my_kl = prelude(generator, NormalInfo(), distribution_name)
x, y, my_kl = _prelude(generator, NormalInfo(), distribution_name)
assert isinstance(x, NormalEP)

@@ -50,3 +48,3 @@ assert isinstance(y, NormalNP)

"""Test the KL divergence."""
x, y, my_kl = prelude(generator, MultivariateNormalInfo(dimensions=4), distribution_name)
x, y, my_kl = _prelude(generator, MultivariateNormalInfo(dimensions=4), distribution_name)
assert isinstance(x, MultivariateNormalEP)

@@ -69,3 +67,3 @@ assert isinstance(y, MultivariateNormalNP)

"""Test the KL divergence."""
x, y, my_kl = prelude(generator, GammaInfo(), distribution_name)
x, y, my_kl = _prelude(generator, GammaInfo(), distribution_name)
assert isinstance(x, GammaEP)

@@ -72,0 +70,0 @@ assert isinstance(y, GammaNP)

@@ -19,15 +19,15 @@ """These tests apply to only samplable distributions."""

Path: TypeAlias = tuple[str, ...]
_Path: TypeAlias = tuple[str, ...]
def produce_samples(generator: Generator,
key: KeyArray,
sampling_distribution_info: DistributionInfo[NaturalParametrization[Any, Any],
ExpectationParametrization[Any],
Any],
distribution_shape: Shape,
sample_shape: Shape,
*,
natural: bool) -> tuple[ExpectationParametrization[Any],
dict[str, Any] | JaxComplexArray]:
def _produce_samples(generator: Generator,
key: KeyArray,
sampling_distribution_info: DistributionInfo[NaturalParametrization,
ExpectationParametrization,
Any],
distribution_shape: Shape,
sample_shape: Shape,
*,
natural: bool) -> tuple[ExpectationParametrization,
dict[str, Any] | JaxComplexArray]:
sampling_object: Distribution

@@ -52,7 +52,7 @@ if natural:

def verify_sample_shape(distribution_shape: Shape,
sample_shape: Shape,
structure: Structure[ExpectationParametrization[Any]],
flat_map_of_samples: dict[Path, Any]
) -> None:
def _verify_sample_shape(distribution_shape: Shape,
sample_shape: Shape,
structure: Structure[ExpectationParametrization],
flat_map_of_samples: dict[_Path, Any]
) -> None:
ideal_samples_shape = {info.path: (*sample_shape, *distribution_shape,

@@ -66,9 +66,9 @@ *info.type_.domain_support().shape(info.dimensions))

def verify_maximum_likelihood_estimate(
sampling_distribution_info: DistributionInfo[NaturalParametrization[Any, Any],
ExpectationParametrization[Any],
def _verify_maximum_likelihood_estimate(
sampling_distribution_info: DistributionInfo[NaturalParametrization,
ExpectationParametrization,
Any],
sample_shape: Shape,
structure: Structure[ExpectationParametrization[Any]],
exp_parameters: ExpectationParametrization[Any],
structure: Structure[ExpectationParametrization],
exp_parameters: ExpectationParametrization,
samples: dict[str, Any] | JaxComplexArray

@@ -96,3 +96,3 @@ ) -> None:

key: KeyArray,
sampling_distribution_info: DistributionInfo[Any, Any, Any],
sampling_distribution_info: DistributionInfo,
*,

@@ -109,9 +109,9 @@ distribution_name: str | None,

sample_shape = (1024, 64) # The number of samples that are taken to do the estimation.
exp_parameters, samples = produce_samples(generator, key, sampling_distribution_info,
distribution_shape, sample_shape, natural=natural)
exp_parameters, samples = _produce_samples(generator, key, sampling_distribution_info,
distribution_shape, sample_shape, natural=natural)
flat_map_of_samples = flat_dict_of_observations(samples)
structure = Structure.create(exp_parameters)
flat_map_of_samples = flatten_mapping(samples) if isinstance(samples, dict) else {(): samples}
verify_sample_shape(distribution_shape, sample_shape, structure, flat_map_of_samples)
verify_maximum_likelihood_estimate(sampling_distribution_info, sample_shape, structure,
exp_parameters, samples)
_verify_sample_shape(distribution_shape, sample_shape, structure, flat_map_of_samples)
_verify_maximum_likelihood_estimate(sampling_distribution_info, sample_shape, structure,
exp_parameters, samples)
from __future__ import annotations
from typing import Any
import numpy as np
from numpy.random import Generator

@@ -13,3 +10,3 @@

def test_shapes(generator: Generator, distribution_info: DistributionInfo[Any, Any, Any]) -> None:
def test_shapes(generator: Generator, distribution_info: DistributionInfo) -> None:
"""Test that the methods produce the correct shapes."""

@@ -43,7 +40,1 @@ shape = (3, 4)

assert q.pdf(x).shape == shape
def test_types(distribution_info: DistributionInfo[Any, Any, Any]) -> None:
if isinstance(distribution_info.exp_parameter_generator(np.random.default_rng(), ()), tuple):
msg = "This should return a number or an ndarray"
raise TypeError(msg)
from __future__ import annotations
from collections.abc import Callable
from typing import Any, TypeAlias
import pytest
from jax import grad, jvp, vjp
from jax.custom_derivatives import zero_from_primal
from numpy.random import Generator
from numpy.testing import assert_allclose
from tjax import JaxRealArray, assert_tree_allclose, jit, print_generic, tree_allclose
from efax import NaturalParametrization, Structure, parameters
from .create_info import GeneralizedDirichletInfo
from .distribution_info import DistributionInfo
LogNormalizer: TypeAlias = Callable[[NaturalParametrization[Any, Any]], JaxRealArray]
def test_conversion(generator: Generator,
distribution_info: DistributionInfo[Any, Any, Any]
) -> None:
"""Test that the conversion between the different parametrizations are consistent."""
if isinstance(distribution_info, GeneralizedDirichletInfo):
pytest.skip()
n = 30
shape = (n,)
original_np = distribution_info.nat_parameter_generator(generator, shape=shape)
intermediate_ep = original_np.to_exp()
final_np = intermediate_ep.to_nat()
# Check round trip.
if not tree_allclose(final_np, original_np):
for i in range(n):
if not tree_allclose(final_np[i], original_np[i]):
print_generic({"original": original_np[i],
"intermediate": intermediate_ep[i],
"final": final_np[i]})
pytest.fail("Conversion failure")
# Check fixed parameters.
original_fixed = parameters(original_np, fixed=True)
intermediate_fixed = parameters(intermediate_ep, fixed=True)
final_fixed = parameters(final_np, fixed=True)
assert_tree_allclose(original_fixed, intermediate_fixed)
assert_tree_allclose(original_fixed, final_fixed)
def prelude(generator: Generator,
distribution_info: DistributionInfo[Any, Any, Any]
) -> tuple[LogNormalizer, LogNormalizer]:
cls = distribution_info.nat_class()
original_ln = cls._original_log_normalizer
optimized_ln = cls.log_normalizer
return original_ln, optimized_ln
def test_gradient_log_normalizer_primals(generator: Generator,
distribution_info: DistributionInfo[Any, Any, Any]
) -> None:
"""Tests that the gradient log-normalizer equals the gradient of the log-normalizer."""
original_ln, optimized_ln = prelude(generator, distribution_info)
original_gln = jit(grad(original_ln, allow_int=True))
optimized_gln = jit(grad(optimized_ln, allow_int=True))
for _ in range(20):
generated_np = distribution_info.nat_parameter_generator(generator, shape=())
generated_ep = generated_np.to_exp() # Regular transformation.
generated_parameters = parameters(generated_ep, fixed=False)
structure_ep = Structure.create(generated_ep)
# Original GLN.
original_gln_np = original_gln(generated_np)
original_gln_ep = structure_ep.reinterpret(original_gln_np)
original_gln_parameters = parameters(original_gln_ep, fixed=False)
# Optimized GLN.
optimized_gln_np = optimized_gln(generated_np)
optimized_gln_ep = structure_ep.reinterpret(optimized_gln_np)
optimized_gln_parameters = parameters(optimized_gln_ep, fixed=False)
# Test primal evaluation.
# parameters(generated_ep, fixed=False)
assert_tree_allclose(generated_parameters, original_gln_parameters, rtol=1e-5)
assert_tree_allclose(generated_parameters, optimized_gln_parameters, rtol=1e-5)
def unit_tangent(nat_parameters: NaturalParametrization[Any, Any]
) -> NaturalParametrization[Any, Any]:
xp = nat_parameters.array_namespace()
new_variable_parameters = {path: xp.ones_like(value)
for path, value in parameters(nat_parameters, fixed=False).items()}
new_fixed_parameters = {path: zero_from_primal(value, symbolic_zeros=False)
for path, value in parameters(nat_parameters, fixed=True).items()}
structure = Structure.create(nat_parameters)
return structure.assemble({**new_variable_parameters, **new_fixed_parameters})
def test_gradient_log_normalizer_jvp(generator: Generator,
distribution_info: DistributionInfo[Any, Any, Any]
) -> None:
"""Tests that the gradient log-normalizer equals the gradient of the log-normalizer."""
original_ln, optimized_ln = prelude(generator, distribution_info)
for _ in range(20):
nat_parameters = distribution_info.nat_parameter_generator(generator, shape=())
# Test JVP.
nat_tangent = unit_tangent(nat_parameters)
original_ln_of_nat, original_jvp = jvp(original_ln, (nat_parameters,), (nat_tangent,))
optimized_ln_of_nat, optimized_jvp = jvp(optimized_ln, (nat_parameters,), (nat_tangent,))
assert_allclose(original_ln_of_nat, optimized_ln_of_nat, rtol=1.5e-5)
assert_allclose(original_jvp, optimized_jvp, rtol=1.5e-5)
def test_gradient_log_normalizer_vjp(generator: Generator,
distribution_info: DistributionInfo[Any, Any, Any]
) -> None:
"""Tests that the gradient log-normalizer equals the gradient of the log-normalizer."""
original_ln, optimized_ln = prelude(generator, distribution_info)
for _ in range(20):
nat_parameters = distribution_info.nat_parameter_generator(generator, shape=())
nat_tangent = unit_tangent(nat_parameters)
original_ln_of_nat, _ = jvp(original_ln, (nat_parameters,), (nat_tangent,))
original_ln_of_nat_b, original_vjp = vjp(original_ln, nat_parameters)
original_gln_of_nat, = original_vjp(1.0)
optimized_ln_of_nat_b, optimized_vjp = vjp(optimized_ln, nat_parameters)
optimized_gln_of_nat, = optimized_vjp(1.0)
assert_allclose(original_ln_of_nat_b, optimized_ln_of_nat_b, rtol=1e-5)
assert_allclose(original_ln_of_nat, original_ln_of_nat_b, rtol=1e-5)
assert_tree_allclose(parameters(original_gln_of_nat, fixed=False),
parameters(optimized_gln_of_nat, fixed=False),
rtol=1e-5)
"""These tests ensure that our distributions match scipy's."""
from __future__ import annotations
from functools import partial
from typing import Any
import numpy as np
from jax import Array
from numpy.random import Generator
from numpy.testing import assert_allclose
from tjax import JaxComplexArray, assert_tree_allclose
from efax import (HasEntropyEP, HasEntropyNP, JointDistributionN, MaximumLikelihoodEstimator,
Multidimensional, NaturalParametrization, SimpleDistribution,
flat_dict_of_observations, flat_dict_of_parameters, parameter_map,
unflatten_mapping)
from .create_info import (ChiInfo, ChiSquareInfo, ComplexCircularlySymmetricNormalInfo,
IsotropicNormalInfo, MultivariateDiagonalNormalInfo,
MultivariateFixedVarianceNormalInfo, MultivariateNormalInfo,
VonMisesFisherInfo)
from .distribution_info import DistributionInfo
def test_nat_entropy(generator: Generator,
entropy_distribution_info: DistributionInfo[Any, Any, Any]
) -> None:
"""Test that the entropy calculation matches scipy's."""
shape = (7, 13)
nat_parameters = entropy_distribution_info.nat_parameter_generator(generator, shape=shape)
assert isinstance(nat_parameters, HasEntropyNP)
scipy_distribution = entropy_distribution_info.nat_to_scipy_distribution(nat_parameters)
rtol = 2e-5
my_entropy = nat_parameters.entropy()
scipy_entropy = scipy_distribution.entropy()
assert_allclose(my_entropy, scipy_entropy, rtol=rtol)
def test_exp_entropy(generator: Generator,
entropy_distribution_info: DistributionInfo[Any, Any, Any]
) -> None:
"""Test that the entropy calculation matches scipy's."""
shape = (7, 13)
exp_parameters = entropy_distribution_info.exp_parameter_generator(generator, shape=shape)
assert isinstance(exp_parameters, HasEntropyEP)
scipy_distribution = entropy_distribution_info.exp_to_scipy_distribution(exp_parameters)
rtol = (1e-5
if isinstance(entropy_distribution_info, ChiInfo | ChiSquareInfo)
else 1e-6)
my_entropy = exp_parameters.entropy()
scipy_entropy = scipy_distribution.entropy()
assert_allclose(my_entropy, scipy_entropy, rtol=rtol)
def check_observation_shape(nat_parameters: NaturalParametrization[Any, Any],
efax_x: JaxComplexArray | dict[str, Any],
distribution_shape: tuple[int, ...],
) -> None:
"""Verify that the sufficient statistics have the right shape."""
if isinstance(nat_parameters, JointDistributionN):
assert isinstance(efax_x, dict)
for name, value in nat_parameters.sub_distributions().items():
check_observation_shape(value, efax_x[name], distribution_shape)
return
assert isinstance(nat_parameters, SimpleDistribution) # type: ignore[unreachable]
assert isinstance(efax_x, Array) # type: ignore[unreachable]
dimensions = (nat_parameters.dimensions()
if isinstance(nat_parameters, Multidimensional)
else 0)
ideal_shape = distribution_shape + nat_parameters.domain_support().shape(dimensions)
assert efax_x.shape == ideal_shape
def test_pdf(generator: Generator, distribution_info: DistributionInfo[Any, Any, Any]) -> None:
"""Test that the density/mass function calculation matches scipy's."""
distribution_shape = (10,)
nat_parameters = distribution_info.nat_parameter_generator(generator, shape=distribution_shape)
scipy_distribution = distribution_info.nat_to_scipy_distribution(nat_parameters)
scipy_x = scipy_distribution.rvs(random_state=generator)
efax_x = distribution_info.scipy_to_exp_family_observation(scipy_x)
check_observation_shape(nat_parameters, efax_x, distribution_shape)
# Verify that the density matches scipy.
efax_density = np.asarray(nat_parameters.pdf(efax_x), dtype=np.float64)
try:
scipy_density = scipy_distribution.pdf(scipy_x)
except AttributeError:
scipy_density = scipy_distribution.pmf(scipy_x)
if isinstance(distribution_info, MultivariateDiagonalNormalInfo):
atol = 1e-5
rtol = 3e-4
else:
atol = 1e-5
rtol = 1e-4
assert_allclose(efax_density, scipy_density, rtol=rtol, atol=atol)
def test_maximum_likelihood_estimation(
generator: Generator,
distribution_info: DistributionInfo[NaturalParametrization[Any, Any], Any, Any]
) -> None:
"""Test maximum likelihood estimation using SciPy.
Test that maximum likelihood estimation from scipy-generated variates produce the same
distribution from which they were drawn.
"""
rtol = 2e-2
if isinstance(distribution_info,
ComplexCircularlySymmetricNormalInfo | MultivariateNormalInfo
| VonMisesFisherInfo | MultivariateFixedVarianceNormalInfo):
atol = 1e-2
elif isinstance(distribution_info, IsotropicNormalInfo):
atol = 1e-3
else:
atol = 1e-6
n = 70000
# Generate a distribution with expectation parameters.
exp_parameters = distribution_info.exp_parameter_generator(generator, shape=())
# Generate variates from the corresponding scipy distribution.
scipy_distribution = distribution_info.exp_to_scipy_distribution(exp_parameters)
scipy_x = scipy_distribution.rvs(random_state=generator, size=n)
# Convert the variates to sufficient statistics.
efax_x = distribution_info.scipy_to_exp_family_observation(scipy_x)
flat_efax_x = flat_dict_of_observations(efax_x)
flat_parameters = flat_dict_of_parameters(exp_parameters)
flat_efax_x_clamped = {path: flat_parameters[path].domain_support().clamp(value)
for path, value in flat_efax_x.items()}
efax_x_clamped: Array | dict[str, Any] = (flat_efax_x_clamped[()]
if flat_efax_x_clamped.keys() == {()}
else unflatten_mapping(flat_efax_x_clamped))
estimator = MaximumLikelihoodEstimator.create_estimator(exp_parameters)
sufficient_stats = estimator.sufficient_statistics(efax_x_clamped)
# Verify that the mean of the sufficient statistics equals the expectation parameters.
calculated_parameters = parameter_map(partial(np.mean, axis=0), # type: ignore[arg-type]
sufficient_stats)
assert_tree_allclose(exp_parameters, calculated_parameters, rtol=rtol, atol=atol)

Sorry, the diff of this file is too big to display