efax
Advanced tools
| """These tests ensure that our distributions match scipy's.""" |
| from __future__ import annotations | ||
| from numpy.random import Generator | ||
| from numpy.testing import assert_allclose | ||
| from efax import HasEntropyEP, HasEntropyNP | ||
| from ..create_info import ChiInfo, ChiSquareInfo | ||
| from ..distribution_info import DistributionInfo | ||
| def test_nat_entropy(generator: Generator, | ||
| entropy_distribution_info: DistributionInfo | ||
| ) -> None: | ||
| """Test that the entropy calculation matches scipy's.""" | ||
| shape = (7, 13) | ||
| nat_parameters = entropy_distribution_info.nat_parameter_generator(generator, shape=shape) | ||
| assert isinstance(nat_parameters, HasEntropyNP) | ||
| scipy_distribution = entropy_distribution_info.nat_to_scipy_distribution(nat_parameters) | ||
| rtol = 2e-5 | ||
| my_entropy = nat_parameters.entropy() | ||
| scipy_entropy = scipy_distribution.entropy() | ||
| assert_allclose(my_entropy, scipy_entropy, rtol=rtol) | ||
| def test_exp_entropy(generator: Generator, | ||
| entropy_distribution_info: DistributionInfo | ||
| ) -> None: | ||
| """Test that the entropy calculation matches scipy's.""" | ||
| shape = (7, 13) | ||
| exp_parameters = entropy_distribution_info.exp_parameter_generator(generator, shape=shape) | ||
| assert isinstance(exp_parameters, HasEntropyEP) | ||
| scipy_distribution = entropy_distribution_info.exp_to_scipy_distribution(exp_parameters) | ||
| rtol = (1e-5 | ||
| if isinstance(entropy_distribution_info, ChiInfo | ChiSquareInfo) | ||
| else 1e-6) | ||
| my_entropy = exp_parameters.entropy() | ||
| scipy_entropy = scipy_distribution.entropy() | ||
| assert_allclose(my_entropy, scipy_entropy, rtol=rtol) |
| from __future__ import annotations | ||
| from functools import partial | ||
| from typing import Any | ||
| import numpy as np | ||
| from jax import Array | ||
| from numpy.random import Generator | ||
| from tjax import assert_tree_allclose | ||
| from efax import (MaximumLikelihoodEstimator, NaturalParametrization, flat_dict_of_observations, | ||
| flat_dict_of_parameters, parameter_map, unflatten_mapping) | ||
| from ..create_info import (ComplexCircularlySymmetricNormalInfo, IsotropicNormalInfo, | ||
| MultivariateFixedVarianceNormalInfo, MultivariateNormalInfo, | ||
| VonMisesFisherInfo) | ||
| from ..distribution_info import DistributionInfo | ||
| def test_maximum_likelihood_estimation( | ||
| generator: Generator, | ||
| distribution_info: DistributionInfo[NaturalParametrization] | ||
| ) -> None: | ||
| """Test maximum likelihood estimation using SciPy. | ||
| Test that maximum likelihood estimation from scipy-generated variates produce the same | ||
| distribution from which they were drawn. | ||
| """ | ||
| rtol = 2e-2 | ||
| if isinstance(distribution_info, | ||
| ComplexCircularlySymmetricNormalInfo | MultivariateNormalInfo | ||
| | VonMisesFisherInfo | MultivariateFixedVarianceNormalInfo): | ||
| atol = 1e-2 | ||
| elif isinstance(distribution_info, IsotropicNormalInfo): | ||
| atol = 1e-3 | ||
| else: | ||
| atol = 1e-6 | ||
| n = 70000 | ||
| # Generate a distribution with expectation parameters. | ||
| exp_parameters = distribution_info.exp_parameter_generator(generator, shape=()) | ||
| # Generate variates from the corresponding scipy distribution. | ||
| scipy_distribution = distribution_info.exp_to_scipy_distribution(exp_parameters) | ||
| scipy_x = scipy_distribution.rvs(random_state=generator, size=n) | ||
| # Convert the variates to sufficient statistics. | ||
| efax_x = distribution_info.scipy_to_exp_family_observation(scipy_x) | ||
| flat_efax_x = flat_dict_of_observations(efax_x) | ||
| flat_parameters = flat_dict_of_parameters(exp_parameters) | ||
| flat_efax_x_clamped = {path: flat_parameters[path].domain_support().clamp(value) | ||
| for path, value in flat_efax_x.items()} | ||
| efax_x_clamped: Array | dict[str, Any] = (flat_efax_x_clamped[()] | ||
| if flat_efax_x_clamped.keys() == {()} | ||
| else unflatten_mapping(flat_efax_x_clamped)) | ||
| estimator = MaximumLikelihoodEstimator.create_estimator(exp_parameters) | ||
| sufficient_stats = estimator.sufficient_statistics(efax_x_clamped) | ||
| # Verify that the mean of the sufficient statistics equals the expectation parameters. | ||
| calculated_parameters = parameter_map(partial(np.mean, axis=0), # type: ignore[arg-type] | ||
| sufficient_stats) | ||
| assert_tree_allclose(exp_parameters, calculated_parameters, rtol=rtol, atol=atol) |
| from __future__ import annotations | ||
| from typing import Any | ||
| import numpy as np | ||
| from jax import Array | ||
| from numpy.random import Generator | ||
| from numpy.testing import assert_allclose | ||
| from tjax import JaxComplexArray | ||
| from efax import JointDistributionN, Multidimensional, NaturalParametrization, SimpleDistribution | ||
| from ..create_info import MultivariateDiagonalNormalInfo | ||
| from ..distribution_info import DistributionInfo | ||
| def _check_observation_shape(nat_parameters: NaturalParametrization, | ||
| efax_x: JaxComplexArray | dict[str, Any], | ||
| distribution_shape: tuple[int, ...], | ||
| ) -> None: | ||
| """Verify that the sufficient statistics have the right shape.""" | ||
| if isinstance(nat_parameters, JointDistributionN): | ||
| assert isinstance(efax_x, dict) | ||
| for name, value in nat_parameters.sub_distributions().items(): | ||
| _check_observation_shape(value, efax_x[name], distribution_shape) | ||
| return | ||
| assert isinstance(nat_parameters, SimpleDistribution) # type: ignore[unreachable] | ||
| assert isinstance(efax_x, Array) # type: ignore[unreachable] | ||
| dimensions = (nat_parameters.dimensions() | ||
| if isinstance(nat_parameters, Multidimensional) | ||
| else 0) | ||
| ideal_shape = distribution_shape + nat_parameters.domain_support().shape(dimensions) | ||
| assert efax_x.shape == ideal_shape | ||
| def test_pdf(generator: Generator, distribution_info: DistributionInfo) -> None: | ||
| """Test that the density/mass function calculation matches scipy's.""" | ||
| distribution_shape = (10,) | ||
| nat_parameters = distribution_info.nat_parameter_generator(generator, shape=distribution_shape) | ||
| scipy_distribution = distribution_info.nat_to_scipy_distribution(nat_parameters) | ||
| scipy_x = scipy_distribution.rvs(random_state=generator) | ||
| efax_x = distribution_info.scipy_to_exp_family_observation(scipy_x) | ||
| _check_observation_shape(nat_parameters, efax_x, distribution_shape) | ||
| # Verify that the density matches scipy. | ||
| efax_density = np.asarray(nat_parameters.pdf(efax_x), dtype=np.float64) | ||
| try: | ||
| scipy_density = scipy_distribution.pdf(scipy_x) | ||
| except AttributeError: | ||
| scipy_density = scipy_distribution.pmf(scipy_x) | ||
| if isinstance(distribution_info, MultivariateDiagonalNormalInfo): | ||
| atol = 1e-5 | ||
| rtol = 3e-4 | ||
| else: | ||
| atol = 1e-5 | ||
| rtol = 1e-4 | ||
| assert_allclose(efax_density, scipy_density, rtol=rtol, atol=atol) |
| from __future__ import annotations | ||
| import pytest | ||
| from numpy.random import Generator | ||
| from tjax import assert_tree_allclose, print_generic, tree_allclose | ||
| from efax import parameters | ||
| from .create_info import GeneralizedDirichletInfo | ||
| from .distribution_info import DistributionInfo | ||
| def test_conversion(generator: Generator, | ||
| distribution_info: DistributionInfo | ||
| ) -> None: | ||
| """Test that the conversion between the different parametrizations are consistent.""" | ||
| if isinstance(distribution_info, GeneralizedDirichletInfo): | ||
| pytest.skip() | ||
| n = 30 | ||
| shape = (n,) | ||
| original_np = distribution_info.nat_parameter_generator(generator, shape=shape) | ||
| intermediate_ep = original_np.to_exp() | ||
| final_np = intermediate_ep.to_nat() | ||
| # Check round trip. | ||
| if not tree_allclose(final_np, original_np): | ||
| for i in range(n): | ||
| if not tree_allclose(final_np[i], original_np[i]): | ||
| print_generic({"original": original_np[i], | ||
| "intermediate": intermediate_ep[i], | ||
| "final": final_np[i]}) | ||
| pytest.fail("Conversion failure") | ||
| # Check fixed parameters. | ||
| original_fixed = parameters(original_np, fixed=True) | ||
| intermediate_fixed = parameters(intermediate_ep, fixed=True) | ||
| final_fixed = parameters(final_np, fixed=True) | ||
| assert_tree_allclose(original_fixed, intermediate_fixed) | ||
| assert_tree_allclose(original_fixed, final_fixed) |
| from __future__ import annotations | ||
| from collections.abc import Callable | ||
| from typing import TypeAlias | ||
| from array_api_compat import array_namespace | ||
| from jax import grad, jvp, vjp | ||
| from jax.custom_derivatives import zero_from_primal | ||
| from numpy.random import Generator | ||
| from numpy.testing import assert_allclose | ||
| from tjax import JaxRealArray, assert_tree_allclose, jit | ||
| from efax import NaturalParametrization, Structure, parameters | ||
| from .distribution_info import DistributionInfo | ||
| _LogNormalizer: TypeAlias = Callable[[NaturalParametrization], JaxRealArray] | ||
| def _prelude(generator: Generator, | ||
| distribution_info: DistributionInfo | ||
| ) -> tuple[_LogNormalizer, _LogNormalizer]: | ||
| cls = distribution_info.nat_class() | ||
| original_ln = cls._original_log_normalizer | ||
| optimized_ln = cls.log_normalizer | ||
| return original_ln, optimized_ln | ||
| def test_primals(generator: Generator, distribution_info: DistributionInfo) -> None: | ||
| """Tests that the gradient log-normalizer equals the gradient of the log-normalizer.""" | ||
| original_ln, optimized_ln = _prelude(generator, distribution_info) | ||
| original_gln = jit(grad(original_ln, allow_int=True)) | ||
| optimized_gln = jit(grad(optimized_ln, allow_int=True)) | ||
| for _ in range(20): | ||
| generated_np = distribution_info.nat_parameter_generator(generator, shape=()) | ||
| generated_ep = generated_np.to_exp() # Regular transformation. | ||
| generated_parameters = parameters(generated_ep, fixed=False) | ||
| structure_ep = Structure.create(generated_ep) | ||
| # Original GLN. | ||
| original_gln_np = original_gln(generated_np) | ||
| original_gln_ep = structure_ep.reinterpret(original_gln_np) | ||
| original_gln_parameters = parameters(original_gln_ep, fixed=False) | ||
| # Optimized GLN. | ||
| optimized_gln_np = optimized_gln(generated_np) | ||
| optimized_gln_ep = structure_ep.reinterpret(optimized_gln_np) | ||
| optimized_gln_parameters = parameters(optimized_gln_ep, fixed=False) | ||
| # Test primal evaluation. | ||
| # parameters(generated_ep, fixed=False) | ||
| assert_tree_allclose(generated_parameters, original_gln_parameters, rtol=1e-5) | ||
| assert_tree_allclose(generated_parameters, optimized_gln_parameters, rtol=1e-5) | ||
| def _unit_tangent(nat_parameters: NaturalParametrization | ||
| ) -> NaturalParametrization: | ||
| xp = array_namespace(nat_parameters) | ||
| new_variable_parameters = {path: xp.ones_like(value) | ||
| for path, value in parameters(nat_parameters, fixed=False).items()} | ||
| new_fixed_parameters = {path: zero_from_primal(value, symbolic_zeros=False) | ||
| for path, value in parameters(nat_parameters, fixed=True).items()} | ||
| structure = Structure.create(nat_parameters) | ||
| return structure.assemble({**new_variable_parameters, **new_fixed_parameters}) | ||
| def test_jvp(generator: Generator, distribution_info: DistributionInfo) -> None: | ||
| """Tests that the gradient log-normalizer equals the gradient of the log-normalizer.""" | ||
| original_ln, optimized_ln = _prelude(generator, distribution_info) | ||
| for _ in range(20): | ||
| nat_parameters = distribution_info.nat_parameter_generator(generator, shape=()) | ||
| # Test JVP. | ||
| nat_tangent = _unit_tangent(nat_parameters) | ||
| original_ln_of_nat, original_jvp = jvp(original_ln, (nat_parameters,), (nat_tangent,)) | ||
| optimized_ln_of_nat, optimized_jvp = jvp(optimized_ln, (nat_parameters,), (nat_tangent,)) | ||
| assert_allclose(original_ln_of_nat, optimized_ln_of_nat, rtol=1.5e-5) | ||
| assert_allclose(original_jvp, optimized_jvp, rtol=1.5e-5) | ||
| def test_vjp(generator: Generator, distribution_info: DistributionInfo) -> None: | ||
| """Tests that the gradient log-normalizer equals the gradient of the log-normalizer.""" | ||
| original_ln, optimized_ln = _prelude(generator, distribution_info) | ||
| for _ in range(20): | ||
| nat_parameters = distribution_info.nat_parameter_generator(generator, shape=()) | ||
| nat_tangent = _unit_tangent(nat_parameters) | ||
| original_ln_of_nat, _ = jvp(original_ln, (nat_parameters,), (nat_tangent,)) | ||
| original_ln_of_nat_b, original_vjp = vjp(original_ln, nat_parameters) | ||
| original_gln_of_nat, = original_vjp(1.0) | ||
| optimized_ln_of_nat_b, optimized_vjp = vjp(optimized_ln, nat_parameters) | ||
| optimized_gln_of_nat, = optimized_vjp(1.0) | ||
| assert_allclose(original_ln_of_nat_b, optimized_ln_of_nat_b, rtol=1e-5) | ||
| assert_allclose(original_ln_of_nat, original_ln_of_nat_b, rtol=1e-5) | ||
| assert_tree_allclose(parameters(original_gln_of_nat, fixed=False), | ||
| parameters(optimized_gln_of_nat, fixed=False), | ||
| rtol=1e-5) |
| from __future__ import annotations | ||
| from typing import Any, Self | ||
| from typing import Self | ||
@@ -46,3 +46,3 @@ import jax.random as jr | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.logaddexp(self.log_odds, xp.asarray(0.0)) | ||
@@ -56,3 +56,3 @@ | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(x.shape) | ||
@@ -67,3 +67,3 @@ | ||
| def nat_to_probability(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| p = jss.expit(self.log_odds) | ||
@@ -74,3 +74,3 @@ final_p = 1.0 - p | ||
| def nat_to_surprisal(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| total_p = self.nat_to_probability() | ||
@@ -117,3 +117,3 @@ return -xp.log(total_p) | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -123,3 +123,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxBooleanArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -131,3 +131,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) | ||
| def conjugate_prior_distribution(self, n: JaxRealArray) -> BetaNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| reshaped_n = n[..., np.newaxis] | ||
@@ -138,3 +138,3 @@ return BetaNP(reshaped_n * xp.stack([self.probability, (1.0 - self.probability)], axis=-1)) | ||
| @override | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any] | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization | ||
| ) -> tuple[Self, JaxRealArray]: | ||
@@ -141,0 +141,0 @@ assert isinstance(cp, BetaNP) |
@@ -63,3 +63,3 @@ from __future__ import annotations | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.log(2.0 * x) - xp.square(x) * 0.5 | ||
@@ -66,0 +66,0 @@ |
@@ -6,2 +6,3 @@ from __future__ import annotations | ||
| import jax.random as jr | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxComplexArray, JaxRealArray, KeyArray, Shape, outer_product | ||
@@ -53,3 +54,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| log_det_s = xp.log(xp.real(xp.linalg.det(-self.negative_precision))) | ||
@@ -60,3 +61,3 @@ return -log_det_s + self.dimensions() * math.log(math.pi) | ||
| def to_exp(self) -> ComplexCircularlySymmetricNormalEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return ComplexCircularlySymmetricNormalEP( | ||
@@ -67,3 +68,3 @@ xp.conj(xp.linalg.inv(-self.negative_precision))) | ||
| def carrier_measure(self, x: JaxComplexArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(x.shape[:-1]) | ||
@@ -118,3 +119,3 @@ | ||
| def to_nat(self) -> ComplexCircularlySymmetricNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return ComplexCircularlySymmetricNormalNP(xp.conj(-xp.linalg.inv(self.variance))) | ||
@@ -124,3 +125,3 @@ | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -144,3 +145,3 @@ | ||
| """Return the mean of a corresponding real distribution with double the size.""" | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| n = self.dimensions() | ||
@@ -151,3 +152,3 @@ return xp.zeros((*self.shape, n * 2)) | ||
| """Return the covariance of a corresponding real distribution with double the size.""" | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| gamma_r = 0.5 * xp.real(self.variance) | ||
@@ -154,0 +155,0 @@ gamma_i = 0.5 * xp.imag(self.variance) |
@@ -6,2 +6,3 @@ from __future__ import annotations | ||
| import jax.random as jr | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxComplexArray, JaxRealArray, KeyArray, Shape, abs_square | ||
@@ -51,3 +52,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| mean_conjugate = self.two_mean_conjugate * 0.5 | ||
@@ -58,3 +59,3 @@ return xp.sum(abs_square(mean_conjugate), axis=-1) + self.dimensions() * math.log(math.pi) | ||
| def to_exp(self) -> ComplexMultivariateUnitVarianceNormalEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return ComplexMultivariateUnitVarianceNormalEP(xp.conj(self.two_mean_conjugate) * 0.5) | ||
@@ -64,3 +65,3 @@ | ||
| def carrier_measure(self, x: JaxComplexArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return -xp.sum(abs_square(x), axis=-1) | ||
@@ -114,3 +115,3 @@ | ||
| def to_nat(self) -> ComplexMultivariateUnitVarianceNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return ComplexMultivariateUnitVarianceNormalNP(xp.conj(self.mean) * 2.0) | ||
@@ -120,3 +121,3 @@ | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| # The second moment of a normal distribution with the given mean. | ||
@@ -127,3 +128,3 @@ return -(xp.sum(abs_square(self.mean), axis=-1) + self.dimensions()) | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxComplexArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.mean.shape if shape is None else shape + self.mean.shape | ||
@@ -130,0 +131,0 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape)) |
@@ -54,3 +54,3 @@ from __future__ import annotations | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| _, s, mu = self._r_s_mu() | ||
@@ -67,3 +67,3 @@ det_s = s | ||
| def to_exp(self) -> ComplexNormalEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| r, s, mu = self._r_s_mu() | ||
@@ -75,3 +75,3 @@ u = xp.conj(r * s) | ||
| def carrier_measure(self, x: JaxComplexArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(x.shape) | ||
@@ -87,3 +87,3 @@ | ||
| def _r_s_mu(self) -> tuple[JaxComplexArray, JaxRealArray, JaxComplexArray]: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| r = -self.pseudo_precision / self.negative_precision | ||
@@ -143,3 +143,3 @@ s = xp.reciprocal((abs_square(r) - 1.0) * self.negative_precision) | ||
| def to_nat(self) -> ComplexNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| variance = self.second_moment - abs_square(self.mean) | ||
@@ -159,3 +159,3 @@ pseudo_variance = self.pseudo_second_moment - xp.square(self.mean) | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -173,3 +173,3 @@ | ||
| def _multivariate_normal_cov(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| variance = self.second_moment - abs_square(self.mean) | ||
@@ -187,3 +187,3 @@ pseudo_variance = self.pseudo_second_moment - xp.square(self.mean) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| mn_mean = xp.stack([xp.real(self.mean), xp.imag(self.mean)], axis=-1) | ||
@@ -190,0 +190,0 @@ mn_cov = self._multivariate_normal_cov() |
@@ -6,2 +6,3 @@ from __future__ import annotations | ||
| import jax.random as jr | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxComplexArray, JaxRealArray, KeyArray, Shape, abs_square | ||
@@ -54,3 +55,3 @@ from tjax.dataclasses import dataclass | ||
| def to_exp(self) -> ComplexUnitVarianceNormalEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return ComplexUnitVarianceNormalEP(xp.conj(self.two_mean_conjugate) * 0.5) | ||
@@ -103,3 +104,3 @@ | ||
| def to_nat(self) -> ComplexUnitVarianceNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return ComplexUnitVarianceNormalNP(xp.conj(self.mean) * 2.0) | ||
@@ -114,3 +115,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxComplexArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -117,0 +118,0 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) |
@@ -7,2 +7,3 @@ from __future__ import annotations | ||
| import jax.scipy.special as jss | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxRealArray, KeyArray, Shape | ||
@@ -38,3 +39,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| q = self.alpha_minus_one | ||
@@ -55,3 +56,3 @@ return (xp.sum(jss.gammaln(q + 1.0), axis=-1) | ||
| def _exp_helper(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| q = self.alpha_minus_one | ||
@@ -79,3 +80,3 @@ return jss.digamma(q + 1.0) - jss.digamma(xp.sum(q, axis=-1, keepdims=True) + q.shape[-1]) | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -82,0 +83,0 @@ |
@@ -41,3 +41,3 @@ from __future__ import annotations | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(x.shape[:len(x.shape) - 1]) | ||
@@ -44,0 +44,0 @@ |
| from __future__ import annotations | ||
| from typing import Any, Self | ||
| from typing import Self | ||
| import jax.random as jr | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxRealArray, KeyArray, Shape | ||
@@ -41,3 +42,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return -xp.log(-self.negative_rate) | ||
@@ -47,3 +48,3 @@ | ||
| def to_exp(self) -> ExponentialEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return ExponentialEP(-xp.reciprocal(self.negative_rate)) | ||
@@ -53,3 +54,3 @@ | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(x.shape) | ||
@@ -65,3 +66,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -100,3 +101,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) | ||
| def to_nat(self) -> ExponentialNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return ExponentialNP(-xp.reciprocal(self.mean)) | ||
@@ -106,3 +107,3 @@ | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -112,3 +113,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -124,3 +125,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) | ||
| @override | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any] | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization | ||
| ) -> tuple[Self, JaxRealArray]: | ||
@@ -127,0 +128,0 @@ assert isinstance(cp, GammaNP) |
@@ -48,3 +48,3 @@ from __future__ import annotations | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape_minus_one + 1.0 | ||
@@ -55,3 +55,3 @@ return jss.gammaln(shape) - shape * xp.log(-self.negative_rate) | ||
| def to_exp(self) -> GammaEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape_minus_one + 1.0 | ||
@@ -63,3 +63,3 @@ return GammaEP(-shape / self.negative_rate, | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(x.shape) | ||
@@ -76,3 +76,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -90,3 +90,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) | ||
| def to_approximate_log_normal(self) -> LogNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape_minus_one + 1.0 | ||
@@ -133,3 +133,3 @@ rate = -self.negative_rate | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -149,3 +149,3 @@ | ||
| def search_gradient(self, search_parameters: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = softplus(search_parameters[..., 0]) | ||
@@ -159,3 +159,3 @@ log_mean_minus_mean_log = xp.log(self.mean) - self.mean_log | ||
| def initial_search_parameters(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| log_mean_minus_mean_log = xp.log(self.mean) - self.mean_log | ||
@@ -162,0 +162,0 @@ initial_shape: JaxRealArray = ( |
@@ -44,3 +44,3 @@ """The generalized Dirichlet distribution. | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| alpha, beta = self.alpha_beta() | ||
@@ -56,3 +56,3 @@ return xp.sum(jss.betaln(alpha, beta), axis=-1) | ||
| # = jss.digamma(beta) - jss.digamma(alpha + beta) | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| alpha, beta = self.alpha_beta() | ||
@@ -79,3 +79,3 @@ digamma_sum = jss.digamma(alpha + beta) | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(x.shape[:len(x.shape) - 1]) | ||
@@ -88,3 +88,3 @@ | ||
| def alpha_beta(self) -> tuple[JaxRealArray, JaxRealArray]: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| alpha = self.alpha_minus_one + 1.0 | ||
@@ -135,3 +135,3 @@ # cs_alpha[i] = sum_{j>=i} alpha[j] | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -138,0 +138,0 @@ |
| from __future__ import annotations | ||
| import jax.random as jr | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxRealArray, KeyArray, Shape | ||
@@ -77,3 +78,3 @@ from tjax.dataclasses import dataclass | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.mean.shape) | ||
@@ -85,3 +86,3 @@ | ||
| shape += self.shape | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| p = xp.reciprocal(self.mean) | ||
@@ -88,0 +89,0 @@ return jr.geometric(key, p, shape) |
@@ -59,3 +59,3 @@ from __future__ import annotations | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return -2.0 * xp.log(x) | ||
@@ -65,3 +65,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.reciprocal(self.base_distribution().sample(key, shape)) | ||
@@ -109,3 +109,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.reciprocal(self.base_distribution().sample(key, shape)) |
@@ -44,3 +44,3 @@ from __future__ import annotations | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return (-0.5 * xp.log(-2.0 * self.negative_lambda_over_two) | ||
@@ -52,3 +52,3 @@ - 2.0 * xp.sqrt(self.negative_lambda_over_two_mu_squared | ||
| def to_exp(self) -> InverseGaussianEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| mu = xp.sqrt(self.negative_lambda_over_two / self.negative_lambda_over_two_mu_squared) | ||
@@ -61,3 +61,3 @@ lambda_ = -2.0 * self.negative_lambda_over_two | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return -0.5 * (xp.log(2.0 * xp.pi) + 3.0 * xp.log(x)) | ||
@@ -84,3 +84,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -136,3 +136,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) | ||
| def to_nat(self) -> InverseGaussianNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| lambda_ = xp.reciprocal(self.mean_reciprocal - xp.reciprocal(self.mean)) | ||
@@ -139,0 +139,0 @@ eta2 = -0.5 * lambda_ |
@@ -64,3 +64,3 @@ from __future__ import annotations | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return -xp.log(x) | ||
@@ -70,3 +70,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -121,3 +121,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -124,0 +124,0 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) |
| from __future__ import annotations | ||
| from typing import Any, Self | ||
| from typing import Self | ||
@@ -65,3 +65,3 @@ import jax.random as jr | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return -0.5 * xp.square(xp.log(x)) - xp.log(x) | ||
@@ -111,3 +111,3 @@ | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return -self.mean - 0.5 * (xp.square(self.mean) + 1.0) | ||
@@ -117,3 +117,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -130,3 +130,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) | ||
| @override | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any] | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization | ||
| ) -> tuple[Self, JaxRealArray]: | ||
@@ -133,0 +133,0 @@ uvn, n = UnitVarianceNormalEP.from_conjugate_prior_distribution(cp) |
| from __future__ import annotations | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxRealArray, Shape | ||
@@ -40,3 +41,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.log(-xp.log1p(-xp.exp(self.log_probability))) | ||
@@ -46,3 +47,3 @@ | ||
| def to_exp(self) -> LogarithmicEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| probability = xp.exp(self.log_probability) | ||
@@ -59,3 +60,3 @@ chi = xp.where(self.log_probability < log_probability_floor, | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return -xp.log(x) | ||
@@ -100,3 +101,3 @@ | ||
| def to_nat(self) -> LogarithmicNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| z: LogarithmicNP = super().to_nat() | ||
@@ -103,0 +104,0 @@ return LogarithmicNP(xp.where(self.chi < 1.0, |
| from __future__ import annotations | ||
| from typing import Any, Self | ||
| from typing import Self | ||
@@ -9,3 +9,3 @@ import array_api_extra as xpx | ||
| import scipy.special as sc | ||
| from jax.nn import one_hot | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxRealArray, KeyArray, Shape | ||
@@ -48,3 +48,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| max_q = xp.maximum(0.0, xp.max(self.log_odds, axis=-1)) | ||
@@ -57,3 +57,3 @@ q_minus_max_q = self.log_odds - max_q[..., np.newaxis] | ||
| def to_exp(self) -> MultinomialEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| max_q = xp.maximum(0.0, xp.max(self.log_odds, axis=-1)) | ||
@@ -66,3 +66,3 @@ q_minus_max_q = self.log_odds - max_q[..., np.newaxis] | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(x.shape[:-1]) | ||
@@ -80,4 +80,5 @@ | ||
| shape += self.shape | ||
| return one_hot(jr.categorical(key, self.log_odds, shape=shape), | ||
| self.dimensions()) | ||
| retval = xpx.one_hot(jr.categorical(key, self.log_odds, shape=shape), self.dimensions()) | ||
| assert isinstance(retval, JaxRealArray) | ||
| return retval | ||
@@ -89,3 +90,3 @@ @override | ||
| def nat_to_probability(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| max_q = xp.maximum(0.0, xp.max(self.log_odds, axis=-1)) | ||
@@ -99,3 +100,3 @@ q_minus_max_q = self.log_odds - max_q[..., np.newaxis] | ||
| def nat_to_surprisal(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| total_p = self.nat_to_probability() | ||
@@ -135,3 +136,3 @@ return -xp.log(total_p) | ||
| def to_nat(self) -> MultinomialNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| p_k = 1.0 - xp.sum(self.probability, axis=-1, keepdims=True) | ||
@@ -142,3 +143,3 @@ return MultinomialNP(xp.log(self.probability / p_k)) | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -148,3 +149,3 @@ | ||
| def conjugate_prior_distribution(self, n: JaxRealArray) -> DirichletNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| reshaped_n = n[..., np.newaxis] | ||
@@ -156,3 +157,3 @@ final_p = 1.0 - xp.sum(self.probability, axis=-1, keepdims=True) | ||
| @override | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any], | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization, | ||
| ) -> tuple[Self, JaxRealArray]: | ||
@@ -164,3 +165,3 @@ assert isinstance(cp, GeneralizedDirichletNP) | ||
| def generalized_conjugate_prior_distribution(self, n: JaxRealArray) -> GeneralizedDirichletNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| final_p = 1.0 - xp.sum(self.probability, axis=-1, keepdims=True) | ||
@@ -167,0 +168,0 @@ all_p = xp.concat((self.probability, final_p), axis=-1) |
@@ -7,2 +7,3 @@ from __future__ import annotations | ||
| import jax.random as jr | ||
| from array_api_compat import array_namespace | ||
| from tjax import (JaxArray, JaxRealArray, KeyArray, Shape, matrix_dot_product, matrix_vector_mul, | ||
@@ -47,3 +48,3 @@ outer_product) | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| eta = self.mean_times_precision | ||
@@ -60,3 +61,3 @@ h_inv = xp.linalg.inv(self.negative_half_precision) | ||
| def variance(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| h_inv: JaxRealArray = xp.linalg.inv(self.negative_half_precision) | ||
@@ -76,3 +77,3 @@ return -0.5 * h_inv | ||
| def to_exp(self) -> MultivariateNormalEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| h_inv = xp.linalg.inv(self.negative_half_precision) | ||
@@ -87,3 +88,3 @@ h_inv = cast('JaxRealArray', h_inv) | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(x.shape[:-1]) | ||
@@ -127,3 +128,3 @@ | ||
| def to_nat(self) -> MultivariateNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| precision = xp.linalg.inv(self.variance()) | ||
@@ -136,3 +137,3 @@ precision = cast('JaxRealArray', precision) | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -139,0 +140,0 @@ |
@@ -48,3 +48,3 @@ from __future__ import annotations | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| components = (-xp.square(self.mean_times_precision) / (4.0 * self.negative_half_precision) | ||
@@ -56,3 +56,3 @@ + 0.5 * xp.log(-np.pi / self.negative_half_precision)) | ||
| def to_exp(self) -> MultivariateDiagonalNormalEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| mean = -self.mean_times_precision / (2.0 * self.negative_half_precision) | ||
@@ -64,3 +64,3 @@ second_moment = xp.square(mean) - 0.5 / self.negative_half_precision | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(x.shape[:-1]) | ||
@@ -121,3 +121,3 @@ | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -134,3 +134,3 @@ | ||
| def variance(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return self.second_moment - xp.square(self.mean) | ||
@@ -165,3 +165,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.mean.shape if shape is None else shape + self.mean.shape | ||
@@ -180,3 +180,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape)) | ||
| def to_exp(self) -> MultivariateDiagonalNormalEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| second_moment = self.variance + xp.square(self.mean) | ||
@@ -186,3 +186,3 @@ return MultivariateDiagonalNormalEP(self.mean, second_moment) | ||
| def to_nat(self) -> MultivariateDiagonalNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| precision = xp.reciprocal(self.variance) | ||
@@ -189,0 +189,0 @@ mean_times_precision = self.mean * precision |
| from __future__ import annotations | ||
| import math | ||
| from typing import Any, Self | ||
| from typing import Self | ||
| import jax.random as jr | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxRealArray, KeyArray, Shape | ||
@@ -51,3 +52,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| eta = self.mean_times_precision | ||
@@ -59,3 +60,3 @@ return 0.5 * (xp.sum(xp.square(eta), axis=-1) * self.variance | ||
| def to_exp(self) -> MultivariateFixedVarianceNormalEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return MultivariateFixedVarianceNormalEP( | ||
@@ -67,3 +68,3 @@ self.mean_times_precision * self.variance[..., xp.newaxis], | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return -0.5 * xp.sum(xp.square(x), axis=-1) / self.variance | ||
@@ -121,3 +122,3 @@ | ||
| def to_nat(self) -> MultivariateFixedVarianceNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return MultivariateFixedVarianceNormalNP(self.mean / self.variance[..., xp.newaxis], | ||
@@ -128,3 +129,3 @@ variance=self.variance) | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return -0.5 * (xp.sum(xp.square(self.mean), axis=-1) / self.variance + self.dimensions()) | ||
@@ -134,3 +135,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.mean.shape if shape is None else shape + self.mean.shape | ||
@@ -143,3 +144,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape)) | ||
| def conjugate_prior_distribution(self, n: JaxRealArray) -> IsotropicNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| n_over_variance = n / self.variance | ||
@@ -152,3 +153,3 @@ negative_half_precision = -0.5 * n_over_variance | ||
| @override | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any], | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization, | ||
| variance: JaxRealArray | None = None | ||
@@ -158,3 +159,3 @@ ) -> tuple[Self, JaxRealArray]: | ||
| assert variance is not None | ||
| xp = cp.array_namespace() | ||
| xp = array_namespace(cp) | ||
| n_over_variance = -2.0 * cp.negative_half_precision | ||
@@ -168,3 +169,3 @@ n = n_over_variance * variance | ||
| ) -> MultivariateDiagonalNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| n_over_variance = n / self.variance[..., xp.newaxis] | ||
@@ -171,0 +172,0 @@ negative_half_precision = -0.5 * n_over_variance |
@@ -44,3 +44,3 @@ from __future__ import annotations | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| eta = self.mean_times_precision | ||
@@ -52,3 +52,3 @@ return 0.5 * (-0.5 * xp.sum(xp.square(eta), axis=-1) / self.negative_half_precision | ||
| def to_exp(self) -> IsotropicNormalEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| precision = -2.0 * self.negative_half_precision | ||
@@ -111,3 +111,3 @@ mean = self.mean_times_precision / precision[..., xp.newaxis] | ||
| def to_nat(self) -> IsotropicNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| variance = self.variance() | ||
@@ -120,3 +120,3 @@ negative_half_precision = -0.5 / variance | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -126,3 +126,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.mean.shape if shape is None else shape + self.mean.shape | ||
@@ -138,4 +138,4 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape)) | ||
| def variance(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| dimensions = self.dimensions() | ||
| return (self.total_second_moment - xp.sum(xp.square(self.mean), axis=-1)) / dimensions |
| from __future__ import annotations | ||
| import math | ||
| from typing import Any, Self | ||
| from typing import Self | ||
| import jax.random as jr | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxRealArray, KeyArray, Shape | ||
@@ -48,3 +49,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return 0.5 * (xp.sum(xp.square(self.mean), axis=-1) | ||
@@ -60,3 +61,3 @@ + self.dimensions() * math.log(math.pi * 2.0)) | ||
| # The second moment of a delta distribution at x. | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return -0.5 * xp.sum(xp.square(x), axis=-1) | ||
@@ -72,3 +73,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.mean.shape if shape is None else shape + self.mean.shape | ||
@@ -120,3 +121,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape)) | ||
| # The second moment of a normal distribution with the given mean. | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return -0.5 * (xp.sum(xp.square(self.mean), axis=-1) + self.dimensions()) | ||
@@ -130,3 +131,3 @@ | ||
| def conjugate_prior_distribution(self, n: JaxRealArray) -> IsotropicNormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| negative_half_precision = -0.5 * n | ||
@@ -137,6 +138,6 @@ return IsotropicNormalNP(n[..., xp.newaxis] * self.mean, negative_half_precision) | ||
| @override | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any] | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization | ||
| ) -> tuple[Self, JaxRealArray]: | ||
| assert isinstance(cp, IsotropicNormalNP) | ||
| xp = cp.array_namespace() | ||
| xp = array_namespace(cp) | ||
| n = -2.0 * cp.negative_half_precision | ||
@@ -143,0 +144,0 @@ mean = cp.mean_times_precision / n[..., xp.newaxis] |
@@ -7,2 +7,3 @@ from __future__ import annotations | ||
| import jax.scipy.special as jss | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxIntegralArray, JaxRealArray, Shape | ||
@@ -37,3 +38,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return -self._failures() * xp.log1p(-xp.exp(self.log_not_p)) | ||
@@ -48,3 +49,3 @@ | ||
| def _mean(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return self._failures() / xp.expm1(-self.log_not_p) | ||
@@ -79,3 +80,3 @@ | ||
| def _log_not_p(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return -xp.log1p(self._failures() / self.mean) | ||
@@ -82,0 +83,0 @@ |
@@ -44,3 +44,3 @@ from __future__ import annotations | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return (-xp.square(self.mean_times_precision) / (4.0 * self.negative_half_precision) | ||
@@ -51,3 +51,3 @@ + 0.5 * xp.log(-np.pi / self.negative_half_precision)) | ||
| def to_exp(self) -> NormalEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| mean = -self.mean_times_precision / (2.0 * self.negative_half_precision) | ||
@@ -59,3 +59,3 @@ second_moment = xp.square(mean) - 0.5 / self.negative_half_precision | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(x.shape) | ||
@@ -115,3 +115,3 @@ | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -124,3 +124,3 @@ | ||
| def variance(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return self.second_moment - xp.square(self.mean) | ||
@@ -167,3 +167,3 @@ | ||
| def to_exp(self) -> NormalEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| second_moment = self.variance + xp.square(self.mean) | ||
@@ -173,3 +173,3 @@ return NormalEP(self.mean, second_moment) | ||
| def to_nat(self) -> NormalNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| precision = xp.reciprocal(self.variance) | ||
@@ -181,3 +181,3 @@ mean_times_precision = self.mean * precision | ||
| def to_deviation_parametrization(self) -> NormalDP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return NormalDP(self.mean, xp.sqrt(self.variance)) | ||
@@ -209,3 +209,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -228,3 +228,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) | ||
| def to_variance_parametrization(self) -> NormalVP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return NormalVP(self.mean, xp.square(self.deviation)) |
| from __future__ import annotations | ||
| import math | ||
| from typing import Any, Self | ||
| from typing import Self | ||
| import jax.random as jr | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxRealArray, KeyArray, Shape | ||
@@ -44,3 +45,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return 0.5 * (xp.square(self.mean) + math.log(math.pi * 2.0)) | ||
@@ -55,3 +56,3 @@ | ||
| # The second moment of a delta distribution at x. | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return -0.5 * xp.square(x) | ||
@@ -106,3 +107,3 @@ | ||
| # The second moment of a normal distribution with the given mean. | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return -0.5 * (xp.square(self.mean) + 1.0) | ||
@@ -112,3 +113,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -125,3 +126,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) | ||
| @override | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any] | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization | ||
| ) -> tuple[Self, JaxRealArray]: | ||
@@ -128,0 +129,0 @@ assert isinstance(cp, NormalNP) |
| from __future__ import annotations | ||
| from typing import Any, Self | ||
| from typing import Self | ||
| import jax.random as jr | ||
| import jax.scipy.special as jss | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxRealArray, KeyArray, Shape | ||
@@ -43,3 +44,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.exp(self.log_mean) | ||
@@ -49,3 +50,3 @@ | ||
| def to_exp(self) -> PoissonEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return PoissonEP(xp.exp(self.log_mean)) | ||
@@ -97,3 +98,3 @@ | ||
| def to_nat(self) -> PoissonNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return PoissonNP(xp.log(self.mean)) | ||
@@ -115,3 +116,3 @@ | ||
| @override | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any] | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization | ||
| ) -> tuple[Self, JaxRealArray]: | ||
@@ -118,0 +119,0 @@ assert isinstance(cp, GammaNP) |
@@ -64,3 +64,3 @@ from __future__ import annotations | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.log(x) + math.log(2.0) | ||
@@ -72,3 +72,3 @@ | ||
| shape += self.shape | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| sigma = xp.sqrt(-0.5 / self.eta) | ||
@@ -115,3 +115,3 @@ return jr.rayleigh(key, sigma, shape) | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return 0.5 * xp.log(self.chi * 0.5) + (1.5 * math.log(2.0) - 0.5 * np.euler_gamma) | ||
@@ -123,4 +123,4 @@ | ||
| shape += self.shape | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| sigma = xp.sqrt(0.5 * self.chi) | ||
| return jr.rayleigh(key, sigma, shape) |
@@ -6,2 +6,3 @@ from __future__ import annotations | ||
| import jax.random as jr | ||
| from array_api_compat import array_namespace | ||
| from tjax import Array, JaxArray, JaxRealArray, KeyArray, Shape, inverse_softplus, softplus | ||
@@ -60,3 +61,3 @@ from tjax.dataclasses import dataclass | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return -xp.log1p(-xp.exp(-x)) | ||
@@ -66,3 +67,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -108,3 +109,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -111,0 +112,0 @@ grow = (xp.newaxis,) * (len(shape) - len(self.mean.shape)) |
| from __future__ import annotations | ||
| from typing import Any, Self, cast | ||
| from typing import Self, cast | ||
| import jax.random as jr | ||
| from array_api_compat import array_namespace | ||
| from tjax import Array, JaxArray, JaxRealArray, KeyArray, Shape, inverse_softplus, softplus | ||
@@ -61,3 +62,3 @@ from tjax.dataclasses import dataclass | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return -xp.log1p(-xp.exp(-x)) - 0.5 * xp.square(inverse_softplus(x)) | ||
@@ -106,3 +107,3 @@ | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -119,3 +120,3 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) | ||
| @override | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any] | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization | ||
| ) -> tuple[Self, JaxRealArray]: | ||
@@ -122,0 +123,0 @@ uvn, n = UnitVarianceNormalEP.from_conjugate_prior_distribution(cp) |
@@ -6,2 +6,3 @@ from __future__ import annotations | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxRealArray, Shape, inverse_softplus, softplus | ||
@@ -42,3 +43,3 @@ from tjax.dataclasses import dataclass | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| half_k = self.dimensions() * 0.5 | ||
@@ -53,3 +54,3 @@ kappa = xp.linalg.vector_norm(self.mean_times_concentration, axis=-1) | ||
| def to_exp(self) -> VonMisesFisherEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| q = self.mean_times_concentration | ||
@@ -64,3 +65,3 @@ kappa: JaxRealArray = xp.linalg.vector_norm(q, axis=-1, keepdims=True) | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return xp.zeros(self.shape) | ||
@@ -79,3 +80,3 @@ | ||
| def kappa(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.linalg.vector_norm(self.mean_times_concentration, axis=-1) | ||
@@ -86,3 +87,3 @@ | ||
| raise ValueError | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| kappa = self.kappa() | ||
@@ -124,3 +125,3 @@ angle = xp.where(kappa == 0.0, | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros(self.shape) | ||
@@ -130,3 +131,3 @@ | ||
| def initial_search_parameters(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| mu: JaxRealArray = xp.linalg.vector_norm(self.mean, axis=-1) | ||
@@ -142,3 +143,3 @@ # 0 <= mu <= 1.0 | ||
| def search_to_natural(self, search_parameters: JaxRealArray) -> VonMisesFisherNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| kappa = softplus(search_parameters) | ||
@@ -151,3 +152,3 @@ mu = xp.linalg.vector_norm(self.mean, axis=-1, keepdims=True) | ||
| def search_gradient(self, search_parameters: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| kappa = softplus(search_parameters) | ||
@@ -154,0 +155,0 @@ mu = xp.linalg.vector_norm(self.mean, axis=-1) |
@@ -44,3 +44,3 @@ from __future__ import annotations | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return -xp.log(-self.eta) - xp.log(self.concentration) | ||
@@ -50,3 +50,3 @@ | ||
| def to_exp(self) -> WeibullEP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return WeibullEP(self.concentration, -xp.reciprocal(self.eta)) | ||
@@ -56,3 +56,3 @@ | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| return (self.concentration - 1.0) * xp.log(x) | ||
@@ -106,3 +106,3 @@ | ||
| def to_nat(self) -> WeibullNP: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return WeibullNP(self.concentration, -xp.reciprocal(self.chi)) | ||
@@ -112,3 +112,3 @@ | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| k = self.concentration | ||
@@ -120,3 +120,3 @@ one_minus_one_over_k = 1.0 - xp.reciprocal(k) | ||
| def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| shape = self.shape if shape is None else shape + self.shape | ||
@@ -123,0 +123,0 @@ grow = (xp.newaxis,) * (len(shape) - len(self.shape)) |
| from __future__ import annotations | ||
| from abc import abstractmethod | ||
| from typing import Any, Generic, TypeVar, final | ||
| from typing import Any, Generic, final | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxRealArray, jit | ||
| from typing_extensions import TypeVar | ||
@@ -12,3 +14,3 @@ from .natural_parametrization import NaturalParametrization | ||
| NP = TypeVar('NP', bound=NaturalParametrization[Any, Any]) | ||
| NP = TypeVar('NP', bound=NaturalParametrization, default=Any) | ||
@@ -49,3 +51,3 @@ | ||
| self_nat = self.to_nat() | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| difference = parameter_map(xp.subtract, self_nat, q) | ||
@@ -52,0 +54,0 @@ return (parameter_dot_product(difference, self) |
| from __future__ import annotations | ||
| from abc import abstractmethod | ||
| from typing import Any, Self | ||
| from typing import Self | ||
@@ -13,6 +13,6 @@ from tjax import JaxComplexArray, JaxRealArray | ||
| class HasConjugatePrior(ExpectationParametrization[Any]): | ||
| class HasConjugatePrior(ExpectationParametrization): | ||
| @abstractmethod | ||
| def conjugate_prior_distribution(self, n: JaxRealArray | ||
| ) -> NaturalParametrization[Any, Any]: | ||
| ) -> NaturalParametrization: | ||
| """The conjugate prior distribution. | ||
@@ -27,3 +27,3 @@ | ||
| @abstractmethod | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization[Any, Any] | ||
| def from_conjugate_prior_distribution(cls, cp: NaturalParametrization | ||
| ) -> tuple[Self, JaxRealArray]: | ||
@@ -49,3 +49,3 @@ """Given a conjugate prior distribution, find the distribution and observation count. | ||
| def generalized_conjugate_prior_distribution(self, n: JaxRealArray | ||
| ) -> NaturalParametrization[Any, Any]: | ||
| ) -> NaturalParametrization: | ||
| """A generalization of the conjugate prior distribution. | ||
@@ -52,0 +52,0 @@ |
@@ -6,2 +6,3 @@ from collections.abc import Iterable, Mapping | ||
| from tjax import JaxComplexArray | ||
| from typing_extensions import TypeIs | ||
@@ -13,2 +14,8 @@ from .parameter import Support | ||
| def is_string_mapping(x: object) -> TypeIs[Mapping[str, object]]: | ||
| if not isinstance(x, Mapping): | ||
| return False | ||
| return all(isinstance(xi, str) for xi in x) | ||
| def flatten_mapping(m: Mapping[str, Any], /) -> dict[Path, Any]: | ||
@@ -18,8 +25,10 @@ """Flatten a nested mapping.""" | ||
| def _flatten(m: Mapping[str, Mapping[str, Any] | Any], prefix: Path) -> None: | ||
| def _flatten(m: Mapping[str, Any], prefix: Path) -> None: | ||
| for key, value in m.items(): | ||
| path = (*prefix, key) | ||
| if isinstance(value, Mapping): | ||
| if is_string_mapping(value): | ||
| _flatten(value, path) | ||
| continue | ||
| if isinstance(value, Mapping): | ||
| raise TypeError | ||
| result[path] = value | ||
@@ -26,0 +35,0 @@ _flatten(m, ()) |
| from __future__ import annotations | ||
| from dataclasses import KW_ONLY, field | ||
| from typing import Any, Generic, Self, TypeAlias, TypeVar | ||
| from typing import Any, Generic, Self, TypeAlias | ||
| from array_api_compat import array_namespace | ||
| from jax import vmap | ||
| from tjax import JaxRealArray, jit | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
| from typing_extensions import TypeVar, override | ||
@@ -16,3 +17,3 @@ from ...expectation_parametrization import ExpectationParametrization | ||
| NP = TypeVar('NP', bound=NaturalParametrization[Any, Any]) | ||
| NP = TypeVar('NP', bound=NaturalParametrization, default=Any) | ||
| SP: TypeAlias = JaxRealArray | ||
@@ -22,3 +23,3 @@ | ||
| class ExpToNatMinimizer: | ||
| def solve(self, exp_to_nat: ExpToNat[Any]) -> SP: | ||
| def solve(self, exp_to_nat: ExpToNat) -> SP: | ||
| raise NotImplementedError | ||
@@ -68,3 +69,3 @@ | ||
| _, flattened = Flattener.flatten(self) | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.zeros_like(flattened) | ||
@@ -71,0 +72,0 @@ |
| from typing import Any, TypeAlias, TypeVar | ||
| import optimistix as optx | ||
| from array_api_compat import array_namespace | ||
| from jax import jit | ||
@@ -27,4 +28,4 @@ from tjax import JaxRealArray | ||
| @override | ||
| def solve(self, exp_to_nat: ExpToNat[Any]) -> JaxRealArray: | ||
| xp = exp_to_nat.array_namespace() | ||
| def solve(self, exp_to_nat: ExpToNat) -> JaxRealArray: | ||
| xp = array_namespace(exp_to_nat) | ||
@@ -31,0 +32,0 @@ @jit |
| from __future__ import annotations | ||
| from abc import abstractmethod | ||
| from typing import Any, Generic, TypeVar, final, override | ||
| from typing import Any, Generic, final, override | ||
| from jax.lax import stop_gradient | ||
| from tjax import JaxAbstractClass, JaxRealArray, abstract_jit, jit | ||
| from tjax import JaxAbstractClass, JaxRealArray, abstract_jit, jit, stop_gradient | ||
| from typing_extensions import TypeVar | ||
@@ -14,3 +14,3 @@ from ..expectation_parametrization import ExpectationParametrization | ||
| NP = TypeVar('NP', bound=NaturalParametrization[Any, Any]) | ||
| NP = TypeVar('NP', bound=NaturalParametrization, default=Any) | ||
@@ -60,6 +60,6 @@ | ||
| EP = TypeVar('EP', bound=HasEntropyEP[Any]) | ||
| EP = TypeVar('EP', bound=HasEntropyEP, default=Any) | ||
| class HasEntropyNP(NaturalParametrization[EP, Any], | ||
| class HasEntropyNP(NaturalParametrization[EP], | ||
| HasEntropy, | ||
@@ -66,0 +66,0 @@ Generic[EP]): |
@@ -5,7 +5,8 @@ from __future__ import annotations | ||
| from functools import partial | ||
| from typing import Any, Generic, TypeVar, cast | ||
| from typing import Any, Generic, cast | ||
| from array_api_compat import array_namespace | ||
| from jax import jacobian, vmap | ||
| from tjax import JaxArray, JaxComplexArray, JaxRealArray, Shape | ||
| from typing_extensions import override | ||
| from typing_extensions import TypeVar, override | ||
@@ -16,5 +17,5 @@ from ..expectation_parametrization import ExpectationParametrization | ||
| TEP = TypeVar('TEP', bound=ExpectationParametrization[Any]) | ||
| NP = TypeVar('NP', bound=NaturalParametrization[Any, Any]) | ||
| Domain = TypeVar('Domain', bound=JaxComplexArray) | ||
| TEP = TypeVar('TEP', bound=ExpectationParametrization, default=Any) | ||
| NP = TypeVar('NP', bound=NaturalParametrization, default=Any) | ||
| Domain = TypeVar('Domain', bound=JaxComplexArray, default=JaxComplexArray) | ||
@@ -61,3 +62,3 @@ | ||
| def carrier_measure(self, x: JaxRealArray) -> JaxRealArray: | ||
| xp = self.array_namespace(x) | ||
| xp = array_namespace(self, x) | ||
| casted_x = cast('Domain', x) | ||
@@ -95,3 +96,3 @@ fixed_parameters = parameters(self, fixed=True, recurse=False) | ||
| TNP = TypeVar('TNP', bound=TransformedNaturalParametrization[Any, Any, Any, Any]) | ||
| TNP = TypeVar('TNP', bound=TransformedNaturalParametrization, default=Any) | ||
@@ -98,0 +99,0 @@ |
| from __future__ import annotations | ||
| from abc import abstractmethod | ||
| from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar, final, get_type_hints | ||
| from typing import TYPE_CHECKING, Any, Generic, Self, final, get_type_hints | ||
| from array_api_compat import array_namespace | ||
| from jax import grad, jacfwd, vjp, vmap | ||
@@ -10,2 +11,3 @@ from tjax import (JaxAbstractClass, JaxArray, JaxComplexArray, JaxRealArray, abstract_custom_jvp, | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import TypeVar | ||
@@ -23,9 +25,9 @@ from .iteration import parameters | ||
| EP = TypeVar('EP', bound='ExpectationParametrization[Any]') | ||
| Domain = TypeVar('Domain', bound=JaxComplexArray | dict[str, Any]) | ||
| EP = TypeVar('EP', bound='ExpectationParametrization', default=Any) | ||
| Domain = TypeVar('Domain', bound=JaxComplexArray | dict[str, Any], default=Any) | ||
| def log_normalizer_jvp(primals: tuple[NaturalParametrization[Any, Any]], | ||
| tangents: tuple[NaturalParametrization[Any, Any]], | ||
| ) -> tuple[JaxRealArray, JaxRealArray]: | ||
| def _log_normalizer_jvp(primals: tuple[NaturalParametrization], | ||
| tangents: tuple[NaturalParametrization], | ||
| ) -> tuple[JaxRealArray, JaxRealArray]: | ||
| """The log-normalizer's special JVP vastly improves numerical stability.""" | ||
@@ -49,3 +51,3 @@ q, = primals | ||
| """ | ||
| @abstract_custom_jvp(log_normalizer_jvp) | ||
| @abstract_custom_jvp(_log_normalizer_jvp) | ||
| @abstract_jit | ||
@@ -100,3 +102,3 @@ @abstractmethod | ||
| """ | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| return xp.exp(self.log_pdf(x)) | ||
@@ -125,3 +127,3 @@ | ||
| """ | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| flattener, _ = Flattener.flatten(self) | ||
@@ -141,3 +143,3 @@ fisher_matrix = self._fisher_information_matrix() | ||
| """ | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| fisher_information_diagonal = self.fisher_information_diagonal() | ||
@@ -161,3 +163,3 @@ structure = Structure.create(self) | ||
| def jeffreys_prior(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| fisher_matrix = self._fisher_information_matrix() | ||
@@ -164,0 +166,0 @@ return xp.sqrt(xp.linalg.det(fisher_matrix)) |
@@ -7,2 +7,3 @@ from __future__ import annotations | ||
| import array_api_extra as xpx | ||
| import jax.scipy.special as jss | ||
@@ -53,3 +54,3 @@ import numpy as np | ||
| def general_array_namespace(x: RealNumeric) -> ModuleType: | ||
| def _general_array_namespace(x: RealNumeric) -> ModuleType: | ||
| if isinstance(x, float): | ||
@@ -60,5 +61,5 @@ return np | ||
| def canonical_float_epsilon(xp: ModuleType) -> float: | ||
| dtype = xp.empty((), dtype=float).dtype # For Jax, this is canonicalize_dtype(float). | ||
| return float(xp.finfo(dtype).eps) | ||
| def _canonical_float_epsilon(xp: ModuleType) -> float: | ||
| dtype = xpx.default_dtype(xp) | ||
| return xp.finfo(dtype).eps | ||
@@ -76,4 +77,4 @@ | ||
| if self.min_open and self.minimum is not None: | ||
| xp = general_array_namespace(self.minimum) | ||
| eps = canonical_float_epsilon(xp) | ||
| xp = _general_array_namespace(self.minimum) | ||
| eps = _canonical_float_epsilon(xp) | ||
| self.minimum = xp.asarray(xp.maximum( | ||
@@ -83,4 +84,4 @@ self.minimum + eps, | ||
| if self.max_open and self.maximum is not None: | ||
| xp = general_array_namespace(self.maximum) | ||
| eps = canonical_float_epsilon(xp) | ||
| xp = _general_array_namespace(self.maximum) | ||
| eps = _canonical_float_epsilon(xp) | ||
| self.maximum = xp.asarray(xp.minimum( | ||
@@ -217,3 +218,3 @@ self.maximum - eps, | ||
| xp = array_namespace(x) | ||
| return xp.asarray(x, dtype=float) | ||
| return xp.astype(x, float) | ||
@@ -223,7 +224,7 @@ @override | ||
| xp = array_namespace(y) | ||
| return xp.asarray(y, dtype=xp.bool_) | ||
| return xp.astype(y, xp.bool_) | ||
| @override | ||
| def generate(self, xp: Namespace, rng: Generator, shape: Shape, safety: float) -> JaxRealArray: | ||
| return xp.asarray(rng.binomial(1, 0.5, shape), dtype=xp.bool_) | ||
| return xp.astype(rng.binomial(1, 0.5, shape), xp.bool_) | ||
@@ -243,3 +244,3 @@ | ||
| xp = array_namespace(x) | ||
| return xp.asarray(x, dtype=float) | ||
| return xp.astype(x, float) | ||
@@ -249,3 +250,3 @@ @override | ||
| xp = array_namespace(y) | ||
| return xp.asarray(y, dtype=int) | ||
| return xp.astype(y, int) | ||
@@ -252,0 +253,0 @@ @override |
@@ -12,2 +12,3 @@ from __future__ import annotations | ||
| from numpy.random import Generator | ||
| from opt_einsum import contract | ||
| from tjax import JaxArray, JaxRealArray, Shape | ||
@@ -234,3 +235,3 @@ from typing_extensions import override | ||
| # Perform QR decomposition to obtain an orthogonal matrix Q | ||
| q, _ = np.linalg.qr(m) | ||
| q, _ = xp.linalg.qr(m) | ||
| # Generate Eigenvalues. | ||
@@ -242,3 +243,3 @@ assert isinstance(self.ring, RealField | ComplexField) | ||
| # Return Q.T @ diag(eig) @ Q. | ||
| return xp.einsum('...ji,...j,...jk->...ik', xp.conj(q) if self.hermitian else q, eig, q) | ||
| return contract('...ji,...j,...jk->...ik', xp.conj(q) if self.hermitian else q, eig, q) | ||
| mt = xp.matrix_transpose(m) | ||
@@ -245,0 +246,0 @@ if self.hermitian: |
@@ -41,3 +41,3 @@ from __future__ import annotations | ||
| def array_namespace(self, *x: Any) -> ModuleType: | ||
| def __array_namespace__(self, api_version: str | None = None) -> ModuleType: # noqa: PLW3201 | ||
| from .iteration import parameters # noqa: PLC0415 | ||
@@ -44,0 +44,0 @@ values = parameters(self).values() |
@@ -120,3 +120,3 @@ from __future__ import annotations | ||
| for i in np.ndindex(*self.shape): | ||
| objects[i] = self.objects[i].as_multivariate_normal() | ||
| objects[i] = self.objects[i].as_multivariate_normal() # pyright: ignore | ||
| return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects) | ||
@@ -135,3 +135,3 @@ | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].variance | ||
| retval[i] = self.objects[i].variance # pyright: ignore | ||
| return retval | ||
@@ -143,3 +143,3 @@ | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].pseudo_variance | ||
| retval[i] = self.objects[i].pseudo_variance # pyright: ignore | ||
| return retval |
@@ -117,3 +117,3 @@ from __future__ import annotations | ||
| for i in np.ndindex(*self.shape): | ||
| objects[i] = self.objects[i].as_multivariate_normal() | ||
| objects[i] = self.objects[i].as_multivariate_normal() # pyright: ignore | ||
| return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects) | ||
@@ -132,3 +132,3 @@ | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].variance | ||
| retval[i] = self.objects[i].variance # pyright: ignore | ||
| return retval | ||
@@ -140,3 +140,3 @@ | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].pseudo_variance | ||
| retval[i] = self.objects[i].pseudo_variance # pyright: ignore | ||
| return retval |
@@ -40,3 +40,3 @@ from __future__ import annotations | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].rvs(size=size, random_state=random_state) | ||
| retval[i] = self.objects[i].rvs(size=size, random_state=random_state) # pyright: ignore | ||
| return retval | ||
@@ -47,3 +47,3 @@ | ||
| for i in np.ndindex(*self.shape): | ||
| value = self.objects[i].pdf(x[i]) | ||
| value = self.objects[i].pdf(x[i]) # pyright: ignore | ||
| if i == (): | ||
@@ -57,3 +57,3 @@ return value | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].entropy() | ||
| retval[i] = self.objects[i].entropy() # pyright: ignore | ||
| return retval | ||
@@ -60,0 +60,0 @@ |
@@ -40,3 +40,3 @@ from __future__ import annotations | ||
| for i in np.ndindex(*shape): | ||
| objects[i] = ss.vonmises_fisher(mu[i], kappa[i]) | ||
| objects[i] = ss.vonmises_fisher(mu[i], kappa[i]) # pyright: ignore | ||
| super().__init__(shape, rvs_shape, dtype, objects) |
@@ -16,2 +16,3 @@ from dataclasses import fields | ||
| from ..natural_parametrization import NaturalParametrization | ||
| NP = TypeVar('NP', bound=NaturalParametrization) | ||
@@ -33,5 +34,5 @@ T = TypeVar('T') | ||
| def create_simple_estimator(cls, | ||
| type_p: type[SimpleDistribution], | ||
| type_p: type[SP], | ||
| **fixed_parameters: JaxArray | ||
| ) -> 'MaximumLikelihoodEstimator[Any]': | ||
| ) -> 'MaximumLikelihoodEstimator[SP]': | ||
| """Create an estimator for a simple expectation parametrization class. | ||
@@ -57,8 +58,7 @@ | ||
| @classmethod | ||
| def create_estimator_from_natural(cls, p: 'NaturalParametrization[Any, Any]' | ||
| ) -> 'MaximumLikelihoodEstimator[Any]': | ||
| def create_estimator_from_natural(cls, p: 'NP') -> 'MaximumLikelihoodEstimator[NP]': | ||
| """Create an estimator for a natural parametrization.""" | ||
| infos = MaximumLikelihoodEstimator.create(p).to_exp().infos | ||
| fixed_parameters = parameters(p, fixed=True) | ||
| return cls(infos, fixed_parameters) | ||
| return MaximumLikelihoodEstimator(infos, fixed_parameters) | ||
@@ -69,3 +69,3 @@ def sufficient_statistics(self, x: dict[str, Any] | JaxComplexArray) -> P: | ||
| from ..transform.joint import JointDistributionE # noqa: PLC0415 | ||
| constructed: dict[Path, ExpectationParametrization[Any]] = {} | ||
| constructed: dict[Path, ExpectationParametrization] = {} | ||
@@ -106,3 +106,3 @@ def g(info: SubDistributionInfo, x: JaxComplexArray) -> None: | ||
| def from_conjugate_prior_distribution(self, | ||
| cp: 'NaturalParametrization[Any, Any]' | ||
| cp: 'NaturalParametrization' | ||
| ) -> tuple[P, JaxRealArray]: | ||
@@ -109,0 +109,0 @@ from ..interfaces.conjugate_prior import HasConjugatePrior # noqa: PLC0415 |
| from dataclasses import replace | ||
| from functools import partial | ||
| from typing import Any, Self, TypeVar, cast, overload | ||
| from typing import Any, Self, cast, overload | ||
@@ -8,2 +8,3 @@ from array_api_compat import array_namespace | ||
| from tjax.dataclasses import dataclass, field | ||
| from typing_extensions import TypeVar | ||
@@ -17,4 +18,4 @@ from ..iteration import parameters | ||
| P = TypeVar('P', bound=Distribution) | ||
| SP = TypeVar('SP', bound=SimpleDistribution) | ||
| P = TypeVar('P', bound=Distribution, default=Any) | ||
| SP = TypeVar('SP', bound=SimpleDistribution, default=Any) | ||
@@ -97,3 +98,3 @@ | ||
| """ | ||
| xp = p.array_namespace() | ||
| xp = array_namespace(p) | ||
| arrays = [x | ||
@@ -100,0 +101,0 @@ for xs in cls._walk(partial(cls._make_flat, map_to_plane=map_to_plane), p) |
@@ -13,3 +13,3 @@ from collections.abc import Generator, Iterable | ||
| ) -> Generator[str]: | ||
| """Return the parameter names in a distribution. | ||
| """The parameter names in a distribution. | ||
@@ -21,3 +21,3 @@ Args: | ||
| Returns: | ||
| The name of each parameter. | ||
| The parameter names. | ||
| """ | ||
@@ -24,0 +24,0 @@ def _parameters(q: type[Distribution], |
@@ -26,3 +26,3 @@ from collections.abc import Generator | ||
| ) -> Generator[tuple[str, Support, ValueReceptacle]]: | ||
| """Return the parameter supports in a distribution. | ||
| """The parameter supports in a distribution. | ||
@@ -29,0 +29,0 @@ Args: |
| from collections.abc import Callable, Iterable, Mapping | ||
| from typing import Any, Generic, TypeVar, cast | ||
| from typing import Any, Generic, cast | ||
| from numpy.random import Generator | ||
| from tjax import JaxComplexArray, Shape | ||
| from tjax import JaxComplexArray, JaxRealArray, Shape | ||
| from tjax.dataclasses import dataclass, field | ||
| from typing_extensions import TypeVar | ||
@@ -26,4 +27,3 @@ from ..iteration import parameters | ||
| T = TypeVar('T') | ||
| P = TypeVar('P', bound=Distribution) | ||
| SP = TypeVar('SP', bound=SimpleDistribution) | ||
| P = TypeVar('P', bound=Distribution, default=Any) | ||
@@ -51,3 +51,3 @@ | ||
| def to_nat(self) -> 'Structure[Any]': | ||
| def to_nat(self) -> 'Structure': | ||
| from ..expectation_parametrization import ExpectationParametrization # noqa: PLC0415 | ||
@@ -61,5 +61,5 @@ infos = [] | ||
| def to_exp(self) -> 'Structure[Any]': | ||
| def to_exp(self) -> 'Structure': | ||
| from ..natural_parametrization import NaturalParametrization # noqa: PLC0415 | ||
| infos = [] | ||
| infos: list[SubDistributionInfo] = [] | ||
| for info in self.infos: | ||
@@ -105,3 +105,3 @@ assert issubclass(info.type_, NaturalParametrization) | ||
| """Generate a random distribution.""" | ||
| path_and_values = {} | ||
| path_and_values: dict[tuple[str, ...], JaxRealArray] = {} | ||
| for info in self.infos: | ||
@@ -108,0 +108,0 @@ for name, support, value_receptacle in parameter_supports(info.type_): |
@@ -21,3 +21,3 @@ from __future__ import annotations | ||
| @jit | ||
| def parameter_dot_product(x: NaturalParametrization[Any, Any], y: Any, /) -> JaxRealArray: | ||
| def parameter_dot_product(x: NaturalParametrization, y: Any, /) -> JaxRealArray: | ||
| """Return the vectorized dot product over all of the variable parameters.""" | ||
@@ -43,3 +43,3 @@ def dotted_fields() -> Iterable[JaxRealArray]: | ||
| """Return the mean of the parameters (fixed and variable).""" | ||
| xp = x.array_namespace() | ||
| xp = array_namespace(x) | ||
| structure = Structure.create(x) | ||
@@ -46,0 +46,0 @@ p = parameters(x) |
@@ -5,2 +5,3 @@ from collections.abc import Callable, Mapping | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxComplexArray, JaxRealArray, KeyArray, RngStream, Shape | ||
@@ -59,6 +60,6 @@ from tjax.dataclasses import dataclass | ||
| HasEntropyEP['JointDistributionN']): | ||
| _sub_distributions: Mapping[str, ExpectationParametrization[Any]] | ||
| _sub_distributions: Mapping[str, ExpectationParametrization] | ||
| @override | ||
| def sub_distributions(self) -> Mapping[str, ExpectationParametrization[Any]]: | ||
| def sub_distributions(self) -> Mapping[str, ExpectationParametrization]: | ||
| return self._sub_distributions | ||
@@ -78,5 +79,5 @@ | ||
| def expected_carrier_measure(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
| def f(x: ExpectationParametrization[Any]) -> JaxRealArray: | ||
| def f(x: ExpectationParametrization) -> JaxRealArray: | ||
| assert isinstance(x, HasEntropyEP) | ||
@@ -92,6 +93,6 @@ return x.expected_carrier_measure() | ||
| NaturalParametrization[JointDistributionE, dict[str, Any]]): | ||
| _sub_distributions: Mapping[str, NaturalParametrization[Any, Any]] | ||
| _sub_distributions: Mapping[str, NaturalParametrization] | ||
| @override | ||
| def sub_distributions(self) -> Mapping[str, NaturalParametrization[Any, Any]]: | ||
| def sub_distributions(self) -> Mapping[str, NaturalParametrization]: | ||
| return self._sub_distributions | ||
@@ -101,3 +102,3 @@ | ||
| def log_normalizer(self) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
@@ -114,3 +115,3 @@ return reduce(xp.add, | ||
| def carrier_measure(self, x: dict[str, Any]) -> JaxRealArray: | ||
| xp = self.array_namespace() | ||
| xp = array_namespace(self) | ||
@@ -117,0 +118,0 @@ joined = join_mappings(sub=self._sub_distributions, x=x) |
+7
-6
| Metadata-Version: 2.4 | ||
| Name: efax | ||
| Version: 1.22.0 | ||
| Version: 1.22.1 | ||
| Summary: Exponential families for JAX | ||
| Project-URL: repository, https://github.com/NeilGirdhar/efax | ||
| Project-URL: source, https://github.com/NeilGirdhar/efax | ||
| Author: Neil Girdhar | ||
@@ -26,5 +26,6 @@ Author-email: mistersheik@gmail.com | ||
| Requires-Dist: array-api-compat>=1.10 | ||
| Requires-Dist: array-api-extra>=0.7 | ||
| Requires-Dist: array-api-extra>=0.8 | ||
| Requires-Dist: jax>=0.6.1 | ||
| Requires-Dist: numpy>=1.25 | ||
| Requires-Dist: opt-einsum>=3.4 | ||
| Requires-Dist: optimistix>=0.0.9 | ||
@@ -35,3 +36,3 @@ Requires-Dist: optype>=0.8.0 | ||
| Requires-Dist: tensorflow-probability>=0.15 | ||
| Requires-Dist: tjax>=1.3.1 | ||
| Requires-Dist: tjax>=1.3.10 | ||
| Requires-Dist: typing-extensions>=4.8 | ||
@@ -44,6 +45,6 @@ Provides-Extra: dev | ||
| Requires-Dist: pylint>=3.3; extra == 'dev' | ||
| Requires-Dist: pyright>=0.0.13; extra == 'dev' | ||
| Requires-Dist: pyright>=1.1.401; extra == 'dev' | ||
| Requires-Dist: pytest-ordering; extra == 'dev' | ||
| Requires-Dist: pytest-xdist[psutil]>=3; extra == 'dev' | ||
| Requires-Dist: pytest>=8; extra == 'dev' | ||
| Requires-Dist: pytest>=8.4; extra == 'dev' | ||
| Requires-Dist: ruff>=0.9.10; extra == 'dev' | ||
@@ -50,0 +51,0 @@ Description-Content-Type: text/x-rst |
+7
-6
@@ -7,3 +7,3 @@ [build-system] | ||
| name = "efax" | ||
| version = "1.22.0" | ||
| version = "1.22.1" | ||
| description = "Exponential families for JAX" | ||
@@ -31,5 +31,6 @@ readme = "README.rst" | ||
| "array_api_compat>=1.10", | ||
| "array_api_extra>=0.7", | ||
| "array_api_extra>=0.8", | ||
| "jax>=0.6.1", | ||
| "numpy>=1.25", | ||
| "opt-einsum>=3.4", | ||
| "optimistix>=0.0.9", | ||
@@ -40,3 +41,3 @@ "optype>=0.8.0", | ||
| "tensorflow_probability>=0.15", | ||
| "tjax>=1.3.1", | ||
| "tjax>=1.3.10", | ||
| "typing_extensions>=4.8", | ||
@@ -52,6 +53,6 @@ ] | ||
| "pylint>=3.3", | ||
| "pyright>=0.0.13", | ||
| "pyright>=1.1.401", | ||
| "pytest-ordering", | ||
| "pytest-xdist[psutil]>=3", | ||
| "pytest>=8", | ||
| "pytest>=8.4", | ||
| "ruff>=0.9.10", | ||
@@ -61,3 +62,3 @@ ] | ||
| [project.urls] | ||
| repository = "https://github.com/NeilGirdhar/efax" | ||
| source = "https://github.com/NeilGirdhar/efax" | ||
@@ -64,0 +65,0 @@ [tool.isort] |
@@ -62,6 +62,6 @@ from __future__ import annotations | ||
| def distribution_name(request: pytest.FixtureRequest) -> str | None: | ||
| return request.config.getoption("--distribution") # pyright: ignore | ||
| return request.config.getoption("--distribution") | ||
| def supports(s: Structure[Any], abc: type[Any]) -> bool: | ||
| def supports(s: Structure, abc: type[Any]) -> bool: | ||
| return all(issubclass(info.type_, abc) or issubclass(info.type_, JointDistribution) | ||
@@ -71,3 +71,3 @@ for info in s.infos) | ||
| def any_integral_supports(structure: Structure[Any]) -> bool: | ||
| def any_integral_supports(structure: Structure) -> bool: | ||
| return any(isinstance(s.ring, BooleanRing | IntegralRing) | ||
@@ -74,0 +74,0 @@ for s in structure.domain_support().values()) |
@@ -309,3 +309,3 @@ from __future__ import annotations | ||
| class JointInfo(DistributionInfo[JointDistributionN, JointDistributionE, dict[str, Any]]): | ||
| def __init__(self, infos: Mapping[str, DistributionInfo[Any, Any, Any]]) -> None: | ||
| def __init__(self, infos: Mapping[str, DistributionInfo]) -> None: | ||
| super().__init__() | ||
@@ -396,3 +396,3 @@ self.infos = dict(infos) | ||
| variance = np.asarray(p.variance()) | ||
| covariance = xpx.create_diagonal(variance) # type: ignore[arg-type] | ||
| covariance = xpx.create_diagonal(variance) # type: ignore[arg-type] # pyright: ignore | ||
| assert isinstance(covariance, np.ndarray) # type: ignore[unreachable] | ||
@@ -643,3 +643,3 @@ return ScipyMultivariateNormal.from_mc( # type: ignore[unreachable] | ||
| def create_infos() -> list[DistributionInfo[Any, Any, Any]]: | ||
| def create_infos() -> list[DistributionInfo]: | ||
| return [ | ||
@@ -646,0 +646,0 @@ BernoulliInfo(), |
| from __future__ import annotations | ||
| from collections.abc import Callable | ||
| from typing import Any, Generic, TypeVar, final | ||
| from typing import Any, Generic, final | ||
@@ -10,9 +10,9 @@ import jax.numpy as jnp | ||
| from tjax import JaxComplexArray, NumpyComplexArray, Shape | ||
| from typing_extensions import override | ||
| from typing_extensions import TypeVar, override | ||
| from efax import ExpectationParametrization, NaturalParametrization, Structure, SubDistributionInfo | ||
| NP = TypeVar('NP', bound=NaturalParametrization[Any, Any]) | ||
| EP = TypeVar('EP', bound=ExpectationParametrization[Any]) | ||
| Domain = TypeVar('Domain', bound=NumpyComplexArray | dict[str, Any]) | ||
| NP = TypeVar('NP', bound=NaturalParametrization, default=Any) | ||
| EP = TypeVar('EP', bound=ExpectationParametrization, default=Any) | ||
| Domain = TypeVar('Domain', bound=NumpyComplexArray | dict[str, Any], default=Any) | ||
@@ -19,0 +19,0 @@ |
@@ -13,8 +13,8 @@ from __future__ import annotations | ||
| # Tools -------------------------------------------------------------------------------------------- | ||
| def random_complex_array(generator: Generator, shape: Shape = ()) -> NumpyComplexArray: | ||
| def _random_complex_array(generator: Generator, shape: Shape = ()) -> NumpyComplexArray: | ||
| return np.asarray(sum(x * generator.normal(size=shape) for x in (0.5, 0.5j))) | ||
| def build_uvcn(generator: Generator, shape: Shape) -> ScipyComplexNormal: | ||
| mean = random_complex_array(generator, shape) | ||
| def _build_uvcn(generator: Generator, shape: Shape) -> ScipyComplexNormal: | ||
| mean = _random_complex_array(generator, shape) | ||
| variance = generator.exponential(size=shape) | ||
@@ -27,11 +27,11 @@ pseudo_variance = (variance | ||
| def build_mvcn(generator: Generator, | ||
| shape: Shape, | ||
| dimensions: int, | ||
| polarization: float = 0.98 | ||
| ) -> ScipyComplexMultivariateNormal: | ||
| def _build_mvcn(generator: Generator, | ||
| shape: Shape, | ||
| dimensions: int, | ||
| polarization: float = 0.98 | ||
| ) -> ScipyComplexMultivariateNormal: | ||
| directions = 3 | ||
| weights = np.asarray(range(directions)) + 1.5 | ||
| mean = random_complex_array(generator, (*shape, dimensions)) | ||
| z = random_complex_array(generator, (*shape, dimensions, directions)) | ||
| mean = _random_complex_array(generator, (*shape, dimensions)) | ||
| z = _random_complex_array(generator, (*shape, dimensions, directions)) | ||
| regularizer = np.tile(np.eye(dimensions), (*shape, 1, 1)) | ||
@@ -56,3 +56,3 @@ variance = ( | ||
| rvs_axes = tuple(range(-len(rvs_shape), 0)) | ||
| dist = build_uvcn(generator, shape) | ||
| dist = _build_uvcn(generator, shape) | ||
| rvs = dist.rvs(random_state=generator, size=rvs_shape) | ||
@@ -79,3 +79,3 @@ assert rvs.shape == shape + rvs_shape | ||
| rvs_axes2 = tuple(range(-len(rvs_shape) - 2, -2)) | ||
| dist = build_mvcn(generator, shape, dimensions) | ||
| dist = _build_mvcn(generator, shape, dimensions) | ||
| rvs = dist.rvs(random_state=generator, size=rvs_shape) | ||
@@ -99,3 +99,3 @@ assert rvs.shape == shape + rvs_shape + (dimensions,) | ||
| def test_univariate_multivariate_consistency(generator: Generator) -> None: | ||
| mv = build_mvcn(generator, (), 1, polarization=0.5) | ||
| mv = _build_mvcn(generator, (), 1, polarization=0.5) | ||
| component = mv.access_object(()) | ||
@@ -106,3 +106,3 @@ mean: NumpyComplexArray = np.asarray(component.mean[0]) | ||
| uv = ScipyComplexNormal(mean, variance, pseudo_variance) | ||
| x = random_complex_array(generator) | ||
| x = _random_complex_array(generator) | ||
| assert_allclose(mv.pdf(x[np.newaxis]), uv.pdf(x)) |
| from __future__ import annotations | ||
| from typing import Any | ||
| import jax.numpy as jnp | ||
@@ -17,3 +15,3 @@ from jax import grad, vmap | ||
| def test_conjugate_prior(generator: Generator, | ||
| cp_distribution_info: DistributionInfo[Any, Any, Any], | ||
| cp_distribution_info: DistributionInfo, | ||
| distribution_name: str | None) -> None: | ||
@@ -50,3 +48,3 @@ """Test that the conjugate prior actually matches the distribution.""" | ||
| def test_from_conjugate_prior(generator: Generator, | ||
| cp_distribution_info: DistributionInfo[Any, Any, Any], | ||
| cp_distribution_info: DistributionInfo, | ||
| distribution_name: str | None) -> None: | ||
@@ -71,3 +69,3 @@ """Test that the conjugate prior is reversible.""" | ||
| def test_generalized_conjugate_prior(generator: Generator, | ||
| gcp_distribution_info: DistributionInfo[Any, Any, Any], | ||
| gcp_distribution_info: DistributionInfo, | ||
| distribution_name: str | None | ||
@@ -74,0 +72,0 @@ ) -> None: |
@@ -1,2 +0,2 @@ | ||
| """These tests are related to entropy.""" | ||
| """These tests verify entropy gradients.""" | ||
| from __future__ import annotations | ||
@@ -23,3 +23,3 @@ | ||
| @jit | ||
| def sum_entropy(flattened: JaxRealArray, flattener: Flattener[Any]) -> JaxRealArray: | ||
| def _sum_entropy(flattened: JaxRealArray, flattener: Flattener) -> JaxRealArray: | ||
| x = flattener.unflatten(flattened) | ||
@@ -30,11 +30,11 @@ return jnp.sum(x.entropy()) | ||
| @jit | ||
| def all_finite(some_tree: Any, /) -> JaxBooleanArray: | ||
| def _all_finite(some_tree: Any, /) -> JaxBooleanArray: | ||
| return dynamic_tree_all(tree.map(lambda x: jnp.all(jnp.isfinite(x)), some_tree)) | ||
| def check_entropy_gradient(distribution: HasEntropy, /) -> None: | ||
| def _check_entropy_gradient(distribution: HasEntropy, /) -> None: | ||
| flattener, flattened = Flattener.flatten(distribution, map_to_plane=False) | ||
| p_sum_entropy = partial(sum_entropy, flattener=flattener) | ||
| p_sum_entropy = partial(_sum_entropy, flattener=flattener) | ||
| calculated_gradient = grad(p_sum_entropy)(flattened) | ||
| if not all_finite(calculated_gradient): | ||
| if not _all_finite(calculated_gradient): | ||
| indices = jnp.argwhere(jnp.isnan(calculated_gradient)) | ||
@@ -51,15 +51,15 @@ bad_distributions = [distribution[tuple(index)[:-1]] for index in indices] | ||
| def test_nat_entropy_gradient(generator: Generator, | ||
| entropy_distribution_info: DistributionInfo[Any, Any, Any], | ||
| entropy_distribution_info: DistributionInfo, | ||
| ) -> None: | ||
| shape = (7, 13) | ||
| nat_parameters = entropy_distribution_info.nat_parameter_generator(generator, shape=shape) | ||
| check_entropy_gradient(nat_parameters) | ||
| _check_entropy_gradient(nat_parameters) | ||
| def test_exp_entropy_gradient(generator: Generator, | ||
| entropy_distribution_info: DistributionInfo[Any, Any, Any], | ||
| entropy_distribution_info: DistributionInfo, | ||
| ) -> None: | ||
| shape = (7, 13) | ||
| exp_parameters = entropy_distribution_info.exp_parameter_generator(generator, shape=shape) | ||
| check_entropy_gradient(exp_parameters) | ||
| _check_entropy_gradient(exp_parameters) | ||
@@ -66,0 +66,0 @@ |
| from __future__ import annotations | ||
| from typing import Any | ||
| import jax.numpy as jnp | ||
@@ -57,3 +55,3 @@ import numpy as np | ||
| def test_fisher_information_is_convex(generator: Generator, | ||
| distribution_info: DistributionInfo[Any, Any, Any]) -> None: | ||
| distribution_info: DistributionInfo) -> None: | ||
| shape = (3, 2) | ||
@@ -60,0 +58,0 @@ nat_parameters = distribution_info.nat_parameter_generator(generator, shape=shape) |
| from __future__ import annotations | ||
| from typing import Any | ||
| import jax.numpy as jnp | ||
@@ -18,3 +16,3 @@ import numpy as np | ||
| def test_flatten(generator: Generator, | ||
| distribution_info: DistributionInfo[Any, Any, Any], | ||
| distribution_info: DistributionInfo, | ||
| *, | ||
@@ -21,0 +19,0 @@ natural: bool, |
| """These tests apply to only samplable distributions.""" | ||
| from __future__ import annotations | ||
| from typing import Any | ||
| import jax.numpy as jnp | ||
@@ -18,3 +16,3 @@ import jax.random as jr | ||
| def _sample_using_flattened(flattened_parameters: JaxRealArray, | ||
| flattener: Flattener[Any], | ||
| flattener: Flattener, | ||
| key: KeyArray, | ||
@@ -77,3 +75,3 @@ ) -> JaxArray: | ||
| key: KeyArray, | ||
| sampling_wc_distribution_info: DistributionInfo[Any, Any, Any], | ||
| sampling_wc_distribution_info: DistributionInfo, | ||
| *, | ||
@@ -80,0 +78,0 @@ distribution_name: str | None, |
+7
-9
| from __future__ import annotations | ||
| from typing import Any | ||
| import numpy as np | ||
@@ -19,6 +17,6 @@ from numpy.linalg import det, inv | ||
| def prelude(generator: Generator, | ||
| distribution_info_kl: DistributionInfo[Any, Any, Any], | ||
| distribution_name: str | None | ||
| ) -> tuple[ExpectationParametrization[Any], NaturalParametrization[Any, Any], | ||
| def _prelude(generator: Generator, | ||
| distribution_info_kl: DistributionInfo, | ||
| distribution_name: str | None | ||
| ) -> tuple[ExpectationParametrization, NaturalParametrization, | ||
| JaxRealArray]: | ||
@@ -35,3 +33,3 @@ shape = (3, 2) | ||
| """Test the KL divergence.""" | ||
| x, y, my_kl = prelude(generator, NormalInfo(), distribution_name) | ||
| x, y, my_kl = _prelude(generator, NormalInfo(), distribution_name) | ||
| assert isinstance(x, NormalEP) | ||
@@ -50,3 +48,3 @@ assert isinstance(y, NormalNP) | ||
| """Test the KL divergence.""" | ||
| x, y, my_kl = prelude(generator, MultivariateNormalInfo(dimensions=4), distribution_name) | ||
| x, y, my_kl = _prelude(generator, MultivariateNormalInfo(dimensions=4), distribution_name) | ||
| assert isinstance(x, MultivariateNormalEP) | ||
@@ -69,3 +67,3 @@ assert isinstance(y, MultivariateNormalNP) | ||
| """Test the KL divergence.""" | ||
| x, y, my_kl = prelude(generator, GammaInfo(), distribution_name) | ||
| x, y, my_kl = _prelude(generator, GammaInfo(), distribution_name) | ||
| assert isinstance(x, GammaEP) | ||
@@ -72,0 +70,0 @@ assert isinstance(y, GammaNP) |
+27
-27
@@ -19,15 +19,15 @@ """These tests apply to only samplable distributions.""" | ||
| Path: TypeAlias = tuple[str, ...] | ||
| _Path: TypeAlias = tuple[str, ...] | ||
| def produce_samples(generator: Generator, | ||
| key: KeyArray, | ||
| sampling_distribution_info: DistributionInfo[NaturalParametrization[Any, Any], | ||
| ExpectationParametrization[Any], | ||
| Any], | ||
| distribution_shape: Shape, | ||
| sample_shape: Shape, | ||
| *, | ||
| natural: bool) -> tuple[ExpectationParametrization[Any], | ||
| dict[str, Any] | JaxComplexArray]: | ||
| def _produce_samples(generator: Generator, | ||
| key: KeyArray, | ||
| sampling_distribution_info: DistributionInfo[NaturalParametrization, | ||
| ExpectationParametrization, | ||
| Any], | ||
| distribution_shape: Shape, | ||
| sample_shape: Shape, | ||
| *, | ||
| natural: bool) -> tuple[ExpectationParametrization, | ||
| dict[str, Any] | JaxComplexArray]: | ||
| sampling_object: Distribution | ||
@@ -52,7 +52,7 @@ if natural: | ||
| def verify_sample_shape(distribution_shape: Shape, | ||
| sample_shape: Shape, | ||
| structure: Structure[ExpectationParametrization[Any]], | ||
| flat_map_of_samples: dict[Path, Any] | ||
| ) -> None: | ||
| def _verify_sample_shape(distribution_shape: Shape, | ||
| sample_shape: Shape, | ||
| structure: Structure[ExpectationParametrization], | ||
| flat_map_of_samples: dict[_Path, Any] | ||
| ) -> None: | ||
| ideal_samples_shape = {info.path: (*sample_shape, *distribution_shape, | ||
@@ -66,9 +66,9 @@ *info.type_.domain_support().shape(info.dimensions)) | ||
| def verify_maximum_likelihood_estimate( | ||
| sampling_distribution_info: DistributionInfo[NaturalParametrization[Any, Any], | ||
| ExpectationParametrization[Any], | ||
| def _verify_maximum_likelihood_estimate( | ||
| sampling_distribution_info: DistributionInfo[NaturalParametrization, | ||
| ExpectationParametrization, | ||
| Any], | ||
| sample_shape: Shape, | ||
| structure: Structure[ExpectationParametrization[Any]], | ||
| exp_parameters: ExpectationParametrization[Any], | ||
| structure: Structure[ExpectationParametrization], | ||
| exp_parameters: ExpectationParametrization, | ||
| samples: dict[str, Any] | JaxComplexArray | ||
@@ -96,3 +96,3 @@ ) -> None: | ||
| key: KeyArray, | ||
| sampling_distribution_info: DistributionInfo[Any, Any, Any], | ||
| sampling_distribution_info: DistributionInfo, | ||
| *, | ||
@@ -109,9 +109,9 @@ distribution_name: str | None, | ||
| sample_shape = (1024, 64) # The number of samples that are taken to do the estimation. | ||
| exp_parameters, samples = produce_samples(generator, key, sampling_distribution_info, | ||
| distribution_shape, sample_shape, natural=natural) | ||
| exp_parameters, samples = _produce_samples(generator, key, sampling_distribution_info, | ||
| distribution_shape, sample_shape, natural=natural) | ||
| flat_map_of_samples = flat_dict_of_observations(samples) | ||
| structure = Structure.create(exp_parameters) | ||
| flat_map_of_samples = flatten_mapping(samples) if isinstance(samples, dict) else {(): samples} | ||
| verify_sample_shape(distribution_shape, sample_shape, structure, flat_map_of_samples) | ||
| verify_maximum_likelihood_estimate(sampling_distribution_info, sample_shape, structure, | ||
| exp_parameters, samples) | ||
| _verify_sample_shape(distribution_shape, sample_shape, structure, flat_map_of_samples) | ||
| _verify_maximum_likelihood_estimate(sampling_distribution_info, sample_shape, structure, | ||
| exp_parameters, samples) |
+1
-10
| from __future__ import annotations | ||
| from typing import Any | ||
| import numpy as np | ||
| from numpy.random import Generator | ||
@@ -13,3 +10,3 @@ | ||
| def test_shapes(generator: Generator, distribution_info: DistributionInfo[Any, Any, Any]) -> None: | ||
| def test_shapes(generator: Generator, distribution_info: DistributionInfo) -> None: | ||
| """Test that the methods produce the correct shapes.""" | ||
@@ -43,7 +40,1 @@ shape = (3, 4) | ||
| assert q.pdf(x).shape == shape | ||
| def test_types(distribution_info: DistributionInfo[Any, Any, Any]) -> None: | ||
| if isinstance(distribution_info.exp_parameter_generator(np.random.default_rng(), ()), tuple): | ||
| msg = "This should return a number or an ndarray" | ||
| raise TypeError(msg) |
| from __future__ import annotations | ||
| from collections.abc import Callable | ||
| from typing import Any, TypeAlias | ||
| import pytest | ||
| from jax import grad, jvp, vjp | ||
| from jax.custom_derivatives import zero_from_primal | ||
| from numpy.random import Generator | ||
| from numpy.testing import assert_allclose | ||
| from tjax import JaxRealArray, assert_tree_allclose, jit, print_generic, tree_allclose | ||
| from efax import NaturalParametrization, Structure, parameters | ||
| from .create_info import GeneralizedDirichletInfo | ||
| from .distribution_info import DistributionInfo | ||
| LogNormalizer: TypeAlias = Callable[[NaturalParametrization[Any, Any]], JaxRealArray] | ||
| def test_conversion(generator: Generator, | ||
| distribution_info: DistributionInfo[Any, Any, Any] | ||
| ) -> None: | ||
| """Test that the conversion between the different parametrizations are consistent.""" | ||
| if isinstance(distribution_info, GeneralizedDirichletInfo): | ||
| pytest.skip() | ||
| n = 30 | ||
| shape = (n,) | ||
| original_np = distribution_info.nat_parameter_generator(generator, shape=shape) | ||
| intermediate_ep = original_np.to_exp() | ||
| final_np = intermediate_ep.to_nat() | ||
| # Check round trip. | ||
| if not tree_allclose(final_np, original_np): | ||
| for i in range(n): | ||
| if not tree_allclose(final_np[i], original_np[i]): | ||
| print_generic({"original": original_np[i], | ||
| "intermediate": intermediate_ep[i], | ||
| "final": final_np[i]}) | ||
| pytest.fail("Conversion failure") | ||
| # Check fixed parameters. | ||
| original_fixed = parameters(original_np, fixed=True) | ||
| intermediate_fixed = parameters(intermediate_ep, fixed=True) | ||
| final_fixed = parameters(final_np, fixed=True) | ||
| assert_tree_allclose(original_fixed, intermediate_fixed) | ||
| assert_tree_allclose(original_fixed, final_fixed) | ||
| def prelude(generator: Generator, | ||
| distribution_info: DistributionInfo[Any, Any, Any] | ||
| ) -> tuple[LogNormalizer, LogNormalizer]: | ||
| cls = distribution_info.nat_class() | ||
| original_ln = cls._original_log_normalizer | ||
| optimized_ln = cls.log_normalizer | ||
| return original_ln, optimized_ln | ||
| def test_gradient_log_normalizer_primals(generator: Generator, | ||
| distribution_info: DistributionInfo[Any, Any, Any] | ||
| ) -> None: | ||
| """Tests that the gradient log-normalizer equals the gradient of the log-normalizer.""" | ||
| original_ln, optimized_ln = prelude(generator, distribution_info) | ||
| original_gln = jit(grad(original_ln, allow_int=True)) | ||
| optimized_gln = jit(grad(optimized_ln, allow_int=True)) | ||
| for _ in range(20): | ||
| generated_np = distribution_info.nat_parameter_generator(generator, shape=()) | ||
| generated_ep = generated_np.to_exp() # Regular transformation. | ||
| generated_parameters = parameters(generated_ep, fixed=False) | ||
| structure_ep = Structure.create(generated_ep) | ||
| # Original GLN. | ||
| original_gln_np = original_gln(generated_np) | ||
| original_gln_ep = structure_ep.reinterpret(original_gln_np) | ||
| original_gln_parameters = parameters(original_gln_ep, fixed=False) | ||
| # Optimized GLN. | ||
| optimized_gln_np = optimized_gln(generated_np) | ||
| optimized_gln_ep = structure_ep.reinterpret(optimized_gln_np) | ||
| optimized_gln_parameters = parameters(optimized_gln_ep, fixed=False) | ||
| # Test primal evaluation. | ||
| # parameters(generated_ep, fixed=False) | ||
| assert_tree_allclose(generated_parameters, original_gln_parameters, rtol=1e-5) | ||
| assert_tree_allclose(generated_parameters, optimized_gln_parameters, rtol=1e-5) | ||
| def unit_tangent(nat_parameters: NaturalParametrization[Any, Any] | ||
| ) -> NaturalParametrization[Any, Any]: | ||
| xp = nat_parameters.array_namespace() | ||
| new_variable_parameters = {path: xp.ones_like(value) | ||
| for path, value in parameters(nat_parameters, fixed=False).items()} | ||
| new_fixed_parameters = {path: zero_from_primal(value, symbolic_zeros=False) | ||
| for path, value in parameters(nat_parameters, fixed=True).items()} | ||
| structure = Structure.create(nat_parameters) | ||
| return structure.assemble({**new_variable_parameters, **new_fixed_parameters}) | ||
| def test_gradient_log_normalizer_jvp(generator: Generator, | ||
| distribution_info: DistributionInfo[Any, Any, Any] | ||
| ) -> None: | ||
| """Tests that the gradient log-normalizer equals the gradient of the log-normalizer.""" | ||
| original_ln, optimized_ln = prelude(generator, distribution_info) | ||
| for _ in range(20): | ||
| nat_parameters = distribution_info.nat_parameter_generator(generator, shape=()) | ||
| # Test JVP. | ||
| nat_tangent = unit_tangent(nat_parameters) | ||
| original_ln_of_nat, original_jvp = jvp(original_ln, (nat_parameters,), (nat_tangent,)) | ||
| optimized_ln_of_nat, optimized_jvp = jvp(optimized_ln, (nat_parameters,), (nat_tangent,)) | ||
| assert_allclose(original_ln_of_nat, optimized_ln_of_nat, rtol=1.5e-5) | ||
| assert_allclose(original_jvp, optimized_jvp, rtol=1.5e-5) | ||
| def test_gradient_log_normalizer_vjp(generator: Generator, | ||
| distribution_info: DistributionInfo[Any, Any, Any] | ||
| ) -> None: | ||
| """Tests that the gradient log-normalizer equals the gradient of the log-normalizer.""" | ||
| original_ln, optimized_ln = prelude(generator, distribution_info) | ||
| for _ in range(20): | ||
| nat_parameters = distribution_info.nat_parameter_generator(generator, shape=()) | ||
| nat_tangent = unit_tangent(nat_parameters) | ||
| original_ln_of_nat, _ = jvp(original_ln, (nat_parameters,), (nat_tangent,)) | ||
| original_ln_of_nat_b, original_vjp = vjp(original_ln, nat_parameters) | ||
| original_gln_of_nat, = original_vjp(1.0) | ||
| optimized_ln_of_nat_b, optimized_vjp = vjp(optimized_ln, nat_parameters) | ||
| optimized_gln_of_nat, = optimized_vjp(1.0) | ||
| assert_allclose(original_ln_of_nat_b, optimized_ln_of_nat_b, rtol=1e-5) | ||
| assert_allclose(original_ln_of_nat, original_ln_of_nat_b, rtol=1e-5) | ||
| assert_tree_allclose(parameters(original_gln_of_nat, fixed=False), | ||
| parameters(optimized_gln_of_nat, fixed=False), | ||
| rtol=1e-5) |
| """These tests ensure that our distributions match scipy's.""" | ||
| from __future__ import annotations | ||
| from functools import partial | ||
| from typing import Any | ||
| import numpy as np | ||
| from jax import Array | ||
| from numpy.random import Generator | ||
| from numpy.testing import assert_allclose | ||
| from tjax import JaxComplexArray, assert_tree_allclose | ||
| from efax import (HasEntropyEP, HasEntropyNP, JointDistributionN, MaximumLikelihoodEstimator, | ||
| Multidimensional, NaturalParametrization, SimpleDistribution, | ||
| flat_dict_of_observations, flat_dict_of_parameters, parameter_map, | ||
| unflatten_mapping) | ||
| from .create_info import (ChiInfo, ChiSquareInfo, ComplexCircularlySymmetricNormalInfo, | ||
| IsotropicNormalInfo, MultivariateDiagonalNormalInfo, | ||
| MultivariateFixedVarianceNormalInfo, MultivariateNormalInfo, | ||
| VonMisesFisherInfo) | ||
| from .distribution_info import DistributionInfo | ||
| def test_nat_entropy(generator: Generator, | ||
| entropy_distribution_info: DistributionInfo[Any, Any, Any] | ||
| ) -> None: | ||
| """Test that the entropy calculation matches scipy's.""" | ||
| shape = (7, 13) | ||
| nat_parameters = entropy_distribution_info.nat_parameter_generator(generator, shape=shape) | ||
| assert isinstance(nat_parameters, HasEntropyNP) | ||
| scipy_distribution = entropy_distribution_info.nat_to_scipy_distribution(nat_parameters) | ||
| rtol = 2e-5 | ||
| my_entropy = nat_parameters.entropy() | ||
| scipy_entropy = scipy_distribution.entropy() | ||
| assert_allclose(my_entropy, scipy_entropy, rtol=rtol) | ||
| def test_exp_entropy(generator: Generator, | ||
| entropy_distribution_info: DistributionInfo[Any, Any, Any] | ||
| ) -> None: | ||
| """Test that the entropy calculation matches scipy's.""" | ||
| shape = (7, 13) | ||
| exp_parameters = entropy_distribution_info.exp_parameter_generator(generator, shape=shape) | ||
| assert isinstance(exp_parameters, HasEntropyEP) | ||
| scipy_distribution = entropy_distribution_info.exp_to_scipy_distribution(exp_parameters) | ||
| rtol = (1e-5 | ||
| if isinstance(entropy_distribution_info, ChiInfo | ChiSquareInfo) | ||
| else 1e-6) | ||
| my_entropy = exp_parameters.entropy() | ||
| scipy_entropy = scipy_distribution.entropy() | ||
| assert_allclose(my_entropy, scipy_entropy, rtol=rtol) | ||
| def check_observation_shape(nat_parameters: NaturalParametrization[Any, Any], | ||
| efax_x: JaxComplexArray | dict[str, Any], | ||
| distribution_shape: tuple[int, ...], | ||
| ) -> None: | ||
| """Verify that the sufficient statistics have the right shape.""" | ||
| if isinstance(nat_parameters, JointDistributionN): | ||
| assert isinstance(efax_x, dict) | ||
| for name, value in nat_parameters.sub_distributions().items(): | ||
| check_observation_shape(value, efax_x[name], distribution_shape) | ||
| return | ||
| assert isinstance(nat_parameters, SimpleDistribution) # type: ignore[unreachable] | ||
| assert isinstance(efax_x, Array) # type: ignore[unreachable] | ||
| dimensions = (nat_parameters.dimensions() | ||
| if isinstance(nat_parameters, Multidimensional) | ||
| else 0) | ||
| ideal_shape = distribution_shape + nat_parameters.domain_support().shape(dimensions) | ||
| assert efax_x.shape == ideal_shape | ||
| def test_pdf(generator: Generator, distribution_info: DistributionInfo[Any, Any, Any]) -> None: | ||
| """Test that the density/mass function calculation matches scipy's.""" | ||
| distribution_shape = (10,) | ||
| nat_parameters = distribution_info.nat_parameter_generator(generator, shape=distribution_shape) | ||
| scipy_distribution = distribution_info.nat_to_scipy_distribution(nat_parameters) | ||
| scipy_x = scipy_distribution.rvs(random_state=generator) | ||
| efax_x = distribution_info.scipy_to_exp_family_observation(scipy_x) | ||
| check_observation_shape(nat_parameters, efax_x, distribution_shape) | ||
| # Verify that the density matches scipy. | ||
| efax_density = np.asarray(nat_parameters.pdf(efax_x), dtype=np.float64) | ||
| try: | ||
| scipy_density = scipy_distribution.pdf(scipy_x) | ||
| except AttributeError: | ||
| scipy_density = scipy_distribution.pmf(scipy_x) | ||
| if isinstance(distribution_info, MultivariateDiagonalNormalInfo): | ||
| atol = 1e-5 | ||
| rtol = 3e-4 | ||
| else: | ||
| atol = 1e-5 | ||
| rtol = 1e-4 | ||
| assert_allclose(efax_density, scipy_density, rtol=rtol, atol=atol) | ||
| def test_maximum_likelihood_estimation( | ||
| generator: Generator, | ||
| distribution_info: DistributionInfo[NaturalParametrization[Any, Any], Any, Any] | ||
| ) -> None: | ||
| """Test maximum likelihood estimation using SciPy. | ||
| Test that maximum likelihood estimation from scipy-generated variates produce the same | ||
| distribution from which they were drawn. | ||
| """ | ||
| rtol = 2e-2 | ||
| if isinstance(distribution_info, | ||
| ComplexCircularlySymmetricNormalInfo | MultivariateNormalInfo | ||
| | VonMisesFisherInfo | MultivariateFixedVarianceNormalInfo): | ||
| atol = 1e-2 | ||
| elif isinstance(distribution_info, IsotropicNormalInfo): | ||
| atol = 1e-3 | ||
| else: | ||
| atol = 1e-6 | ||
| n = 70000 | ||
| # Generate a distribution with expectation parameters. | ||
| exp_parameters = distribution_info.exp_parameter_generator(generator, shape=()) | ||
| # Generate variates from the corresponding scipy distribution. | ||
| scipy_distribution = distribution_info.exp_to_scipy_distribution(exp_parameters) | ||
| scipy_x = scipy_distribution.rvs(random_state=generator, size=n) | ||
| # Convert the variates to sufficient statistics. | ||
| efax_x = distribution_info.scipy_to_exp_family_observation(scipy_x) | ||
| flat_efax_x = flat_dict_of_observations(efax_x) | ||
| flat_parameters = flat_dict_of_parameters(exp_parameters) | ||
| flat_efax_x_clamped = {path: flat_parameters[path].domain_support().clamp(value) | ||
| for path, value in flat_efax_x.items()} | ||
| efax_x_clamped: Array | dict[str, Any] = (flat_efax_x_clamped[()] | ||
| if flat_efax_x_clamped.keys() == {()} | ||
| else unflatten_mapping(flat_efax_x_clamped)) | ||
| estimator = MaximumLikelihoodEstimator.create_estimator(exp_parameters) | ||
| sufficient_stats = estimator.sufficient_statistics(efax_x_clamped) | ||
| # Verify that the mean of the sufficient statistics equals the expectation parameters. | ||
| calculated_parameters = parameter_map(partial(np.mean, axis=0), # type: ignore[arg-type] | ||
| sufficient_stats) | ||
| assert_tree_allclose(exp_parameters, calculated_parameters, rtol=rtol, atol=atol) |
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
997491
0.4%122
3.39%8066
0.45%