efax
Advanced tools
| [*.py] | ||
| max_line_length = 80 |
| import jax._src.xla_bridge as xb # noqa: PLC2701 | ||
| import efax # noqa: F401 | ||
| def jax_is_initialized() -> bool: | ||
| return bool(xb._backends) # noqa: SLF001 # pyright: ignore | ||
| def test_jax_not_initialized() -> None: | ||
| assert not jax_is_initialized() |
@@ -7,3 +7,2 @@ from __future__ import annotations | ||
| import jax.numpy as jnp | ||
| import jax.scipy.special as jss | ||
@@ -13,3 +12,4 @@ import numpy as np | ||
| from numpy.random import Generator | ||
| from tjax import JaxArray, JaxComplexArray, JaxRealArray, Shape, inverse_softplus, softplus | ||
| from tjax import (JaxArray, JaxComplexArray, JaxRealArray, RealNumeric, Shape, inverse_softplus, | ||
| softplus) | ||
| from typing_extensions import override | ||
@@ -20,3 +20,3 @@ | ||
| def _fix_bound(bound: JaxArray | float | None, x: JaxArray) -> JaxArray | None: | ||
| def _fix_bound(bound: RealNumeric | None, x: JaxArray) -> JaxArray | None: | ||
| xp = array_namespace(x) | ||
@@ -55,3 +55,3 @@ if bound is None: | ||
| def general_array_namespace(x: JaxRealArray | float) -> ModuleType: | ||
| def general_array_namespace(x: RealNumeric) -> ModuleType: | ||
| if isinstance(x, float): | ||
@@ -62,6 +62,11 @@ return np | ||
| def canonical_float_epsilon(xp: ModuleType) -> float: | ||
| dtype = xp.empty((), dtype=float).dtype # For Jax, this is canonicalize_dtype(float). | ||
| return float(xp.finfo(dtype).eps) | ||
| @dataclass | ||
| class RealField(Ring): | ||
| minimum: float | JaxRealArray | None = None | ||
| maximum: float | JaxRealArray | None = None | ||
| minimum: RealNumeric | None = None | ||
| maximum: RealNumeric | None = None | ||
| generation_scale: float = 1.0 # Scale the generated random numbers to improve random testing. | ||
@@ -72,7 +77,6 @@ min_open: bool = True # Open interval | ||
| def __post_init__(self) -> None: | ||
| dtype = jnp.empty((), dtype=float).dtype # This is canonicalize_dtype(float). | ||
| eps = float(np.finfo(dtype).eps) | ||
| if self.min_open and self.minimum is not None: | ||
| xp = general_array_namespace(self.minimum) | ||
| self.minimum = jnp.asarray(jnp.maximum( | ||
| eps = canonical_float_epsilon(xp) | ||
| self.minimum = xp.asarray(xp.maximum( | ||
| self.minimum + eps, | ||
@@ -82,3 +86,4 @@ self.minimum * (1.0 + xp.copysign(eps, self.minimum)))) | ||
| xp = general_array_namespace(self.maximum) | ||
| self.maximum = jnp.asarray(jnp.minimum( | ||
| eps = canonical_float_epsilon(xp) | ||
| self.maximum = xp.asarray(xp.minimum( | ||
| self.maximum - eps, | ||
@@ -148,4 +153,4 @@ self.maximum * (1.0 + xp.copysign(eps, -self.maximum)))) | ||
| class ComplexField(Ring): | ||
| minimum_modulus: float | JaxRealArray = 0.0 | ||
| maximum_modulus: float | JaxRealArray | None = None | ||
| minimum_modulus: RealNumeric = 0.0 | ||
| maximum_modulus: RealNumeric | None = None | ||
@@ -152,0 +157,0 @@ @override |
| """Bayesian evidence combination. | ||
| This example is based on section 1.2.1 from expfam.pdf, entitled Bayesian evidence combination. | ||
| This example is based on section 1.2.1 from expfam.pdf, entitled Bayesian | ||
| evidence combination. | ||
| Suppose you have a prior, and a set of likelihoods, and you want to combine all of the evidence | ||
| into one distribution. | ||
| Suppose you have a prior, and a set of likelihoods, and you want to combine all | ||
| of the evidence into one distribution. | ||
| """ | ||
@@ -24,4 +25,4 @@ from operator import add | ||
| # Sum. We use parameter_map to ensure that we don't accidentally add "fixed" parameters, e.g., the | ||
| # failure count of a negative binomial distribution. | ||
| # Sum. We use parameter_map to ensure that we don't accidentally add "fixed" | ||
| # parameters, e.g., the failure count of a negative binomial distribution. | ||
| posterior_np = parameter_map(add, prior_np, likelihood_np) | ||
@@ -31,7 +32,19 @@ | ||
| posterior = posterior_np.to_variance_parametrization() | ||
| print_generic(posterior) | ||
| # MultivariateDiagonalNormalVP[dataclass] | ||
| print_generic(prior=prior, | ||
| likelihood=likelihood, | ||
| posterior=posterior) | ||
| # likelihood=MultivariateDiagonalNormalVP[dataclass] | ||
| # ├── mean=Jax Array (2,) float32 | ||
| # │ └── 0.0355 │ -0.2000 | ||
| # │ └── 1.1000 │ -2.2000 | ||
| # └── variance=Jax Array (2,) float32 | ||
| # └── 0.0968 │ 0.0909 | ||
| # └── 3.0000 │ 1.0000 | ||
| # posterior=MultivariateDiagonalNormalVP[dataclass] | ||
| # ├── mean=Jax Array (2,) float32 | ||
| # │ └── 0.8462 │ -2.0000 | ||
| # └── variance=Jax Array (2,) float32 | ||
| # └── 2.3077 │ 0.9091 | ||
| # prior=MultivariateDiagonalNormalVP[dataclass] | ||
| # ├── mean=Jax Array (2,) float32 | ||
| # │ └── 0.0000 │ 0.0000 | ||
| # └── variance=Jax Array (2,) float32 | ||
| # └── 10.0000 │ 10.0000 |
| """Cross-entropy. | ||
| This example is based on section 1.4.1 from expfam.pdf, entitled Information theoretic statistics. | ||
| This example is based on section 1.4.1 from expfam.pdf, entitled Information | ||
| theoretic statistics. | ||
| """ | ||
@@ -10,8 +11,8 @@ import jax.numpy as jnp | ||
| # p is the expectation parameters of three Bernoulli distributions having probabilities 0.4, 0.5, | ||
| # and 0.6. | ||
| # p is the expectation parameters of three Bernoulli distributions having | ||
| # probabilities 0.4, 0.5, and 0.6. | ||
| p = BernoulliEP(jnp.asarray([0.4, 0.5, 0.6])) | ||
| # q is the natural parameters of three Bernoulli distributions having log-odds 0, which is | ||
| # probability 0.5. | ||
| # q is the natural parameters of three Bernoulli distributions having log-odds | ||
| # 0, which is probability 0.5. | ||
| q = BernoulliNP(jnp.zeros(3)) | ||
@@ -23,10 +24,11 @@ | ||
| # q2 is natural parameters of Bernoulli distributions having a probability of 0.3. | ||
| # q2 is natural parameters of Bernoulli distributions having a probability of | ||
| # 0.3. | ||
| p2 = BernoulliEP(0.3 * jnp.ones(3)) | ||
| q2 = p2.to_nat() | ||
| # A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation | ||
| # with probability 0.4 better than the other observations. | ||
| print_generic(p.cross_entropy(q2)) | ||
| # Jax Array (3,) float32 | ||
| # └── 0.6956 │ 0.7803 │ 0.8651 | ||
| # A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation with probability | ||
| # 0.4 better than the other observations. |
| """Maximum likelihood estimation. | ||
| This example is based on section 1.3.2 from expfam.pdf, entitled Maximum likelihood estimation. | ||
| This example is based on section 1.3.2 from expfam.pdf, entitled Maximum | ||
| likelihood estimation. | ||
| Suppose you have some samples from a distribution family with unknown parameters, and you want to | ||
| estimate the maximum likelihood parmaters of the distribution. | ||
| Suppose you have some samples from a distribution family with unknown | ||
| parameters, and you want to estimate the maximum likelihood parmaters of the | ||
| distribution. | ||
| """ | ||
@@ -12,2 +14,3 @@ import jax.numpy as jnp | ||
| from efax import DirichletEP, DirichletNP, MaximumLikelihoodEstimator, parameter_mean | ||
| from tjax import print_generic | ||
@@ -27,3 +30,4 @@ # Consider a Dirichlet distribution with a given alpha. | ||
| ss = estimator.sufficient_statistics(samples) | ||
| # ss has type DirichletEP. This is similar to the conjguate prior of the Dirichlet distribution. | ||
| # ss has type DirichletEP. This is similar to the conjguate prior of the | ||
| # Dirichlet distribution. | ||
@@ -35,2 +39,9 @@ # Take the mean over the first axis. | ||
| estimated_distribution = ss_mean.to_nat() | ||
| print(estimated_distribution.alpha_minus_one + 1.0) # [1.9849904 3.0065458 3.963935 ] # noqa: T201 | ||
| print_generic(estimated_distribution=estimated_distribution, | ||
| source_distribution=source_distribution) | ||
| # estimated_distribution=DirichletNP[dataclass] | ||
| # └── alpha_minus_one=Jax Array (3,) float32 | ||
| # └── 0.9797 │ 1.9539 │ 2.9763 | ||
| # source_distribution=DirichletNP[dataclass] | ||
| # └── alpha_minus_one=Jax Array (3,) float32 | ||
| # └── 1.0000 │ 2.0000 │ 3.0000 |
+26
-20
| """Optimization. | ||
| This example illustrates how this library fits in a typical machine learning context. Suppose we | ||
| have an unknown target value, and a loss function based on the cross-entropy between the target | ||
| value and a predictive distribution. We will optimize the predictive distribution by a small | ||
| fraction of its cotangent. | ||
| This example illustrates how this library fits in a typical machine learning | ||
| context. Suppose we have an unknown target value, and a loss function based on | ||
| the cross-entropy between the target value and a predictive distribution. We | ||
| will optimize the predictive distribution by a small fraction of its cotangent. | ||
| """ | ||
@@ -37,28 +37,34 @@ import jax.numpy as jnp | ||
| # The target_distribution is represented as the expectation parameters of a Bernoulli distribution | ||
| # corresponding to probabilities 0.3, 0.4, and 0.7. | ||
| # The target_distribution is represented as the expectation parameters of a | ||
| # Bernoulli distribution corresponding to probabilities 0.3, 0.4, and 0.7. | ||
| target_distribution = BernoulliEP(jnp.asarray([0.3, 0.4, 0.7])) | ||
| # The initial predictive distribution is represented as the natural parameters of a Bernoulli | ||
| # distribution corresponding to log-odds 0, which is probability 0.5. | ||
| # The initial predictive distribution is represented as the natural parameters | ||
| # of a Bernoulli distribution corresponding to log-odds 0, which is probability | ||
| # 0.5. | ||
| initial_predictive_distribution = BernoulliNP(jnp.zeros(3)) | ||
| # Optimize the predictive distribution iteratively, and output the natural parameters of the | ||
| # prediction. | ||
| predictive_distribution = lax.while_loop(cond_fun, body_fun, initial_predictive_distribution) | ||
| print_generic(predictive_distribution) | ||
| # BernoulliNP | ||
| # Optimize the predictive distribution iteratively. | ||
| predictive_distribution = lax.while_loop(cond_fun, body_fun, | ||
| initial_predictive_distribution) | ||
| # Compare the optimized predictive distribution with the target value in the | ||
| # same natural parametrization. | ||
| print_generic(predictive_distribution=predictive_distribution, | ||
| target_distribution=target_distribution.to_nat()) | ||
| # predictive_distribution=BernoulliNP[dataclass] | ||
| # └── log_odds=Jax Array (3,) float32 | ||
| # └── -0.8440 │ -0.4047 │ 0.8440 | ||
| # Compare the optimized predictive distribution with the target value in the same parametrization. | ||
| print_generic(target_distribution.to_nat()) | ||
| # BernoulliNP | ||
| # target_distribution=BernoulliNP[dataclass] | ||
| # └── log_odds=Jax Array (3,) float32 | ||
| # └── -0.8473 │ -0.4055 │ 0.8473 | ||
| # Print the optimized natural parameters as expectation parameters. | ||
| print_generic(predictive_distribution.to_exp()) | ||
| # BernoulliEP | ||
| # Do the same in the expectation parametrization. | ||
| print_generic(predictive_distribution=predictive_distribution.to_exp(), | ||
| target_distribution=target_distribution) | ||
| # predictive_distribution=BernoulliEP[dataclass] | ||
| # └── probability=Jax Array (3,) float32 | ||
| # └── 0.3007 │ 0.4002 │ 0.6993 | ||
| # target_distribution=BernoulliEP[dataclass] | ||
| # └── probability=Jax Array (3,) float32 | ||
| # └── 0.3000 │ 0.4000 │ 0.7000 |
+129
-39
| Metadata-Version: 2.4 | ||
| Name: efax | ||
| Version: 1.21.1 | ||
| Version: 1.21.2 | ||
| Summary: Exponential families for JAX | ||
@@ -275,28 +275,88 @@ Project-URL: repository, https://github.com/NeilGirdhar/efax | ||
| from __future__ import annotations | ||
| """Cross-entropy. | ||
| This example is based on section 1.4.1 from expfam.pdf, entitled Information | ||
| theoretic statistics. | ||
| """ | ||
| import jax.numpy as jnp | ||
| from tjax import print_generic | ||
| from efax import BernoulliEP, BernoulliNP | ||
| # p is the expectation parameters of three Bernoulli distributions having probabilities 0.4, 0.5, | ||
| # and 0.6. | ||
| # p is the expectation parameters of three Bernoulli distributions having | ||
| # probabilities 0.4, 0.5, and 0.6. | ||
| p = BernoulliEP(jnp.asarray([0.4, 0.5, 0.6])) | ||
| # q is the natural parameters of three Bernoulli distributions having log-odds 0, which is | ||
| # probability 0.5. | ||
| # q is the natural parameters of three Bernoulli distributions having log-odds | ||
| # 0, which is probability 0.5. | ||
| q = BernoulliNP(jnp.zeros(3)) | ||
| print(p.cross_entropy(q)) # noqa: T201 | ||
| # [0.6931472 0.6931472 0.6931472] | ||
| print_generic(p.cross_entropy(q)) | ||
| # Jax Array (3,) float32 | ||
| # └── 0.6931 │ 0.6931 │ 0.6931 | ||
| # q2 is natural parameters of Bernoulli distributions having a probability of 0.3. | ||
| # q2 is natural parameters of Bernoulli distributions having a probability of | ||
| # 0.3. | ||
| p2 = BernoulliEP(0.3 * jnp.ones(3)) | ||
| q2 = p2.to_nat() | ||
| print(p.cross_entropy(q2)) # noqa: T201 | ||
| # [0.6955941 0.78032386 0.86505365] | ||
| # A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation with probability | ||
| # 0.4 better than the other observations. | ||
| # A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation | ||
| # with probability 0.4 better than the other observations. | ||
| print_generic(p.cross_entropy(q2)) | ||
| # Jax Array (3,) float32 | ||
| # └── 0.6956 │ 0.7803 │ 0.8651 | ||
| Evidence combination: | ||
| .. code:: python | ||
| """Bayesian evidence combination. | ||
| This example is based on section 1.2.1 from expfam.pdf, entitled Bayesian | ||
| evidence combination. | ||
| Suppose you have a prior, and a set of likelihoods, and you want to combine all | ||
| of the evidence into one distribution. | ||
| """ | ||
| from operator import add | ||
| import jax.numpy as jnp | ||
| from tjax import print_generic | ||
| from efax import MultivariateDiagonalNormalVP, parameter_map | ||
| prior = MultivariateDiagonalNormalVP(mean=jnp.zeros(2), | ||
| variance=10 * jnp.ones(2)) | ||
| likelihood = MultivariateDiagonalNormalVP(mean=jnp.asarray([1.1, -2.2]), | ||
| variance=jnp.asarray([3.0, 1.0])) | ||
| # Convert to the natural parametrization. | ||
| prior_np = prior.to_nat() | ||
| likelihood_np = likelihood.to_nat() | ||
| # Sum. We use parameter_map to ensure that we don't accidentally add "fixed" | ||
| # parameters, e.g., the failure count of a negative binomial distribution. | ||
| posterior_np = parameter_map(add, prior_np, likelihood_np) | ||
| # Convert to the source parametrization. | ||
| posterior = posterior_np.to_variance_parametrization() | ||
| print_generic(prior=prior, | ||
| likelihood=likelihood, | ||
| posterior=posterior) | ||
| # likelihood=MultivariateDiagonalNormalVP[dataclass] | ||
| # ├── mean=Jax Array (2,) float32 | ||
| # │ └── 1.1000 │ -2.2000 | ||
| # └── variance=Jax Array (2,) float32 | ||
| # └── 3.0000 │ 1.0000 | ||
| # posterior=MultivariateDiagonalNormalVP[dataclass] | ||
| # ├── mean=Jax Array (2,) float32 | ||
| # │ └── 0.8462 │ -2.0000 | ||
| # └── variance=Jax Array (2,) float32 | ||
| # └── 2.3077 │ 0.9091 | ||
| # prior=MultivariateDiagonalNormalVP[dataclass] | ||
| # ├── mean=Jax Array (2,) float32 | ||
| # │ └── 0.0000 │ 0.0000 | ||
| # └── variance=Jax Array (2,) float32 | ||
| # └── 10.0000 │ 10.0000 | ||
| Optimization | ||
@@ -308,4 +368,9 @@ ------------ | ||
| from __future__ import annotations | ||
| """Optimization. | ||
| This example illustrates how this library fits in a typical machine learning | ||
| context. Suppose we have an unknown target value, and a loss function based on | ||
| the cross-entropy between the target value and a predictive distribution. We | ||
| will optimize the predictive distribution by a small fraction of its cotangent. | ||
| """ | ||
| import jax.numpy as jnp | ||
@@ -322,3 +387,3 @@ from jax import grad, lax | ||
| gce = jit(grad(cross_entropy_loss, 1)) | ||
| gradient_cross_entropy = jit(grad(cross_entropy_loss, 1)) | ||
@@ -331,3 +396,3 @@ | ||
| def body_fun(q: BernoulliNP) -> BernoulliNP: | ||
| q_bar = gce(some_p, q) | ||
| q_bar = gradient_cross_entropy(target_distribution, q) | ||
| return parameter_map(apply, q, q_bar) | ||
@@ -337,3 +402,3 @@ | ||
| def cond_fun(q: BernoulliNP) -> JaxBooleanArray: | ||
| q_bar = gce(some_p, q) | ||
| q_bar = gradient_cross_entropy(target_distribution, q) | ||
| total = jnp.sum(parameter_dot_product(q_bar, q_bar)) | ||
@@ -343,29 +408,35 @@ return total > 1e-6 # noqa: PLR2004 | ||
| # some_p are expectation parameters of a Bernoulli distribution corresponding | ||
| # to probabilities 0.3, 0.4, and 0.7. | ||
| some_p = BernoulliEP(jnp.asarray([0.3, 0.4, 0.7])) | ||
| # The target_distribution is represented as the expectation parameters of a | ||
| # Bernoulli distribution corresponding to probabilities 0.3, 0.4, and 0.7. | ||
| target_distribution = BernoulliEP(jnp.asarray([0.3, 0.4, 0.7])) | ||
| # some_q are natural parameters of a Bernoulli distribution corresponding to | ||
| # log-odds 0, which is probability 0.5. | ||
| some_q = BernoulliNP(jnp.zeros(3)) | ||
| # The initial predictive distribution is represented as the natural parameters | ||
| # of a Bernoulli distribution corresponding to log-odds 0, which is probability | ||
| # 0.5. | ||
| initial_predictive_distribution = BernoulliNP(jnp.zeros(3)) | ||
| # Optimize the predictive distribution iteratively, and output the natural parameters of the | ||
| # prediction. | ||
| optimized_q = lax.while_loop(cond_fun, body_fun, some_q) | ||
| print_generic(optimized_q) | ||
| # BernoulliNP | ||
| # Optimize the predictive distribution iteratively. | ||
| predictive_distribution = lax.while_loop(cond_fun, body_fun, | ||
| initial_predictive_distribution) | ||
| # Compare the optimized predictive distribution with the target value in the | ||
| # same natural parametrization. | ||
| print_generic(predictive_distribution=predictive_distribution, | ||
| target_distribution=target_distribution.to_nat()) | ||
| # predictive_distribution=BernoulliNP[dataclass] | ||
| # └── log_odds=Jax Array (3,) float32 | ||
| # └── -0.8440 │ -0.4047 │ 0.8440 | ||
| # Compare with the true value. | ||
| print_generic(some_p.to_nat()) | ||
| # BernoulliNP | ||
| # target_distribution=BernoulliNP[dataclass] | ||
| # └── log_odds=Jax Array (3,) float32 | ||
| # └── -0.8473 │ -0.4055 │ 0.8473 | ||
| # Print optimized natural parameters as expectation parameters. | ||
| print_generic(optimized_q.to_exp()) | ||
| # BernoulliEP | ||
| # Do the same in the expectation parametrization. | ||
| print_generic(predictive_distribution=predictive_distribution.to_exp(), | ||
| target_distribution=target_distribution) | ||
| # predictive_distribution=BernoulliEP[dataclass] | ||
| # └── probability=Jax Array (3,) float32 | ||
| # └── 0.3007 │ 0.4002 │ 0.6993 | ||
| # target_distribution=BernoulliEP[dataclass] | ||
| # └── probability=Jax Array (3,) float32 | ||
| # └── 0.3000 │ 0.4000 │ 0.7000 | ||
@@ -380,6 +451,16 @@ Maximum likelihood estimation | ||
| """Maximum likelihood estimation. | ||
| This example is based on section 1.3.2 from expfam.pdf, entitled Maximum | ||
| likelihood estimation. | ||
| Suppose you have some samples from a distribution family with unknown | ||
| parameters, and you want to estimate the maximum likelihood parmaters of the | ||
| distribution. | ||
| """ | ||
| import jax.numpy as jnp | ||
| import jax.random as jr | ||
| from efax import DirichletNP, parameter_mean | ||
| from efax import DirichletEP, DirichletNP, MaximumLikelihoodEstimator, parameter_mean | ||
| from tjax import print_generic | ||
@@ -397,4 +478,6 @@ # Consider a Dirichlet distribution with a given alpha. | ||
| # First, convert the samples to their sufficient statistics. | ||
| ss = DirichletNP.sufficient_statistics(samples) | ||
| # ss has type DirichletEP. This is similar to the conjguate prior of the Dirichlet distribution. | ||
| estimator = MaximumLikelihoodEstimator.create_simple_estimator(DirichletEP) | ||
| ss = estimator.sufficient_statistics(samples) | ||
| # ss has type DirichletEP. This is similar to the conjguate prior of the | ||
| # Dirichlet distribution. | ||
@@ -406,3 +489,10 @@ # Take the mean over the first axis. | ||
| estimated_distribution = ss_mean.to_nat() | ||
| print(estimated_distribution.alpha_minus_one + 1.0) # [1.9849904 3.0065458 3.963935 ] # noqa: T201 | ||
| print_generic(estimated_distribution=estimated_distribution, | ||
| source_distribution=source_distribution) | ||
| # estimated_distribution=DirichletNP[dataclass] | ||
| # └── alpha_minus_one=Jax Array (3,) float32 | ||
| # └── 0.9797 │ 1.9539 │ 2.9763 | ||
| # source_distribution=DirichletNP[dataclass] | ||
| # └── alpha_minus_one=Jax Array (3,) float32 | ||
| # └── 1.0000 │ 2.0000 │ 3.0000 | ||
@@ -409,0 +499,0 @@ Contribution guidelines |
+1
-1
@@ -7,3 +7,3 @@ [build-system] | ||
| name = "efax" | ||
| version = "1.21.1" | ||
| version = "1.21.2" | ||
| description = "Exponential families for JAX" | ||
@@ -10,0 +10,0 @@ readme = "README.rst" |
+128
-38
@@ -228,28 +228,88 @@ .. role:: bash(code) | ||
| from __future__ import annotations | ||
| """Cross-entropy. | ||
| This example is based on section 1.4.1 from expfam.pdf, entitled Information | ||
| theoretic statistics. | ||
| """ | ||
| import jax.numpy as jnp | ||
| from tjax import print_generic | ||
| from efax import BernoulliEP, BernoulliNP | ||
| # p is the expectation parameters of three Bernoulli distributions having probabilities 0.4, 0.5, | ||
| # and 0.6. | ||
| # p is the expectation parameters of three Bernoulli distributions having | ||
| # probabilities 0.4, 0.5, and 0.6. | ||
| p = BernoulliEP(jnp.asarray([0.4, 0.5, 0.6])) | ||
| # q is the natural parameters of three Bernoulli distributions having log-odds 0, which is | ||
| # probability 0.5. | ||
| # q is the natural parameters of three Bernoulli distributions having log-odds | ||
| # 0, which is probability 0.5. | ||
| q = BernoulliNP(jnp.zeros(3)) | ||
| print(p.cross_entropy(q)) # noqa: T201 | ||
| # [0.6931472 0.6931472 0.6931472] | ||
| print_generic(p.cross_entropy(q)) | ||
| # Jax Array (3,) float32 | ||
| # └── 0.6931 │ 0.6931 │ 0.6931 | ||
| # q2 is natural parameters of Bernoulli distributions having a probability of 0.3. | ||
| # q2 is natural parameters of Bernoulli distributions having a probability of | ||
| # 0.3. | ||
| p2 = BernoulliEP(0.3 * jnp.ones(3)) | ||
| q2 = p2.to_nat() | ||
| print(p.cross_entropy(q2)) # noqa: T201 | ||
| # [0.6955941 0.78032386 0.86505365] | ||
| # A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation with probability | ||
| # 0.4 better than the other observations. | ||
| # A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation | ||
| # with probability 0.4 better than the other observations. | ||
| print_generic(p.cross_entropy(q2)) | ||
| # Jax Array (3,) float32 | ||
| # └── 0.6956 │ 0.7803 │ 0.8651 | ||
| Evidence combination: | ||
| .. code:: python | ||
| """Bayesian evidence combination. | ||
| This example is based on section 1.2.1 from expfam.pdf, entitled Bayesian | ||
| evidence combination. | ||
| Suppose you have a prior, and a set of likelihoods, and you want to combine all | ||
| of the evidence into one distribution. | ||
| """ | ||
| from operator import add | ||
| import jax.numpy as jnp | ||
| from tjax import print_generic | ||
| from efax import MultivariateDiagonalNormalVP, parameter_map | ||
| prior = MultivariateDiagonalNormalVP(mean=jnp.zeros(2), | ||
| variance=10 * jnp.ones(2)) | ||
| likelihood = MultivariateDiagonalNormalVP(mean=jnp.asarray([1.1, -2.2]), | ||
| variance=jnp.asarray([3.0, 1.0])) | ||
| # Convert to the natural parametrization. | ||
| prior_np = prior.to_nat() | ||
| likelihood_np = likelihood.to_nat() | ||
| # Sum. We use parameter_map to ensure that we don't accidentally add "fixed" | ||
| # parameters, e.g., the failure count of a negative binomial distribution. | ||
| posterior_np = parameter_map(add, prior_np, likelihood_np) | ||
| # Convert to the source parametrization. | ||
| posterior = posterior_np.to_variance_parametrization() | ||
| print_generic(prior=prior, | ||
| likelihood=likelihood, | ||
| posterior=posterior) | ||
| # likelihood=MultivariateDiagonalNormalVP[dataclass] | ||
| # ├── mean=Jax Array (2,) float32 | ||
| # │ └── 1.1000 │ -2.2000 | ||
| # └── variance=Jax Array (2,) float32 | ||
| # └── 3.0000 │ 1.0000 | ||
| # posterior=MultivariateDiagonalNormalVP[dataclass] | ||
| # ├── mean=Jax Array (2,) float32 | ||
| # │ └── 0.8462 │ -2.0000 | ||
| # └── variance=Jax Array (2,) float32 | ||
| # └── 2.3077 │ 0.9091 | ||
| # prior=MultivariateDiagonalNormalVP[dataclass] | ||
| # ├── mean=Jax Array (2,) float32 | ||
| # │ └── 0.0000 │ 0.0000 | ||
| # └── variance=Jax Array (2,) float32 | ||
| # └── 10.0000 │ 10.0000 | ||
| Optimization | ||
@@ -261,4 +321,9 @@ ------------ | ||
| from __future__ import annotations | ||
| """Optimization. | ||
| This example illustrates how this library fits in a typical machine learning | ||
| context. Suppose we have an unknown target value, and a loss function based on | ||
| the cross-entropy between the target value and a predictive distribution. We | ||
| will optimize the predictive distribution by a small fraction of its cotangent. | ||
| """ | ||
| import jax.numpy as jnp | ||
@@ -275,3 +340,3 @@ from jax import grad, lax | ||
| gce = jit(grad(cross_entropy_loss, 1)) | ||
| gradient_cross_entropy = jit(grad(cross_entropy_loss, 1)) | ||
@@ -284,3 +349,3 @@ | ||
| def body_fun(q: BernoulliNP) -> BernoulliNP: | ||
| q_bar = gce(some_p, q) | ||
| q_bar = gradient_cross_entropy(target_distribution, q) | ||
| return parameter_map(apply, q, q_bar) | ||
@@ -290,3 +355,3 @@ | ||
| def cond_fun(q: BernoulliNP) -> JaxBooleanArray: | ||
| q_bar = gce(some_p, q) | ||
| q_bar = gradient_cross_entropy(target_distribution, q) | ||
| total = jnp.sum(parameter_dot_product(q_bar, q_bar)) | ||
@@ -296,29 +361,35 @@ return total > 1e-6 # noqa: PLR2004 | ||
| # some_p are expectation parameters of a Bernoulli distribution corresponding | ||
| # to probabilities 0.3, 0.4, and 0.7. | ||
| some_p = BernoulliEP(jnp.asarray([0.3, 0.4, 0.7])) | ||
| # The target_distribution is represented as the expectation parameters of a | ||
| # Bernoulli distribution corresponding to probabilities 0.3, 0.4, and 0.7. | ||
| target_distribution = BernoulliEP(jnp.asarray([0.3, 0.4, 0.7])) | ||
| # some_q are natural parameters of a Bernoulli distribution corresponding to | ||
| # log-odds 0, which is probability 0.5. | ||
| some_q = BernoulliNP(jnp.zeros(3)) | ||
| # The initial predictive distribution is represented as the natural parameters | ||
| # of a Bernoulli distribution corresponding to log-odds 0, which is probability | ||
| # 0.5. | ||
| initial_predictive_distribution = BernoulliNP(jnp.zeros(3)) | ||
| # Optimize the predictive distribution iteratively, and output the natural parameters of the | ||
| # prediction. | ||
| optimized_q = lax.while_loop(cond_fun, body_fun, some_q) | ||
| print_generic(optimized_q) | ||
| # BernoulliNP | ||
| # Optimize the predictive distribution iteratively. | ||
| predictive_distribution = lax.while_loop(cond_fun, body_fun, | ||
| initial_predictive_distribution) | ||
| # Compare the optimized predictive distribution with the target value in the | ||
| # same natural parametrization. | ||
| print_generic(predictive_distribution=predictive_distribution, | ||
| target_distribution=target_distribution.to_nat()) | ||
| # predictive_distribution=BernoulliNP[dataclass] | ||
| # └── log_odds=Jax Array (3,) float32 | ||
| # └── -0.8440 │ -0.4047 │ 0.8440 | ||
| # Compare with the true value. | ||
| print_generic(some_p.to_nat()) | ||
| # BernoulliNP | ||
| # target_distribution=BernoulliNP[dataclass] | ||
| # └── log_odds=Jax Array (3,) float32 | ||
| # └── -0.8473 │ -0.4055 │ 0.8473 | ||
| # Print optimized natural parameters as expectation parameters. | ||
| print_generic(optimized_q.to_exp()) | ||
| # BernoulliEP | ||
| # Do the same in the expectation parametrization. | ||
| print_generic(predictive_distribution=predictive_distribution.to_exp(), | ||
| target_distribution=target_distribution) | ||
| # predictive_distribution=BernoulliEP[dataclass] | ||
| # └── probability=Jax Array (3,) float32 | ||
| # └── 0.3007 │ 0.4002 │ 0.6993 | ||
| # target_distribution=BernoulliEP[dataclass] | ||
| # └── probability=Jax Array (3,) float32 | ||
| # └── 0.3000 │ 0.4000 │ 0.7000 | ||
@@ -333,6 +404,16 @@ Maximum likelihood estimation | ||
| """Maximum likelihood estimation. | ||
| This example is based on section 1.3.2 from expfam.pdf, entitled Maximum | ||
| likelihood estimation. | ||
| Suppose you have some samples from a distribution family with unknown | ||
| parameters, and you want to estimate the maximum likelihood parmaters of the | ||
| distribution. | ||
| """ | ||
| import jax.numpy as jnp | ||
| import jax.random as jr | ||
| from efax import DirichletNP, parameter_mean | ||
| from efax import DirichletEP, DirichletNP, MaximumLikelihoodEstimator, parameter_mean | ||
| from tjax import print_generic | ||
@@ -350,4 +431,6 @@ # Consider a Dirichlet distribution with a given alpha. | ||
| # First, convert the samples to their sufficient statistics. | ||
| ss = DirichletNP.sufficient_statistics(samples) | ||
| # ss has type DirichletEP. This is similar to the conjguate prior of the Dirichlet distribution. | ||
| estimator = MaximumLikelihoodEstimator.create_simple_estimator(DirichletEP) | ||
| ss = estimator.sufficient_statistics(samples) | ||
| # ss has type DirichletEP. This is similar to the conjguate prior of the | ||
| # Dirichlet distribution. | ||
@@ -359,3 +442,10 @@ # Take the mean over the first axis. | ||
| estimated_distribution = ss_mean.to_nat() | ||
| print(estimated_distribution.alpha_minus_one + 1.0) # [1.9849904 3.0065458 3.963935 ] # noqa: T201 | ||
| print_generic(estimated_distribution=estimated_distribution, | ||
| source_distribution=source_distribution) | ||
| # estimated_distribution=DirichletNP[dataclass] | ||
| # └── alpha_minus_one=Jax Array (3,) float32 | ||
| # └── 0.9797 │ 1.9539 │ 2.9763 | ||
| # source_distribution=DirichletNP[dataclass] | ||
| # └── alpha_minus_one=Jax Array (3,) float32 | ||
| # └── 1.0000 │ 2.0000 │ 3.0000 | ||
@@ -362,0 +452,0 @@ Contribution guidelines |
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
956775
1.02%118
1.72%8027
0.51%