You're Invited:Meet the Socket Team at RSAC and BSidesSF 2026, March 23–26.RSVP
Socket
Book a DemoSign in
Socket

efax

Package Overview
Dependencies
Maintainers
1
Versions
115
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

efax - pypi Package Compare versions

Comparing version
1.21.1
to
1.21.2
+2
examples/.editorconfig
[*.py]
max_line_length = 80
import jax._src.xla_bridge as xb # noqa: PLC2701
import efax # noqa: F401
def jax_is_initialized() -> bool:
return bool(xb._backends) # noqa: SLF001 # pyright: ignore
def test_jax_not_initialized() -> None:
assert not jax_is_initialized()
+17
-12

@@ -7,3 +7,2 @@ from __future__ import annotations

import jax.numpy as jnp
import jax.scipy.special as jss

@@ -13,3 +12,4 @@ import numpy as np

from numpy.random import Generator
from tjax import JaxArray, JaxComplexArray, JaxRealArray, Shape, inverse_softplus, softplus
from tjax import (JaxArray, JaxComplexArray, JaxRealArray, RealNumeric, Shape, inverse_softplus,
softplus)
from typing_extensions import override

@@ -20,3 +20,3 @@

def _fix_bound(bound: JaxArray | float | None, x: JaxArray) -> JaxArray | None:
def _fix_bound(bound: RealNumeric | None, x: JaxArray) -> JaxArray | None:
xp = array_namespace(x)

@@ -55,3 +55,3 @@ if bound is None:

def general_array_namespace(x: JaxRealArray | float) -> ModuleType:
def general_array_namespace(x: RealNumeric) -> ModuleType:
if isinstance(x, float):

@@ -62,6 +62,11 @@ return np

def canonical_float_epsilon(xp: ModuleType) -> float:
dtype = xp.empty((), dtype=float).dtype # For Jax, this is canonicalize_dtype(float).
return float(xp.finfo(dtype).eps)
@dataclass
class RealField(Ring):
minimum: float | JaxRealArray | None = None
maximum: float | JaxRealArray | None = None
minimum: RealNumeric | None = None
maximum: RealNumeric | None = None
generation_scale: float = 1.0 # Scale the generated random numbers to improve random testing.

@@ -72,7 +77,6 @@ min_open: bool = True # Open interval

def __post_init__(self) -> None:
dtype = jnp.empty((), dtype=float).dtype # This is canonicalize_dtype(float).
eps = float(np.finfo(dtype).eps)
if self.min_open and self.minimum is not None:
xp = general_array_namespace(self.minimum)
self.minimum = jnp.asarray(jnp.maximum(
eps = canonical_float_epsilon(xp)
self.minimum = xp.asarray(xp.maximum(
self.minimum + eps,

@@ -82,3 +86,4 @@ self.minimum * (1.0 + xp.copysign(eps, self.minimum))))

xp = general_array_namespace(self.maximum)
self.maximum = jnp.asarray(jnp.minimum(
eps = canonical_float_epsilon(xp)
self.maximum = xp.asarray(xp.minimum(
self.maximum - eps,

@@ -148,4 +153,4 @@ self.maximum * (1.0 + xp.copysign(eps, -self.maximum))))

class ComplexField(Ring):
minimum_modulus: float | JaxRealArray = 0.0
maximum_modulus: float | JaxRealArray | None = None
minimum_modulus: RealNumeric = 0.0
maximum_modulus: RealNumeric | None = None

@@ -152,0 +157,0 @@ @override

"""Bayesian evidence combination.
This example is based on section 1.2.1 from expfam.pdf, entitled Bayesian evidence combination.
This example is based on section 1.2.1 from expfam.pdf, entitled Bayesian
evidence combination.
Suppose you have a prior, and a set of likelihoods, and you want to combine all of the evidence
into one distribution.
Suppose you have a prior, and a set of likelihoods, and you want to combine all
of the evidence into one distribution.
"""

@@ -24,4 +25,4 @@ from operator import add

# Sum. We use parameter_map to ensure that we don't accidentally add "fixed" parameters, e.g., the
# failure count of a negative binomial distribution.
# Sum. We use parameter_map to ensure that we don't accidentally add "fixed"
# parameters, e.g., the failure count of a negative binomial distribution.
posterior_np = parameter_map(add, prior_np, likelihood_np)

@@ -31,7 +32,19 @@

posterior = posterior_np.to_variance_parametrization()
print_generic(posterior)
# MultivariateDiagonalNormalVP[dataclass]
print_generic(prior=prior,
likelihood=likelihood,
posterior=posterior)
# likelihood=MultivariateDiagonalNormalVP[dataclass]
# ├── mean=Jax Array (2,) float32
# │ └── 0.0355 │ -0.2000
# │ └── 1.1000 │ -2.2000
# └── variance=Jax Array (2,) float32
# └── 0.0968 │ 0.0909
# └── 3.0000 │ 1.0000
# posterior=MultivariateDiagonalNormalVP[dataclass]
# ├── mean=Jax Array (2,) float32
# │ └── 0.8462 │ -2.0000
# └── variance=Jax Array (2,) float32
# └── 2.3077 │ 0.9091
# prior=MultivariateDiagonalNormalVP[dataclass]
# ├── mean=Jax Array (2,) float32
# │ └── 0.0000 │ 0.0000
# └── variance=Jax Array (2,) float32
# └── 10.0000 │ 10.0000
"""Cross-entropy.
This example is based on section 1.4.1 from expfam.pdf, entitled Information theoretic statistics.
This example is based on section 1.4.1 from expfam.pdf, entitled Information
theoretic statistics.
"""

@@ -10,8 +11,8 @@ import jax.numpy as jnp

# p is the expectation parameters of three Bernoulli distributions having probabilities 0.4, 0.5,
# and 0.6.
# p is the expectation parameters of three Bernoulli distributions having
# probabilities 0.4, 0.5, and 0.6.
p = BernoulliEP(jnp.asarray([0.4, 0.5, 0.6]))
# q is the natural parameters of three Bernoulli distributions having log-odds 0, which is
# probability 0.5.
# q is the natural parameters of three Bernoulli distributions having log-odds
# 0, which is probability 0.5.
q = BernoulliNP(jnp.zeros(3))

@@ -23,10 +24,11 @@

# q2 is natural parameters of Bernoulli distributions having a probability of 0.3.
# q2 is natural parameters of Bernoulli distributions having a probability of
# 0.3.
p2 = BernoulliEP(0.3 * jnp.ones(3))
q2 = p2.to_nat()
# A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation
# with probability 0.4 better than the other observations.
print_generic(p.cross_entropy(q2))
# Jax Array (3,) float32
# └── 0.6956 │ 0.7803 │ 0.8651
# A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation with probability
# 0.4 better than the other observations.
"""Maximum likelihood estimation.
This example is based on section 1.3.2 from expfam.pdf, entitled Maximum likelihood estimation.
This example is based on section 1.3.2 from expfam.pdf, entitled Maximum
likelihood estimation.
Suppose you have some samples from a distribution family with unknown parameters, and you want to
estimate the maximum likelihood parmaters of the distribution.
Suppose you have some samples from a distribution family with unknown
parameters, and you want to estimate the maximum likelihood parmaters of the
distribution.
"""

@@ -12,2 +14,3 @@ import jax.numpy as jnp

from efax import DirichletEP, DirichletNP, MaximumLikelihoodEstimator, parameter_mean
from tjax import print_generic

@@ -27,3 +30,4 @@ # Consider a Dirichlet distribution with a given alpha.

ss = estimator.sufficient_statistics(samples)
# ss has type DirichletEP. This is similar to the conjguate prior of the Dirichlet distribution.
# ss has type DirichletEP. This is similar to the conjguate prior of the
# Dirichlet distribution.

@@ -35,2 +39,9 @@ # Take the mean over the first axis.

estimated_distribution = ss_mean.to_nat()
print(estimated_distribution.alpha_minus_one + 1.0) # [1.9849904 3.0065458 3.963935 ] # noqa: T201
print_generic(estimated_distribution=estimated_distribution,
source_distribution=source_distribution)
# estimated_distribution=DirichletNP[dataclass]
# └── alpha_minus_one=Jax Array (3,) float32
# └── 0.9797 │ 1.9539 │ 2.9763
# source_distribution=DirichletNP[dataclass]
# └── alpha_minus_one=Jax Array (3,) float32
# └── 1.0000 │ 2.0000 │ 3.0000
"""Optimization.
This example illustrates how this library fits in a typical machine learning context. Suppose we
have an unknown target value, and a loss function based on the cross-entropy between the target
value and a predictive distribution. We will optimize the predictive distribution by a small
fraction of its cotangent.
This example illustrates how this library fits in a typical machine learning
context. Suppose we have an unknown target value, and a loss function based on
the cross-entropy between the target value and a predictive distribution. We
will optimize the predictive distribution by a small fraction of its cotangent.
"""

@@ -37,28 +37,34 @@ import jax.numpy as jnp

# The target_distribution is represented as the expectation parameters of a Bernoulli distribution
# corresponding to probabilities 0.3, 0.4, and 0.7.
# The target_distribution is represented as the expectation parameters of a
# Bernoulli distribution corresponding to probabilities 0.3, 0.4, and 0.7.
target_distribution = BernoulliEP(jnp.asarray([0.3, 0.4, 0.7]))
# The initial predictive distribution is represented as the natural parameters of a Bernoulli
# distribution corresponding to log-odds 0, which is probability 0.5.
# The initial predictive distribution is represented as the natural parameters
# of a Bernoulli distribution corresponding to log-odds 0, which is probability
# 0.5.
initial_predictive_distribution = BernoulliNP(jnp.zeros(3))
# Optimize the predictive distribution iteratively, and output the natural parameters of the
# prediction.
predictive_distribution = lax.while_loop(cond_fun, body_fun, initial_predictive_distribution)
print_generic(predictive_distribution)
# BernoulliNP
# Optimize the predictive distribution iteratively.
predictive_distribution = lax.while_loop(cond_fun, body_fun,
initial_predictive_distribution)
# Compare the optimized predictive distribution with the target value in the
# same natural parametrization.
print_generic(predictive_distribution=predictive_distribution,
target_distribution=target_distribution.to_nat())
# predictive_distribution=BernoulliNP[dataclass]
# └── log_odds=Jax Array (3,) float32
# └── -0.8440 │ -0.4047 │ 0.8440
# Compare the optimized predictive distribution with the target value in the same parametrization.
print_generic(target_distribution.to_nat())
# BernoulliNP
# target_distribution=BernoulliNP[dataclass]
# └── log_odds=Jax Array (3,) float32
# └── -0.8473 │ -0.4055 │ 0.8473
# Print the optimized natural parameters as expectation parameters.
print_generic(predictive_distribution.to_exp())
# BernoulliEP
# Do the same in the expectation parametrization.
print_generic(predictive_distribution=predictive_distribution.to_exp(),
target_distribution=target_distribution)
# predictive_distribution=BernoulliEP[dataclass]
# └── probability=Jax Array (3,) float32
# └── 0.3007 │ 0.4002 │ 0.6993
# target_distribution=BernoulliEP[dataclass]
# └── probability=Jax Array (3,) float32
# └── 0.3000 │ 0.4000 │ 0.7000
+129
-39
Metadata-Version: 2.4
Name: efax
Version: 1.21.1
Version: 1.21.2
Summary: Exponential families for JAX

@@ -275,28 +275,88 @@ Project-URL: repository, https://github.com/NeilGirdhar/efax

from __future__ import annotations
"""Cross-entropy.
This example is based on section 1.4.1 from expfam.pdf, entitled Information
theoretic statistics.
"""
import jax.numpy as jnp
from tjax import print_generic
from efax import BernoulliEP, BernoulliNP
# p is the expectation parameters of three Bernoulli distributions having probabilities 0.4, 0.5,
# and 0.6.
# p is the expectation parameters of three Bernoulli distributions having
# probabilities 0.4, 0.5, and 0.6.
p = BernoulliEP(jnp.asarray([0.4, 0.5, 0.6]))
# q is the natural parameters of three Bernoulli distributions having log-odds 0, which is
# probability 0.5.
# q is the natural parameters of three Bernoulli distributions having log-odds
# 0, which is probability 0.5.
q = BernoulliNP(jnp.zeros(3))
print(p.cross_entropy(q)) # noqa: T201
# [0.6931472 0.6931472 0.6931472]
print_generic(p.cross_entropy(q))
# Jax Array (3,) float32
# └── 0.6931 │ 0.6931 │ 0.6931
# q2 is natural parameters of Bernoulli distributions having a probability of 0.3.
# q2 is natural parameters of Bernoulli distributions having a probability of
# 0.3.
p2 = BernoulliEP(0.3 * jnp.ones(3))
q2 = p2.to_nat()
print(p.cross_entropy(q2)) # noqa: T201
# [0.6955941 0.78032386 0.86505365]
# A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation with probability
# 0.4 better than the other observations.
# A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation
# with probability 0.4 better than the other observations.
print_generic(p.cross_entropy(q2))
# Jax Array (3,) float32
# └── 0.6956 │ 0.7803 │ 0.8651
Evidence combination:
.. code:: python
"""Bayesian evidence combination.
This example is based on section 1.2.1 from expfam.pdf, entitled Bayesian
evidence combination.
Suppose you have a prior, and a set of likelihoods, and you want to combine all
of the evidence into one distribution.
"""
from operator import add
import jax.numpy as jnp
from tjax import print_generic
from efax import MultivariateDiagonalNormalVP, parameter_map
prior = MultivariateDiagonalNormalVP(mean=jnp.zeros(2),
variance=10 * jnp.ones(2))
likelihood = MultivariateDiagonalNormalVP(mean=jnp.asarray([1.1, -2.2]),
variance=jnp.asarray([3.0, 1.0]))
# Convert to the natural parametrization.
prior_np = prior.to_nat()
likelihood_np = likelihood.to_nat()
# Sum. We use parameter_map to ensure that we don't accidentally add "fixed"
# parameters, e.g., the failure count of a negative binomial distribution.
posterior_np = parameter_map(add, prior_np, likelihood_np)
# Convert to the source parametrization.
posterior = posterior_np.to_variance_parametrization()
print_generic(prior=prior,
likelihood=likelihood,
posterior=posterior)
# likelihood=MultivariateDiagonalNormalVP[dataclass]
# ├── mean=Jax Array (2,) float32
# │ └── 1.1000 │ -2.2000
# └── variance=Jax Array (2,) float32
# └── 3.0000 │ 1.0000
# posterior=MultivariateDiagonalNormalVP[dataclass]
# ├── mean=Jax Array (2,) float32
# │ └── 0.8462 │ -2.0000
# └── variance=Jax Array (2,) float32
# └── 2.3077 │ 0.9091
# prior=MultivariateDiagonalNormalVP[dataclass]
# ├── mean=Jax Array (2,) float32
# │ └── 0.0000 │ 0.0000
# └── variance=Jax Array (2,) float32
# └── 10.0000 │ 10.0000
Optimization

@@ -308,4 +368,9 @@ ------------

from __future__ import annotations
"""Optimization.
This example illustrates how this library fits in a typical machine learning
context. Suppose we have an unknown target value, and a loss function based on
the cross-entropy between the target value and a predictive distribution. We
will optimize the predictive distribution by a small fraction of its cotangent.
"""
import jax.numpy as jnp

@@ -322,3 +387,3 @@ from jax import grad, lax

gce = jit(grad(cross_entropy_loss, 1))
gradient_cross_entropy = jit(grad(cross_entropy_loss, 1))

@@ -331,3 +396,3 @@

def body_fun(q: BernoulliNP) -> BernoulliNP:
q_bar = gce(some_p, q)
q_bar = gradient_cross_entropy(target_distribution, q)
return parameter_map(apply, q, q_bar)

@@ -337,3 +402,3 @@

def cond_fun(q: BernoulliNP) -> JaxBooleanArray:
q_bar = gce(some_p, q)
q_bar = gradient_cross_entropy(target_distribution, q)
total = jnp.sum(parameter_dot_product(q_bar, q_bar))

@@ -343,29 +408,35 @@ return total > 1e-6 # noqa: PLR2004

# some_p are expectation parameters of a Bernoulli distribution corresponding
# to probabilities 0.3, 0.4, and 0.7.
some_p = BernoulliEP(jnp.asarray([0.3, 0.4, 0.7]))
# The target_distribution is represented as the expectation parameters of a
# Bernoulli distribution corresponding to probabilities 0.3, 0.4, and 0.7.
target_distribution = BernoulliEP(jnp.asarray([0.3, 0.4, 0.7]))
# some_q are natural parameters of a Bernoulli distribution corresponding to
# log-odds 0, which is probability 0.5.
some_q = BernoulliNP(jnp.zeros(3))
# The initial predictive distribution is represented as the natural parameters
# of a Bernoulli distribution corresponding to log-odds 0, which is probability
# 0.5.
initial_predictive_distribution = BernoulliNP(jnp.zeros(3))
# Optimize the predictive distribution iteratively, and output the natural parameters of the
# prediction.
optimized_q = lax.while_loop(cond_fun, body_fun, some_q)
print_generic(optimized_q)
# BernoulliNP
# Optimize the predictive distribution iteratively.
predictive_distribution = lax.while_loop(cond_fun, body_fun,
initial_predictive_distribution)
# Compare the optimized predictive distribution with the target value in the
# same natural parametrization.
print_generic(predictive_distribution=predictive_distribution,
target_distribution=target_distribution.to_nat())
# predictive_distribution=BernoulliNP[dataclass]
# └── log_odds=Jax Array (3,) float32
# └── -0.8440 │ -0.4047 │ 0.8440
# Compare with the true value.
print_generic(some_p.to_nat())
# BernoulliNP
# target_distribution=BernoulliNP[dataclass]
# └── log_odds=Jax Array (3,) float32
# └── -0.8473 │ -0.4055 │ 0.8473
# Print optimized natural parameters as expectation parameters.
print_generic(optimized_q.to_exp())
# BernoulliEP
# Do the same in the expectation parametrization.
print_generic(predictive_distribution=predictive_distribution.to_exp(),
target_distribution=target_distribution)
# predictive_distribution=BernoulliEP[dataclass]
# └── probability=Jax Array (3,) float32
# └── 0.3007 │ 0.4002 │ 0.6993
# target_distribution=BernoulliEP[dataclass]
# └── probability=Jax Array (3,) float32
# └── 0.3000 │ 0.4000 │ 0.7000

@@ -380,6 +451,16 @@ Maximum likelihood estimation

"""Maximum likelihood estimation.
This example is based on section 1.3.2 from expfam.pdf, entitled Maximum
likelihood estimation.
Suppose you have some samples from a distribution family with unknown
parameters, and you want to estimate the maximum likelihood parmaters of the
distribution.
"""
import jax.numpy as jnp
import jax.random as jr
from efax import DirichletNP, parameter_mean
from efax import DirichletEP, DirichletNP, MaximumLikelihoodEstimator, parameter_mean
from tjax import print_generic

@@ -397,4 +478,6 @@ # Consider a Dirichlet distribution with a given alpha.

# First, convert the samples to their sufficient statistics.
ss = DirichletNP.sufficient_statistics(samples)
# ss has type DirichletEP. This is similar to the conjguate prior of the Dirichlet distribution.
estimator = MaximumLikelihoodEstimator.create_simple_estimator(DirichletEP)
ss = estimator.sufficient_statistics(samples)
# ss has type DirichletEP. This is similar to the conjguate prior of the
# Dirichlet distribution.

@@ -406,3 +489,10 @@ # Take the mean over the first axis.

estimated_distribution = ss_mean.to_nat()
print(estimated_distribution.alpha_minus_one + 1.0) # [1.9849904 3.0065458 3.963935 ] # noqa: T201
print_generic(estimated_distribution=estimated_distribution,
source_distribution=source_distribution)
# estimated_distribution=DirichletNP[dataclass]
# └── alpha_minus_one=Jax Array (3,) float32
# └── 0.9797 │ 1.9539 │ 2.9763
# source_distribution=DirichletNP[dataclass]
# └── alpha_minus_one=Jax Array (3,) float32
# └── 1.0000 │ 2.0000 │ 3.0000

@@ -409,0 +499,0 @@ Contribution guidelines

@@ -7,3 +7,3 @@ [build-system]

name = "efax"
version = "1.21.1"
version = "1.21.2"
description = "Exponential families for JAX"

@@ -10,0 +10,0 @@ readme = "README.rst"

+128
-38

@@ -228,28 +228,88 @@ .. role:: bash(code)

from __future__ import annotations
"""Cross-entropy.
This example is based on section 1.4.1 from expfam.pdf, entitled Information
theoretic statistics.
"""
import jax.numpy as jnp
from tjax import print_generic
from efax import BernoulliEP, BernoulliNP
# p is the expectation parameters of three Bernoulli distributions having probabilities 0.4, 0.5,
# and 0.6.
# p is the expectation parameters of three Bernoulli distributions having
# probabilities 0.4, 0.5, and 0.6.
p = BernoulliEP(jnp.asarray([0.4, 0.5, 0.6]))
# q is the natural parameters of three Bernoulli distributions having log-odds 0, which is
# probability 0.5.
# q is the natural parameters of three Bernoulli distributions having log-odds
# 0, which is probability 0.5.
q = BernoulliNP(jnp.zeros(3))
print(p.cross_entropy(q)) # noqa: T201
# [0.6931472 0.6931472 0.6931472]
print_generic(p.cross_entropy(q))
# Jax Array (3,) float32
# └── 0.6931 │ 0.6931 │ 0.6931
# q2 is natural parameters of Bernoulli distributions having a probability of 0.3.
# q2 is natural parameters of Bernoulli distributions having a probability of
# 0.3.
p2 = BernoulliEP(0.3 * jnp.ones(3))
q2 = p2.to_nat()
print(p.cross_entropy(q2)) # noqa: T201
# [0.6955941 0.78032386 0.86505365]
# A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation with probability
# 0.4 better than the other observations.
# A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation
# with probability 0.4 better than the other observations.
print_generic(p.cross_entropy(q2))
# Jax Array (3,) float32
# └── 0.6956 │ 0.7803 │ 0.8651
Evidence combination:
.. code:: python
"""Bayesian evidence combination.
This example is based on section 1.2.1 from expfam.pdf, entitled Bayesian
evidence combination.
Suppose you have a prior, and a set of likelihoods, and you want to combine all
of the evidence into one distribution.
"""
from operator import add
import jax.numpy as jnp
from tjax import print_generic
from efax import MultivariateDiagonalNormalVP, parameter_map
prior = MultivariateDiagonalNormalVP(mean=jnp.zeros(2),
variance=10 * jnp.ones(2))
likelihood = MultivariateDiagonalNormalVP(mean=jnp.asarray([1.1, -2.2]),
variance=jnp.asarray([3.0, 1.0]))
# Convert to the natural parametrization.
prior_np = prior.to_nat()
likelihood_np = likelihood.to_nat()
# Sum. We use parameter_map to ensure that we don't accidentally add "fixed"
# parameters, e.g., the failure count of a negative binomial distribution.
posterior_np = parameter_map(add, prior_np, likelihood_np)
# Convert to the source parametrization.
posterior = posterior_np.to_variance_parametrization()
print_generic(prior=prior,
likelihood=likelihood,
posterior=posterior)
# likelihood=MultivariateDiagonalNormalVP[dataclass]
# ├── mean=Jax Array (2,) float32
# │ └── 1.1000 │ -2.2000
# └── variance=Jax Array (2,) float32
# └── 3.0000 │ 1.0000
# posterior=MultivariateDiagonalNormalVP[dataclass]
# ├── mean=Jax Array (2,) float32
# │ └── 0.8462 │ -2.0000
# └── variance=Jax Array (2,) float32
# └── 2.3077 │ 0.9091
# prior=MultivariateDiagonalNormalVP[dataclass]
# ├── mean=Jax Array (2,) float32
# │ └── 0.0000 │ 0.0000
# └── variance=Jax Array (2,) float32
# └── 10.0000 │ 10.0000
Optimization

@@ -261,4 +321,9 @@ ------------

from __future__ import annotations
"""Optimization.
This example illustrates how this library fits in a typical machine learning
context. Suppose we have an unknown target value, and a loss function based on
the cross-entropy between the target value and a predictive distribution. We
will optimize the predictive distribution by a small fraction of its cotangent.
"""
import jax.numpy as jnp

@@ -275,3 +340,3 @@ from jax import grad, lax

gce = jit(grad(cross_entropy_loss, 1))
gradient_cross_entropy = jit(grad(cross_entropy_loss, 1))

@@ -284,3 +349,3 @@

def body_fun(q: BernoulliNP) -> BernoulliNP:
q_bar = gce(some_p, q)
q_bar = gradient_cross_entropy(target_distribution, q)
return parameter_map(apply, q, q_bar)

@@ -290,3 +355,3 @@

def cond_fun(q: BernoulliNP) -> JaxBooleanArray:
q_bar = gce(some_p, q)
q_bar = gradient_cross_entropy(target_distribution, q)
total = jnp.sum(parameter_dot_product(q_bar, q_bar))

@@ -296,29 +361,35 @@ return total > 1e-6 # noqa: PLR2004

# some_p are expectation parameters of a Bernoulli distribution corresponding
# to probabilities 0.3, 0.4, and 0.7.
some_p = BernoulliEP(jnp.asarray([0.3, 0.4, 0.7]))
# The target_distribution is represented as the expectation parameters of a
# Bernoulli distribution corresponding to probabilities 0.3, 0.4, and 0.7.
target_distribution = BernoulliEP(jnp.asarray([0.3, 0.4, 0.7]))
# some_q are natural parameters of a Bernoulli distribution corresponding to
# log-odds 0, which is probability 0.5.
some_q = BernoulliNP(jnp.zeros(3))
# The initial predictive distribution is represented as the natural parameters
# of a Bernoulli distribution corresponding to log-odds 0, which is probability
# 0.5.
initial_predictive_distribution = BernoulliNP(jnp.zeros(3))
# Optimize the predictive distribution iteratively, and output the natural parameters of the
# prediction.
optimized_q = lax.while_loop(cond_fun, body_fun, some_q)
print_generic(optimized_q)
# BernoulliNP
# Optimize the predictive distribution iteratively.
predictive_distribution = lax.while_loop(cond_fun, body_fun,
initial_predictive_distribution)
# Compare the optimized predictive distribution with the target value in the
# same natural parametrization.
print_generic(predictive_distribution=predictive_distribution,
target_distribution=target_distribution.to_nat())
# predictive_distribution=BernoulliNP[dataclass]
# └── log_odds=Jax Array (3,) float32
# └── -0.8440 │ -0.4047 │ 0.8440
# Compare with the true value.
print_generic(some_p.to_nat())
# BernoulliNP
# target_distribution=BernoulliNP[dataclass]
# └── log_odds=Jax Array (3,) float32
# └── -0.8473 │ -0.4055 │ 0.8473
# Print optimized natural parameters as expectation parameters.
print_generic(optimized_q.to_exp())
# BernoulliEP
# Do the same in the expectation parametrization.
print_generic(predictive_distribution=predictive_distribution.to_exp(),
target_distribution=target_distribution)
# predictive_distribution=BernoulliEP[dataclass]
# └── probability=Jax Array (3,) float32
# └── 0.3007 │ 0.4002 │ 0.6993
# target_distribution=BernoulliEP[dataclass]
# └── probability=Jax Array (3,) float32
# └── 0.3000 │ 0.4000 │ 0.7000

@@ -333,6 +404,16 @@ Maximum likelihood estimation

"""Maximum likelihood estimation.
This example is based on section 1.3.2 from expfam.pdf, entitled Maximum
likelihood estimation.
Suppose you have some samples from a distribution family with unknown
parameters, and you want to estimate the maximum likelihood parmaters of the
distribution.
"""
import jax.numpy as jnp
import jax.random as jr
from efax import DirichletNP, parameter_mean
from efax import DirichletEP, DirichletNP, MaximumLikelihoodEstimator, parameter_mean
from tjax import print_generic

@@ -350,4 +431,6 @@ # Consider a Dirichlet distribution with a given alpha.

# First, convert the samples to their sufficient statistics.
ss = DirichletNP.sufficient_statistics(samples)
# ss has type DirichletEP. This is similar to the conjguate prior of the Dirichlet distribution.
estimator = MaximumLikelihoodEstimator.create_simple_estimator(DirichletEP)
ss = estimator.sufficient_statistics(samples)
# ss has type DirichletEP. This is similar to the conjguate prior of the
# Dirichlet distribution.

@@ -359,3 +442,10 @@ # Take the mean over the first axis.

estimated_distribution = ss_mean.to_nat()
print(estimated_distribution.alpha_minus_one + 1.0) # [1.9849904 3.0065458 3.963935 ] # noqa: T201
print_generic(estimated_distribution=estimated_distribution,
source_distribution=source_distribution)
# estimated_distribution=DirichletNP[dataclass]
# └── alpha_minus_one=Jax Array (3,) float32
# └── 0.9797 │ 1.9539 │ 2.9763
# source_distribution=DirichletNP[dataclass]
# └── alpha_minus_one=Jax Array (3,) float32
# └── 1.0000 │ 2.0000 │ 3.0000

@@ -362,0 +452,0 @@ Contribution guidelines

Sorry, the diff of this file is too big to display