efax
Advanced tools
| from __future__ import annotations | ||
| from typing import Protocol, runtime_checkable | ||
| from numpy.random import Generator | ||
| from tjax import NumpyComplexArray, NumpyIntegralArray, NumpyRealArray, ShapeLike | ||
| @runtime_checkable | ||
| class ScipyDistribution(Protocol): | ||
| def pdf(self, x: NumpyComplexArray) -> NumpyRealArray: | ||
| ... | ||
| def rvs(self, size: ShapeLike = (), random_state: Generator | None = None) -> NumpyComplexArray: | ||
| ... | ||
| def entropy(self) -> NumpyRealArray: | ||
| ... | ||
| @runtime_checkable | ||
| class ScipyDiscreteDistribution(Protocol): | ||
| def pmf(self, x: NumpyIntegralArray) -> NumpyRealArray: | ||
| ... | ||
| def rvs(self, size: ShapeLike = (), random_state: Generator | None = None | ||
| ) -> NumpyIntegralArray: | ||
| ... | ||
| def entropy(self) -> NumpyRealArray: | ||
| ... |
+14
| pre-commit: | ||
| parallel: true | ||
| jobs: | ||
| - name: ruff-fix | ||
| glob: "*.py" | ||
| run: uv run ruff check --fix {staged_files} | ||
| stage_fixed: true | ||
| - name: pyright | ||
| glob: "*.py" | ||
| run: uv run pyright {staged_files} | ||
| - name: toml-sort | ||
| glob: "*.toml" | ||
| run: toml-sort -i {staged_files} | ||
| stage_fixed: true |
+124
-121
@@ -60,2 +60,3 @@ """The EFAX Library.""" | ||
| from ._src.parametrization import Distribution, SimpleDistribution | ||
| from ._src.scipy_replacement.base import ScipyDiscreteDistribution, ScipyDistribution | ||
| from ._src.scipy_replacement.complex_multivariate_normal import ScipyComplexMultivariateNormal | ||
@@ -77,123 +78,125 @@ from ._src.scipy_replacement.complex_normal import ScipyComplexNormal | ||
| __all__ = [ | ||
| 'BernoulliEP', | ||
| 'BernoulliNP', | ||
| 'BetaEP', | ||
| 'BetaNP', | ||
| 'BooleanRing', | ||
| 'ChiEP', | ||
| 'ChiNP', | ||
| 'ChiSquareEP', | ||
| 'ChiSquareNP', | ||
| 'CircularBoundedSupport', | ||
| 'ComplexCircularlySymmetricNormalEP', | ||
| 'ComplexCircularlySymmetricNormalNP', | ||
| 'ComplexField', | ||
| 'ComplexMultivariateUnitVarianceNormalEP', | ||
| 'ComplexMultivariateUnitVarianceNormalNP', | ||
| 'ComplexNormalEP', | ||
| 'ComplexNormalNP', | ||
| 'ComplexUnitVarianceNormalEP', | ||
| 'ComplexUnitVarianceNormalNP', | ||
| 'DirichletEP', | ||
| 'DirichletNP', | ||
| 'Distribution', | ||
| 'ExpectationParametrization', | ||
| 'ExponentialEP', | ||
| 'ExponentialNP', | ||
| 'Flattener', | ||
| 'GammaEP', | ||
| 'GammaNP', | ||
| 'GammaVP', | ||
| 'GeneralizedDirichletEP', | ||
| 'GeneralizedDirichletNP', | ||
| 'GeometricEP', | ||
| 'GeometricNP', | ||
| 'HasConjugatePrior', | ||
| 'HasEntropy', | ||
| 'HasEntropyEP', | ||
| 'HasEntropyNP', | ||
| 'HasGeneralizedConjugatePrior', | ||
| 'IntegralRing', | ||
| 'InverseGammaEP', | ||
| 'InverseGammaNP', | ||
| 'InverseGaussianEP', | ||
| 'InverseGaussianNP', | ||
| 'IsotropicNormalEP', | ||
| 'IsotropicNormalNP', | ||
| 'JointDistribution', | ||
| 'JointDistributionE', | ||
| 'JointDistributionN', | ||
| 'LogNormalEP', | ||
| 'LogNormalNP', | ||
| 'LogarithmicEP', | ||
| 'LogarithmicNP', | ||
| 'MaximumLikelihoodEstimator', | ||
| 'Multidimensional', | ||
| 'MultinomialEP', | ||
| 'MultinomialNP', | ||
| 'MultivariateDiagonalNormalEP', | ||
| 'MultivariateDiagonalNormalNP', | ||
| 'MultivariateDiagonalNormalVP', | ||
| 'MultivariateFixedVarianceNormalEP', | ||
| 'MultivariateFixedVarianceNormalNP', | ||
| 'MultivariateNormalEP', | ||
| 'MultivariateNormalNP', | ||
| 'MultivariateNormalVP', | ||
| 'MultivariateUnitVarianceNormalEP', | ||
| 'MultivariateUnitVarianceNormalNP', | ||
| 'NaturalParametrization', | ||
| 'NegativeBinomialEP', | ||
| 'NegativeBinomialNP', | ||
| 'NormalDP', | ||
| 'NormalEP', | ||
| 'NormalNP', | ||
| 'NormalVP', | ||
| 'PoissonEP', | ||
| 'PoissonNP', | ||
| 'RayleighEP', | ||
| 'RayleighNP', | ||
| 'RealField', | ||
| 'Ring', | ||
| 'Samplable', | ||
| 'ScalarSupport', | ||
| 'ScipyComplexMultivariateNormal', | ||
| 'ScipyComplexNormal', | ||
| 'ScipyDirichlet', | ||
| 'ScipyGeneralizedDirichlet', | ||
| 'ScipyGeometric', | ||
| 'ScipyJointDistribution', | ||
| 'ScipyLogNormal', | ||
| 'ScipyMultivariateNormal', | ||
| 'ScipySoftplusNormal', | ||
| 'ScipyVonMises', | ||
| 'ScipyVonMisesFisher', | ||
| 'SimpleDistribution', | ||
| 'SimplexSupport', | ||
| 'SoftplusNormalEP', | ||
| 'SoftplusNormalNP', | ||
| 'SquareMatrixSupport', | ||
| 'Structure', | ||
| 'SubDistributionInfo', | ||
| 'Support', | ||
| 'SymmetricMatrixSupport', | ||
| 'UnitVarianceLogNormalEP', | ||
| 'UnitVarianceLogNormalNP', | ||
| 'UnitVarianceNormalEP', | ||
| 'UnitVarianceNormalNP', | ||
| 'UnitVarianceSoftplusNormalEP', | ||
| 'UnitVarianceSoftplusNormalNP', | ||
| 'VectorSupport', | ||
| 'VonMisesFisherEP', | ||
| 'VonMisesFisherNP', | ||
| 'WeibullEP', | ||
| 'WeibullNP', | ||
| 'distribution_parameter', | ||
| 'flat_dict_of_observations', | ||
| 'flat_dict_of_parameters', | ||
| 'flatten_mapping', | ||
| 'parameter_dot_product', | ||
| 'parameter_map', | ||
| 'parameter_mean', | ||
| 'parameters', | ||
| 'unflatten_mapping', | ||
| 'BernoulliEP', | ||
| 'BernoulliNP', | ||
| 'BetaEP', | ||
| 'BetaNP', | ||
| 'BooleanRing', | ||
| 'ChiEP', | ||
| 'ChiNP', | ||
| 'ChiSquareEP', | ||
| 'ChiSquareNP', | ||
| 'CircularBoundedSupport', | ||
| 'ComplexCircularlySymmetricNormalEP', | ||
| 'ComplexCircularlySymmetricNormalNP', | ||
| 'ComplexField', | ||
| 'ComplexMultivariateUnitVarianceNormalEP', | ||
| 'ComplexMultivariateUnitVarianceNormalNP', | ||
| 'ComplexNormalEP', | ||
| 'ComplexNormalNP', | ||
| 'ComplexUnitVarianceNormalEP', | ||
| 'ComplexUnitVarianceNormalNP', | ||
| 'DirichletEP', | ||
| 'DirichletNP', | ||
| 'Distribution', | ||
| 'ExpectationParametrization', | ||
| 'ExponentialEP', | ||
| 'ExponentialNP', | ||
| 'Flattener', | ||
| 'GammaEP', | ||
| 'GammaNP', | ||
| 'GammaVP', | ||
| 'GeneralizedDirichletEP', | ||
| 'GeneralizedDirichletNP', | ||
| 'GeometricEP', | ||
| 'GeometricNP', | ||
| 'HasConjugatePrior', | ||
| 'HasEntropy', | ||
| 'HasEntropyEP', | ||
| 'HasEntropyNP', | ||
| 'HasGeneralizedConjugatePrior', | ||
| 'IntegralRing', | ||
| 'InverseGammaEP', | ||
| 'InverseGammaNP', | ||
| 'InverseGaussianEP', | ||
| 'InverseGaussianNP', | ||
| 'IsotropicNormalEP', | ||
| 'IsotropicNormalNP', | ||
| 'JointDistribution', | ||
| 'JointDistributionE', | ||
| 'JointDistributionN', | ||
| 'LogNormalEP', | ||
| 'LogNormalNP', | ||
| 'LogarithmicEP', | ||
| 'LogarithmicNP', | ||
| 'MaximumLikelihoodEstimator', | ||
| 'Multidimensional', | ||
| 'MultinomialEP', | ||
| 'MultinomialNP', | ||
| 'MultivariateDiagonalNormalEP', | ||
| 'MultivariateDiagonalNormalNP', | ||
| 'MultivariateDiagonalNormalVP', | ||
| 'MultivariateFixedVarianceNormalEP', | ||
| 'MultivariateFixedVarianceNormalNP', | ||
| 'MultivariateNormalEP', | ||
| 'MultivariateNormalNP', | ||
| 'MultivariateNormalVP', | ||
| 'MultivariateUnitVarianceNormalEP', | ||
| 'MultivariateUnitVarianceNormalNP', | ||
| 'NaturalParametrization', | ||
| 'NegativeBinomialEP', | ||
| 'NegativeBinomialNP', | ||
| 'NormalDP', | ||
| 'NormalEP', | ||
| 'NormalNP', | ||
| 'NormalVP', | ||
| 'PoissonEP', | ||
| 'PoissonNP', | ||
| 'RayleighEP', | ||
| 'RayleighNP', | ||
| 'RealField', | ||
| 'Ring', | ||
| 'Samplable', | ||
| 'ScalarSupport', | ||
| 'ScipyComplexMultivariateNormal', | ||
| 'ScipyComplexNormal', | ||
| 'ScipyDirichlet', | ||
| 'ScipyDiscreteDistribution', | ||
| 'ScipyDistribution', | ||
| 'ScipyGeneralizedDirichlet', | ||
| 'ScipyGeometric', | ||
| 'ScipyJointDistribution', | ||
| 'ScipyLogNormal', | ||
| 'ScipyMultivariateNormal', | ||
| 'ScipySoftplusNormal', | ||
| 'ScipyVonMises', | ||
| 'ScipyVonMisesFisher', | ||
| 'SimpleDistribution', | ||
| 'SimplexSupport', | ||
| 'SoftplusNormalEP', | ||
| 'SoftplusNormalNP', | ||
| 'SquareMatrixSupport', | ||
| 'Structure', | ||
| 'SubDistributionInfo', | ||
| 'Support', | ||
| 'SymmetricMatrixSupport', | ||
| 'UnitVarianceLogNormalEP', | ||
| 'UnitVarianceLogNormalNP', | ||
| 'UnitVarianceNormalEP', | ||
| 'UnitVarianceNormalNP', | ||
| 'UnitVarianceSoftplusNormalEP', | ||
| 'UnitVarianceSoftplusNormalNP', | ||
| 'VectorSupport', | ||
| 'VonMisesFisherEP', | ||
| 'VonMisesFisherNP', | ||
| 'WeibullEP', | ||
| 'WeibullNP', | ||
| 'distribution_parameter', | ||
| 'flat_dict_of_observations', | ||
| 'flat_dict_of_parameters', | ||
| 'flatten_mapping', | ||
| 'parameter_dot_product', | ||
| 'parameter_map', | ||
| 'parameter_mean', | ||
| 'parameters', | ||
| 'unflatten_mapping', | ||
| ] |
| from __future__ import annotations | ||
| from typing import Any | ||
| from jax import Array | ||
| from tjax.dataclasses import field | ||
@@ -14,5 +13,5 @@ | ||
| static: bool = False | ||
| ) -> Any: | ||
| ) -> Array: | ||
| if static and not fixed: | ||
| raise ValueError | ||
| return field(static=static, metadata={'support': support, 'fixed': fixed, 'parameter': True}) |
@@ -5,4 +5,4 @@ from __future__ import annotations | ||
| from collections.abc import Mapping | ||
| from types import ModuleType | ||
| from typing import Any, Self, override | ||
| from types import EllipsisType, ModuleType | ||
| from typing import Self, override | ||
@@ -19,3 +19,3 @@ from array_api_compat import array_namespace | ||
| """The Distribution is the base class of all distributions.""" | ||
| def __getitem__(self, key: Any) -> Self: | ||
| def __getitem__(self, key: tuple[int | slice | EllipsisType | None, ...]) -> Self: | ||
| from .iteration import parameters # noqa: PLC0415 | ||
@@ -22,0 +22,0 @@ from .structure.structure import Structure # noqa: PLC0415 |
@@ -82,3 +82,4 @@ from __future__ import annotations | ||
| class ScipyComplexMultivariateNormal( | ||
| ShapedDistribution[ScipyComplexMultivariateNormalUnvectorized]): | ||
| ShapedDistribution[ | ||
| ScipyComplexMultivariateNormalUnvectorized]): # type: ignore # pyright: ignore | ||
| """This class allows distributions having a non-empty shape.""" | ||
@@ -121,3 +122,5 @@ @override | ||
| for i in np.ndindex(*self.shape): | ||
| objects[i] = self.objects[i].as_multivariate_normal() # pyright: ignore | ||
| this_object = self.objects[i] | ||
| assert isinstance(this_object, ScipyComplexMultivariateNormalUnvectorized) | ||
| objects[i] = this_object.as_multivariate_normal() | ||
| return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects) | ||
@@ -136,3 +139,5 @@ | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].variance # pyright: ignore | ||
| this_object = self.objects[i] | ||
| assert isinstance(this_object, ScipyComplexMultivariateNormalUnvectorized) | ||
| retval[i] = this_object.variance | ||
| return retval | ||
@@ -144,3 +149,5 @@ | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].pseudo_variance # pyright: ignore | ||
| this_object = self.objects[i] | ||
| assert isinstance(this_object, ScipyComplexMultivariateNormalUnvectorized) | ||
| retval[i] = this_object.pseudo_variance | ||
| return retval |
@@ -70,3 +70,4 @@ from __future__ import annotations | ||
| class ScipyComplexNormal(ShapedDistribution[ScipyComplexNormalUnvectorized]): | ||
| class ScipyComplexNormal( | ||
| ShapedDistribution[ScipyComplexNormalUnvectorized]): # type: ignore # pyright: ignore | ||
| """This class allows distributions having a non-empty shape.""" | ||
@@ -118,3 +119,5 @@ @override | ||
| for i in np.ndindex(*self.shape): | ||
| objects[i] = self.objects[i].as_multivariate_normal() # pyright: ignore | ||
| this_object = self.objects[i] | ||
| assert isinstance(this_object, ScipyComplexNormalUnvectorized) | ||
| objects[i] = this_object.as_multivariate_normal() | ||
| return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects) | ||
@@ -133,3 +136,5 @@ | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].variance # pyright: ignore | ||
| this_object = self.objects[i] | ||
| assert isinstance(this_object, ScipyComplexNormalUnvectorized) | ||
| retval[i] = this_object.variance | ||
| return retval | ||
@@ -141,3 +146,5 @@ | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].pseudo_variance # pyright: ignore | ||
| this_object = self.objects[i] | ||
| assert isinstance(this_object, ScipyComplexNormalUnvectorized) | ||
| retval[i] = this_object.pseudo_variance | ||
| return retval |
| from __future__ import annotations | ||
| from typing import Any | ||
| import numpy as np | ||
@@ -32,4 +30,4 @@ import optype.numpy as onp | ||
| def rvs(self, | ||
| size: Any = None, | ||
| random_state: Any = None | ||
| size: tuple[int, ...] | None = None, | ||
| random_state: np.random.Generator | None = None | ||
| ) -> onp.ArrayND[np.float64]: | ||
@@ -39,3 +37,4 @@ if size is None: | ||
| # This somehow fixes the behaviour of rvs. | ||
| return self.distribution.rvs(size=size, random_state=random_state) | ||
| return self.distribution.rvs(size=size, # type: ignore # pyright: ignore | ||
| random_state=random_state) | ||
@@ -46,3 +45,4 @@ def entropy(self) -> NumpyRealArray: | ||
| class ScipyDirichlet(ShapedDistribution[ScipyDirichletFixRVsAndPDF]): | ||
| class ScipyDirichlet( | ||
| ShapedDistribution[ScipyDirichletFixRVsAndPDF]): # type: ignore # pyright: ignore | ||
| """This class allows distributions having a non-empty shape.""" | ||
@@ -49,0 +49,0 @@ @override |
| from __future__ import annotations | ||
| from typing import Any, Self | ||
| from typing import Self | ||
@@ -26,11 +26,7 @@ import numpy as np | ||
| def rvs(self, | ||
| size: Any = None, | ||
| random_state: Any = None) -> onp.ArrayND[np.float64]: | ||
| size: int | tuple[int, ...] = (), | ||
| random_state: np.random.Generator | None = None | ||
| ) -> onp.ArrayND[np.float64]: | ||
| retval = self.distribution.rvs(size=size, random_state=random_state) | ||
| if size is None: | ||
| size = () | ||
| elif isinstance(size, int): | ||
| size = (size,) | ||
| else: | ||
| size = tuple(size) | ||
| size = (size,) if isinstance(size, int) else tuple(size) | ||
| return np.reshape(retval, size + self.distribution.mean.shape) | ||
@@ -42,3 +38,4 @@ | ||
| class ScipyMultivariateNormal(ShapedDistribution[ScipyMultivariateNormalUnvectorized]): | ||
| class ScipyMultivariateNormal( | ||
| ShapedDistribution[ScipyMultivariateNormalUnvectorized]): # type: ignore # pyright: ignore | ||
| """This class allows distributions having a non-empty shape.""" | ||
@@ -45,0 +42,0 @@ @classmethod |
| from __future__ import annotations | ||
| from typing import Any, Generic, TypeVar | ||
| from typing import Any, Generic, TypeVar, cast | ||
@@ -8,8 +8,10 @@ import numpy as np | ||
| from numpy.random import Generator | ||
| from tjax import NumpyComplexArray, NumpyRealArray, Shape | ||
| from tjax import NumpyComplexArray, NumpyIntegralArray, NumpyRealArray, Shape | ||
| from typing_extensions import override | ||
| T = TypeVar('T') | ||
| from .base import ScipyDiscreteDistribution, ScipyDistribution | ||
| T = TypeVar('T', bound=ScipyDiscreteDistribution | ScipyDistribution) | ||
| class ShapedDistribution(Generic[T]): | ||
@@ -41,3 +43,4 @@ """Allow a distributions with shape.""" | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].rvs(size=size, random_state=random_state) # pyright: ignore | ||
| this_object = cast('T', self.objects[i]) | ||
| retval[i] = this_object.rvs(size=size, random_state=random_state) | ||
| return retval | ||
@@ -48,12 +51,24 @@ | ||
| for i in np.ndindex(*self.shape): | ||
| value = self.objects[i].pdf(x[i]) # pyright: ignore | ||
| if i == (): | ||
| return value | ||
| this_object = cast('T', self.objects[i]) | ||
| if not isinstance(this_object, ScipyDistribution): | ||
| raise NotImplementedError | ||
| value = this_object.pdf(x[i]) | ||
| retval[i] = value | ||
| return retval | ||
| def pmf(self, x: NumpyIntegralArray) -> NumpyRealArray: | ||
| retval = np.empty(self.shape, dtype=self.real_dtype) | ||
| for i in np.ndindex(*self.shape): | ||
| this_object = cast('T', self.objects[i]) | ||
| if not isinstance(this_object, ScipyDiscreteDistribution): | ||
| raise NotImplementedError | ||
| value = this_object.pmf(x[i]) | ||
| retval[i] = value | ||
| return retval | ||
| def entropy(self) -> NumpyRealArray: | ||
| retval = np.empty(self.shape, dtype=self.real_dtype) | ||
| for i in np.ndindex(*self.shape): | ||
| retval[i] = self.objects[i].entropy() # pyright: ignore | ||
| this_object = cast('T', self.objects[i]) | ||
| retval[i] = this_object.entropy() | ||
| return retval | ||
@@ -60,0 +75,0 @@ |
@@ -33,1 +33,4 @@ from __future__ import annotations | ||
| return softplus(samples) | ||
| def entropy(self) -> NumpyRealArray: | ||
| raise NotImplementedError |
@@ -11,3 +11,3 @@ from __future__ import annotations | ||
| class ScipyVonMises(ShapedDistribution[object]): | ||
| class ScipyVonMises(ShapedDistribution[object]): # type: ignore # pyright: ignore | ||
| """This class allows distributions having a non-empty shape.""" | ||
@@ -29,3 +29,3 @@ @override | ||
| class ScipyVonMisesFisher(ShapedDistribution[object]): | ||
| class ScipyVonMisesFisher(ShapedDistribution[object]): # type: ignore # pyright: ignore | ||
| """This class allows distributions having a non-empty shape.""" | ||
@@ -42,3 +42,3 @@ @override | ||
| for i in np.ndindex(*shape): | ||
| objects[i] = ss.vonmises_fisher(mu[i], kappa[i]) # pyright: ignore | ||
| objects[i] = ss.vonmises_fisher(mu[i], kappa[i]) # type: ignore # pyright: ignore | ||
| super().__init__(shape, rvs_shape, dtype, objects) |
@@ -7,3 +7,3 @@ from __future__ import annotations | ||
| from itertools import starmap | ||
| from typing import TYPE_CHECKING, Any, TypeVar | ||
| from typing import TYPE_CHECKING, TypeVar | ||
@@ -22,3 +22,3 @@ from array_api_compat import array_namespace | ||
| @jit | ||
| def parameter_dot_product(x: NaturalParametrization, y: Any, /) -> JaxRealArray: | ||
| def parameter_dot_product(x: NaturalParametrization, y: Distribution, /) -> JaxRealArray: | ||
| """Return the vectorized dot product over all of the variable parameters.""" | ||
@@ -25,0 +25,0 @@ def dotted_fields() -> Iterable[JaxRealArray]: |
+2
-13
| Metadata-Version: 2.4 | ||
| Name: efax | ||
| Version: 1.22.1 | ||
| Version: 1.22.2 | ||
| Summary: Exponential families for JAX | ||
@@ -34,16 +34,5 @@ Project-URL: source, https://github.com/NeilGirdhar/efax | ||
| Requires-Dist: scipy>=1.15 | ||
| Requires-Dist: tensorflow-probability>=0.15 | ||
| Requires-Dist: tfp-nightly>=0.25 | ||
| Requires-Dist: tjax>=1.3.10 | ||
| Requires-Dist: typing-extensions>=4.8 | ||
| Provides-Extra: dev | ||
| Requires-Dist: isort>=5.13; extra == 'dev' | ||
| Requires-Dist: jupyter>=1; extra == 'dev' | ||
| Requires-Dist: mypy>=1.12; extra == 'dev' | ||
| Requires-Dist: pre-commit>=4; extra == 'dev' | ||
| Requires-Dist: pylint>=3.3; extra == 'dev' | ||
| Requires-Dist: pyright>=1.1.401; extra == 'dev' | ||
| Requires-Dist: pytest-ordering; extra == 'dev' | ||
| Requires-Dist: pytest-xdist[psutil]>=3; extra == 'dev' | ||
| Requires-Dist: pytest>=8.4; extra == 'dev' | ||
| Requires-Dist: ruff>=0.9.10; extra == 'dev' | ||
| Description-Content-Type: text/x-rst | ||
@@ -50,0 +39,0 @@ |
+307
-79
@@ -5,5 +5,20 @@ [build-system] | ||
| [dependency-groups] | ||
| dev = [ | ||
| "isort>=5.13", | ||
| "jupyter>=1", | ||
| "lefthook>=1.11.13", | ||
| "mypy>=1.12", | ||
| "pylint>=3.3", | ||
| "pyright>=1.1.401", | ||
| "pytest-ordering", | ||
| "pytest-xdist[psutil]>=3", | ||
| "pytest>=8.4", | ||
| "ruff>=0.9.10", | ||
| "toml-sort>=0.24" | ||
| ] | ||
| [project] | ||
| name = "efax" | ||
| version = "1.22.1" | ||
| version = "1.22.2" | ||
| description = "Exponential families for JAX" | ||
@@ -27,3 +42,3 @@ readme = "README.rst" | ||
| "Topic :: Software Development :: Libraries :: Python Modules", | ||
| "Typing :: Typed", | ||
| "Typing :: Typed" | ||
| ] | ||
@@ -40,21 +55,7 @@ dependencies = [ | ||
| "scipy>=1.15", | ||
| "tensorflow_probability>=0.15", | ||
| "tfp-nightly>=0.25", | ||
| "tjax>=1.3.10", | ||
| "typing_extensions>=4.8", | ||
| "typing_extensions>=4.8" | ||
| ] | ||
| [project.optional-dependencies] | ||
| dev = [ | ||
| "isort>=5.13", | ||
| "jupyter>=1", | ||
| "mypy>=1.12", | ||
| "pre-commit>=4", | ||
| "pylint>=3.3", | ||
| "pyright>=1.1.401", | ||
| "pytest-ordering", | ||
| "pytest-xdist[psutil]>=3", | ||
| "pytest>=8.4", | ||
| "ruff>=0.9.10", | ||
| ] | ||
| [project.urls] | ||
@@ -68,2 +69,31 @@ source = "https://github.com/NeilGirdhar/efax" | ||
| [tool.mypy] | ||
| files = ["efax", "tests", "examples"] | ||
| disable_error_code = ["type-abstract"] | ||
| check_untyped_defs = true | ||
| disallow_any_generics = true | ||
| disallow_incomplete_defs = true | ||
| # disallow_untyped_calls = true | ||
| disallow_untyped_decorators = true | ||
| disallow_untyped_defs = true | ||
| no_implicit_optional = true | ||
| pretty = true | ||
| show_error_codes = true | ||
| show_error_context = false | ||
| strict_equality = true | ||
| warn_no_return = true | ||
| warn_redundant_casts = true | ||
| warn_return_any = false | ||
| warn_unreachable = true | ||
| warn_unused_configs = true | ||
| warn_unused_ignores = true | ||
| [[tool.mypy.overrides]] | ||
| module = [ | ||
| "tensorflow_probability.substrates", | ||
| "array_api_compat", | ||
| "opt_einsum" | ||
| ] | ||
| ignore_missing_imports = true | ||
| [tool.pylint.master] | ||
@@ -80,3 +110,3 @@ jobs = 0 | ||
| "pylint.extensions.overlapping_exceptions", | ||
| "pylint.extensions.typing", | ||
| "pylint.extensions.typing" | ||
| ] | ||
@@ -87,28 +117,262 @@ | ||
| # Ruff | ||
| "C0103", "C0105", "C0112", "C0113", "C0114", "C0115", "C0116", "C0121", "C0123", "C0131", "C0132", "C0198", "C0199", "C0201", "C0202", "C0205", | ||
| "C0206", "C0208", "C0301", "C0303", "C0304", "C0305", "C0321", "C0410", "C0411", "C0412", "C0413", "C0414", "C0415", "C0501", "C1802", "C1901", | ||
| "C2201", "C2401", "C2403", "C2701", "C2801", "C3001", "C3002", "E0001", "E0013", "E0014", "E0100", "E0101", "E0102", "E0103", "E0104", "E0105", | ||
| "E0106", "E0107", "E0108", "E0112", "E0115", "E0116", "E0117", "E0118", "E0213", "E0237", "E0241", "E0302", "E0303", "E0304", "E0305", "E0308", | ||
| "E0309", "E0402", "E0602", "E0603", "E0604", "E0605", "E0643", "E0704", "E0711", "E1132", "E1142", "E1205", "E1206", "E1300", "E1301", "E1302", | ||
| "E1303", "E1304", "E1305", "E1306", "E1307", "E1310", "E1519", "E1520", "E1700", "E2502", "E2510", "E2512", "E2513", "E2514", "E2515", "E4703", | ||
| "E6004", "E6005", "R0022", "R0123", "R0124", "R0133", "R0202", "R0203", "R0205", "R0206", "R0402", "R0904", "R0911", "R0912", "R0913", "R0914", | ||
| "R0915", "R0916", "R1260", "R1701", "R1702", "R1703", "R1704", "R1705", "R1706", "R1707", "R1710", "R1711", "R1714", "R1715", "R1717", "R1718", | ||
| "R1719", "R1720", "R1721", "R1722", "R1723", "R1724", "R1725", "R1728", "R1729", "R1730", "R1731", "R1732", "R1733", "R1734", "R1735", "R1736", | ||
| "R2004", "R2044", "R5501", "R6002", "R6003", "R6104", "R6201", "R6301", "W0012", "W0102", "W0104", "W0106", "W0107", "W0108", "W0109", "W0120", | ||
| "W0122", "W0123", "W0127", "W0129", "W0130", "W0131", "W0133", "W0150", "W0160", "W0177", "W0199", "W0211", "W0212", "W0245", "W0301", "W0401", | ||
| "W0404", "W0406", "W0410", "W0511", "W0602", "W0603", "W0604", "W0611", "W0612", "W0613", "W0622", "W0640", "W0702", "W0705", "W0706", "W0707", | ||
| "W0711", "W0718", "W0719", "W1113", "W1201", "W1202", "W1203", "W1301", "W1302", "W1303", "W1304", "W1305", "W1309", "W1310", "W1401", "W1404", | ||
| "W1405", "W1406", "W1501", "W1502", "W1508", "W1509", "W1510", "W1514", "W1515", "W1518", "W1641", "W2101", "W2402", "W2601", "W2901", "W3201", | ||
| "C0103", | ||
| "C0105", | ||
| "C0112", | ||
| "C0113", | ||
| "C0114", | ||
| "C0115", | ||
| "C0116", | ||
| "C0121", | ||
| "C0123", | ||
| "C0131", | ||
| "C0132", | ||
| "C0198", | ||
| "C0199", | ||
| "C0201", | ||
| "C0202", | ||
| "C0205", | ||
| "C0206", | ||
| "C0208", | ||
| "C0301", | ||
| "C0303", | ||
| "C0304", | ||
| "C0305", | ||
| "C0321", | ||
| "C0410", | ||
| "C0411", | ||
| "C0412", | ||
| "C0413", | ||
| "C0414", | ||
| "C0415", | ||
| "C0501", | ||
| "C1802", | ||
| "C1901", | ||
| "C2201", | ||
| "C2401", | ||
| "C2403", | ||
| "C2701", | ||
| "C2801", | ||
| "C3001", | ||
| "C3002", | ||
| "E0001", | ||
| "E0013", | ||
| "E0014", | ||
| "E0100", | ||
| "E0101", | ||
| "E0102", | ||
| "E0103", | ||
| "E0104", | ||
| "E0105", | ||
| "E0106", | ||
| "E0107", | ||
| "E0108", | ||
| "E0112", | ||
| "E0115", | ||
| "E0116", | ||
| "E0117", | ||
| "E0118", | ||
| "E0213", | ||
| "E0237", | ||
| "E0241", | ||
| "E0302", | ||
| "E0303", | ||
| "E0304", | ||
| "E0305", | ||
| "E0308", | ||
| "E0309", | ||
| "E0402", | ||
| "E0602", | ||
| "E0603", | ||
| "E0604", | ||
| "E0605", | ||
| "E0643", | ||
| "E0704", | ||
| "E0711", | ||
| "E1132", | ||
| "E1142", | ||
| "E1205", | ||
| "E1206", | ||
| "E1300", | ||
| "E1301", | ||
| "E1302", | ||
| "E1303", | ||
| "E1304", | ||
| "E1305", | ||
| "E1306", | ||
| "E1307", | ||
| "E1310", | ||
| "E1519", | ||
| "E1520", | ||
| "E1700", | ||
| "E2502", | ||
| "E2510", | ||
| "E2512", | ||
| "E2513", | ||
| "E2514", | ||
| "E2515", | ||
| "E4703", | ||
| "E6004", | ||
| "E6005", | ||
| "R0022", | ||
| "R0123", | ||
| "R0124", | ||
| "R0133", | ||
| "R0202", | ||
| "R0203", | ||
| "R0205", | ||
| "R0206", | ||
| "R0402", | ||
| "R0904", | ||
| "R0911", | ||
| "R0912", | ||
| "R0913", | ||
| "R0914", | ||
| "R0915", | ||
| "R0916", | ||
| "R1260", | ||
| "R1701", | ||
| "R1702", | ||
| "R1703", | ||
| "R1704", | ||
| "R1705", | ||
| "R1706", | ||
| "R1707", | ||
| "R1710", | ||
| "R1711", | ||
| "R1714", | ||
| "R1715", | ||
| "R1717", | ||
| "R1718", | ||
| "R1719", | ||
| "R1720", | ||
| "R1721", | ||
| "R1722", | ||
| "R1723", | ||
| "R1724", | ||
| "R1725", | ||
| "R1728", | ||
| "R1729", | ||
| "R1730", | ||
| "R1731", | ||
| "R1732", | ||
| "R1733", | ||
| "R1734", | ||
| "R1735", | ||
| "R1736", | ||
| "R2004", | ||
| "R2044", | ||
| "R5501", | ||
| "R6002", | ||
| "R6003", | ||
| "R6104", | ||
| "R6201", | ||
| "R6301", | ||
| "W0012", | ||
| "W0102", | ||
| "W0104", | ||
| "W0106", | ||
| "W0107", | ||
| "W0108", | ||
| "W0109", | ||
| "W0120", | ||
| "W0122", | ||
| "W0123", | ||
| "W0127", | ||
| "W0129", | ||
| "W0130", | ||
| "W0131", | ||
| "W0133", | ||
| "W0150", | ||
| "W0160", | ||
| "W0177", | ||
| "W0199", | ||
| "W0211", | ||
| "W0212", | ||
| "W0245", | ||
| "W0301", | ||
| "W0401", | ||
| "W0404", | ||
| "W0406", | ||
| "W0410", | ||
| "W0511", | ||
| "W0602", | ||
| "W0603", | ||
| "W0604", | ||
| "W0611", | ||
| "W0612", | ||
| "W0613", | ||
| "W0622", | ||
| "W0640", | ||
| "W0702", | ||
| "W0705", | ||
| "W0706", | ||
| "W0707", | ||
| "W0711", | ||
| "W0718", | ||
| "W0719", | ||
| "W1113", | ||
| "W1201", | ||
| "W1202", | ||
| "W1203", | ||
| "W1301", | ||
| "W1302", | ||
| "W1303", | ||
| "W1304", | ||
| "W1305", | ||
| "W1309", | ||
| "W1310", | ||
| "W1401", | ||
| "W1404", | ||
| "W1405", | ||
| "W1406", | ||
| "W1501", | ||
| "W1502", | ||
| "W1508", | ||
| "W1509", | ||
| "W1510", | ||
| "W1514", | ||
| "W1515", | ||
| "W1518", | ||
| "W1641", | ||
| "W2101", | ||
| "W2402", | ||
| "W2601", | ||
| "W2901", | ||
| "W3201", | ||
| "W3301", | ||
| # Missing | ||
| "E0601", "R1737", "W0311", "W2301", | ||
| "E0601", | ||
| "R1737", | ||
| "W0311", | ||
| "W2301", | ||
| # Mine | ||
| "C0111", "E1101", "E1102", "E1120", "E1123", "E1130", "E1135", "E1136", "E3701", "R0204", "R0401", "R0801", "R0901", "R0902", "R0903", "R0917", | ||
| "R5601", "R6102", "R6103", "W0149", "W0221", "W0222", "W0223", "W0621", "W0717", | ||
| "C0111", | ||
| "E1101", | ||
| "E1102", | ||
| "E1120", | ||
| "E1123", | ||
| "E1130", | ||
| "E1135", | ||
| "E1136", | ||
| "E3701", | ||
| "R0204", | ||
| "R0401", | ||
| "R0801", | ||
| "R0901", | ||
| "R0902", | ||
| "R0903", | ||
| "R0917", | ||
| "R5601", | ||
| "R6102", | ||
| "R6103", | ||
| "W0149", | ||
| "W0221", | ||
| "W0222", | ||
| "W0223", | ||
| "W0621", | ||
| "W0717" | ||
| ] | ||
| enable = [ | ||
| "useless-suppression", | ||
| "use-symbolic-message-instead", | ||
| "use-symbolic-message-instead" | ||
| ] | ||
@@ -168,30 +432,2 @@ | ||
| [tool.mypy] | ||
| files = ["efax", "tests", "examples"] | ||
| disable_error_code = ["type-abstract"] | ||
| check_untyped_defs = true | ||
| disallow_any_generics = true | ||
| disallow_incomplete_defs = true | ||
| # disallow_untyped_calls = true | ||
| disallow_untyped_decorators = true | ||
| disallow_untyped_defs = true | ||
| no_implicit_optional = true | ||
| pretty = true | ||
| show_error_codes = true | ||
| show_error_context = false | ||
| strict_equality = true | ||
| warn_no_return = true | ||
| warn_redundant_casts = true | ||
| warn_return_any = false | ||
| warn_unreachable = true | ||
| warn_unused_configs = true | ||
| warn_unused_ignores = true | ||
| [[tool.mypy.overrides]] | ||
| module = [ | ||
| "tensorflow_probability.substrates", | ||
| "array_api_compat", | ||
| ] | ||
| ignore_missing_imports = true | ||
| [tool.ruff] | ||
@@ -204,3 +440,2 @@ line-length = 100 | ||
| ignore = [ | ||
| "ANN401", # Dynamically typed expressions (Any). | ||
| "ARG001", # Unused function argument. | ||
@@ -210,4 +445,2 @@ "ARG002", # Unused method argument. | ||
| "ARG004", # Unused static method argument. | ||
| "B011", # Do not assert false. | ||
| "C901", # Complex structure. | ||
| "COM812", # Trailing comma missing. | ||
@@ -226,4 +459,2 @@ "CPY001", # Missing copyright. | ||
| "ERA001", # Commented-out code. | ||
| "F722", # Syntax error in forward annotation. | ||
| "FBT003", # Boolean positional value in function call. | ||
| "FIX002", # Line contains TODO, consider resolving the issue. | ||
@@ -233,7 +464,5 @@ "G004", # Logging statement uses f-string. | ||
| "PD008", # Use .loc instead of .at. If speed is important, use NumPy. | ||
| "PD013", # `.melt` is preferred to `.stack`; provides same functionality | ||
| "PGH003", # Use specific rule codes when ignoring type issues. | ||
| "PLR0913", # Too many arguments in function definition. | ||
| "PLR6301", # Method doesn"t use self. | ||
| "PT013", # Found incorrect import of pytest, use simple import pytest instead. | ||
| "Q000", # Single quotes found but double quotes preferred. | ||
@@ -247,6 +476,8 @@ "RUF021", # Parenthesize `a and b` expressions when chaining `and` and `or` together... | ||
| "TD003", # Missing issue link on the line following this TODO. | ||
| "TID252", # Relative imports from parent modules are banned. | ||
| "UP037", # Remove quotes from type annotation. | ||
| "TID252" # Relative imports from parent modules are banned. | ||
| ] | ||
| [tool.ruff.lint.flake8-errmsg] | ||
| max-string-length = 40 | ||
| [tool.ruff.lint.flake8-import-conventions.extend-aliases] | ||
@@ -267,5 +498,2 @@ "array_api_extra" = "xpx" | ||
| [tool.ruff.lint.flake8-errmsg] | ||
| max-string-length = 40 | ||
| [tool.ruff.lint.isort] | ||
@@ -272,0 +500,0 @@ combine-as-imports = true |
@@ -23,3 +23,3 @@ from __future__ import annotations | ||
| def _jax_fixture(request: pytest.FixtureRequest) -> Generator[None]: # pyright: ignore | ||
| with jax.debug_key_reuse(True), jax.numpy_rank_promotion('raise'), enable_x64(): | ||
| with jax.debug_key_reuse(new_val=True), jax.numpy_rank_promotion('raise'), enable_x64(): | ||
| yield | ||
@@ -26,0 +26,0 @@ |
@@ -360,3 +360,3 @@ from __future__ import annotations | ||
| class LogNormal(DistributionInfo[LogNormalNP, LogNormalEP, NumpyRealArray]): | ||
| class LogNormalInfo(DistributionInfo[LogNormalNP, LogNormalEP, NumpyRealArray]): | ||
| @override | ||
@@ -661,3 +661,3 @@ def exp_to_scipy_distribution(self, p: LogNormalEP) -> Any: | ||
| JointInfo(infos={'gamma': GammaInfo(), 'normal': NormalInfo()}), | ||
| LogNormal(), | ||
| LogNormalInfo(), | ||
| LogarithmicInfo(), | ||
@@ -664,0 +664,0 @@ MultivariateDiagonalNormalInfo(dimensions=4), |
@@ -12,3 +12,4 @@ from __future__ import annotations | ||
| from efax import ExpectationParametrization, NaturalParametrization, Structure, SubDistributionInfo | ||
| from efax import (ExpectationParametrization, NaturalParametrization, ScipyDiscreteDistribution, | ||
| ScipyDistribution, Structure, SubDistributionInfo) | ||
@@ -26,3 +27,3 @@ NP = TypeVar('NP', bound=NaturalParametrization, default=Any) | ||
| def exp_to_scipy_distribution(self, p: EP) -> Any: | ||
| def exp_to_scipy_distribution(self, p: EP) -> ScipyDistribution | ScipyDiscreteDistribution: | ||
| """Produce a corresponding scipy distribution from expectation parameters. | ||
@@ -35,3 +36,3 @@ | ||
| def nat_to_scipy_distribution(self, q: NP) -> Any: | ||
| def nat_to_scipy_distribution(self, q: NP) -> ScipyDistribution | ScipyDiscreteDistribution: | ||
| """Produce a corresponding scipy distribution from natural parameters. | ||
@@ -102,7 +103,7 @@ | ||
| def new_method(*args: Any, | ||
| def new_method(*args: object, | ||
| old_method: Callable[..., Any] = old_method, | ||
| **kwargs: Any) -> Any: | ||
| **kwargs: object) -> object: | ||
| return old_method(*args, **kwargs) | ||
| setattr(cls, method, new_method) |
@@ -11,3 +11,4 @@ from __future__ import annotations | ||
| from efax import JointDistributionN, Multidimensional, NaturalParametrization, SimpleDistribution | ||
| from efax import (JointDistributionN, Multidimensional, NaturalParametrization, | ||
| ScipyDiscreteDistribution, ScipyDistribution, SimpleDistribution) | ||
@@ -48,6 +49,7 @@ from ..create_info import MultivariateDiagonalNormalInfo | ||
| efax_density = np.asarray(nat_parameters.pdf(efax_x), dtype=np.float64) | ||
| try: | ||
| if isinstance(scipy_distribution, ScipyDistribution): | ||
| scipy_density = scipy_distribution.pdf(scipy_x) | ||
| except AttributeError: | ||
| scipy_density = scipy_distribution.pmf(scipy_x) | ||
| else: | ||
| assert isinstance(scipy_distribution, ScipyDiscreteDistribution) | ||
| scipy_density = scipy_distribution.pmf(scipy_x) # type: ignore # pyright: ignore | ||
@@ -54,0 +56,0 @@ if isinstance(distribution_info, MultivariateDiagonalNormalInfo): |
@@ -38,3 +38,3 @@ from __future__ import annotations | ||
| np.average( | ||
| np.conj(z)[..., np.newaxis, :, :] * z[..., np.newaxis, :], # type: ignore[arg-type] | ||
| np.conj(z)[..., np.newaxis, :, :] * z[..., np.newaxis, :], | ||
| weights=weights, | ||
@@ -41,0 +41,0 @@ axis=-1) |
@@ -5,3 +5,2 @@ """These tests verify entropy gradients.""" | ||
| from functools import partial | ||
| from typing import Any | ||
@@ -30,3 +29,3 @@ import jax.numpy as jnp | ||
| @jit | ||
| def _all_finite(some_tree: Any, /) -> JaxBooleanArray: | ||
| def _all_finite(some_tree: object, /) -> JaxBooleanArray: | ||
| return dynamic_tree_all(tree.map(lambda x: jnp.all(jnp.isfinite(x)), some_tree)) | ||
@@ -33,0 +32,0 @@ |
| from __future__ import annotations | ||
| from typing import Any | ||
| import numpy as np | ||
@@ -10,3 +8,3 @@ import pytest | ||
| from efax import ScipyDirichlet, ScipyMultivariateNormal | ||
| from efax import ScipyDirichlet, ScipyDistribution, ScipyMultivariateNormal | ||
@@ -22,3 +20,4 @@ | ||
| ]) | ||
| def test_shaped(generator: NumpyGenerator, distribution: Any, m: Shape, n: int) -> None: | ||
| def test_shaped(generator: NumpyGenerator, distribution: ScipyDistribution, m: Shape, n: int | ||
| ) -> None: | ||
| assert distribution.rvs().shape == (*m, n) | ||
@@ -25,0 +24,0 @@ assert distribution.rvs(1).shape == (*m, 1, n) |
| # Install the pre-commit hooks below with | ||
| # 'pre-commit install' | ||
| # Auto-update the version of the hooks with | ||
| # 'pre-commit autoupdate' | ||
| # Run the hooks on all files with | ||
| # 'pre-commit run --all' | ||
| repos: | ||
| - repo: https://github.com/pre-commit/pre-commit-hooks | ||
| rev: v5.0.0 | ||
| hooks: | ||
| - id: check-toml | ||
| - id: check-yaml | ||
| - repo: https://github.com/astral-sh/ruff-pre-commit | ||
| rev: v0.11.10 | ||
| hooks: | ||
| - id: ruff | ||
| - repo: https://github.com/RobertCraigie/pyright-python | ||
| rev: v1.1.401 | ||
| hooks: | ||
| - id: pyright |
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
123
0.82%8119
0.66%995395
-0.21%