You're Invited:Meet the Socket Team at RSAC and BSidesSF 2026, March 23–26.RSVP
Socket
Book a DemoSign in
Socket

efax

Package Overview
Dependencies
Maintainers
1
Versions
115
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

efax - pypi Package Compare versions

Comparing version
1.22.1
to
1.22.2
+31
efax/_src/scipy_replacement/base.py
from __future__ import annotations
from typing import Protocol, runtime_checkable
from numpy.random import Generator
from tjax import NumpyComplexArray, NumpyIntegralArray, NumpyRealArray, ShapeLike
@runtime_checkable
class ScipyDistribution(Protocol):
def pdf(self, x: NumpyComplexArray) -> NumpyRealArray:
...
def rvs(self, size: ShapeLike = (), random_state: Generator | None = None) -> NumpyComplexArray:
...
def entropy(self) -> NumpyRealArray:
...
@runtime_checkable
class ScipyDiscreteDistribution(Protocol):
def pmf(self, x: NumpyIntegralArray) -> NumpyRealArray:
...
def rvs(self, size: ShapeLike = (), random_state: Generator | None = None
) -> NumpyIntegralArray:
...
def entropy(self) -> NumpyRealArray:
...
pre-commit:
parallel: true
jobs:
- name: ruff-fix
glob: "*.py"
run: uv run ruff check --fix {staged_files}
stage_fixed: true
- name: pyright
glob: "*.py"
run: uv run pyright {staged_files}
- name: toml-sort
glob: "*.toml"
run: toml-sort -i {staged_files}
stage_fixed: true
+124
-121

@@ -60,2 +60,3 @@ """The EFAX Library."""

from ._src.parametrization import Distribution, SimpleDistribution
from ._src.scipy_replacement.base import ScipyDiscreteDistribution, ScipyDistribution
from ._src.scipy_replacement.complex_multivariate_normal import ScipyComplexMultivariateNormal

@@ -77,123 +78,125 @@ from ._src.scipy_replacement.complex_normal import ScipyComplexNormal

__all__ = [
'BernoulliEP',
'BernoulliNP',
'BetaEP',
'BetaNP',
'BooleanRing',
'ChiEP',
'ChiNP',
'ChiSquareEP',
'ChiSquareNP',
'CircularBoundedSupport',
'ComplexCircularlySymmetricNormalEP',
'ComplexCircularlySymmetricNormalNP',
'ComplexField',
'ComplexMultivariateUnitVarianceNormalEP',
'ComplexMultivariateUnitVarianceNormalNP',
'ComplexNormalEP',
'ComplexNormalNP',
'ComplexUnitVarianceNormalEP',
'ComplexUnitVarianceNormalNP',
'DirichletEP',
'DirichletNP',
'Distribution',
'ExpectationParametrization',
'ExponentialEP',
'ExponentialNP',
'Flattener',
'GammaEP',
'GammaNP',
'GammaVP',
'GeneralizedDirichletEP',
'GeneralizedDirichletNP',
'GeometricEP',
'GeometricNP',
'HasConjugatePrior',
'HasEntropy',
'HasEntropyEP',
'HasEntropyNP',
'HasGeneralizedConjugatePrior',
'IntegralRing',
'InverseGammaEP',
'InverseGammaNP',
'InverseGaussianEP',
'InverseGaussianNP',
'IsotropicNormalEP',
'IsotropicNormalNP',
'JointDistribution',
'JointDistributionE',
'JointDistributionN',
'LogNormalEP',
'LogNormalNP',
'LogarithmicEP',
'LogarithmicNP',
'MaximumLikelihoodEstimator',
'Multidimensional',
'MultinomialEP',
'MultinomialNP',
'MultivariateDiagonalNormalEP',
'MultivariateDiagonalNormalNP',
'MultivariateDiagonalNormalVP',
'MultivariateFixedVarianceNormalEP',
'MultivariateFixedVarianceNormalNP',
'MultivariateNormalEP',
'MultivariateNormalNP',
'MultivariateNormalVP',
'MultivariateUnitVarianceNormalEP',
'MultivariateUnitVarianceNormalNP',
'NaturalParametrization',
'NegativeBinomialEP',
'NegativeBinomialNP',
'NormalDP',
'NormalEP',
'NormalNP',
'NormalVP',
'PoissonEP',
'PoissonNP',
'RayleighEP',
'RayleighNP',
'RealField',
'Ring',
'Samplable',
'ScalarSupport',
'ScipyComplexMultivariateNormal',
'ScipyComplexNormal',
'ScipyDirichlet',
'ScipyGeneralizedDirichlet',
'ScipyGeometric',
'ScipyJointDistribution',
'ScipyLogNormal',
'ScipyMultivariateNormal',
'ScipySoftplusNormal',
'ScipyVonMises',
'ScipyVonMisesFisher',
'SimpleDistribution',
'SimplexSupport',
'SoftplusNormalEP',
'SoftplusNormalNP',
'SquareMatrixSupport',
'Structure',
'SubDistributionInfo',
'Support',
'SymmetricMatrixSupport',
'UnitVarianceLogNormalEP',
'UnitVarianceLogNormalNP',
'UnitVarianceNormalEP',
'UnitVarianceNormalNP',
'UnitVarianceSoftplusNormalEP',
'UnitVarianceSoftplusNormalNP',
'VectorSupport',
'VonMisesFisherEP',
'VonMisesFisherNP',
'WeibullEP',
'WeibullNP',
'distribution_parameter',
'flat_dict_of_observations',
'flat_dict_of_parameters',
'flatten_mapping',
'parameter_dot_product',
'parameter_map',
'parameter_mean',
'parameters',
'unflatten_mapping',
'BernoulliEP',
'BernoulliNP',
'BetaEP',
'BetaNP',
'BooleanRing',
'ChiEP',
'ChiNP',
'ChiSquareEP',
'ChiSquareNP',
'CircularBoundedSupport',
'ComplexCircularlySymmetricNormalEP',
'ComplexCircularlySymmetricNormalNP',
'ComplexField',
'ComplexMultivariateUnitVarianceNormalEP',
'ComplexMultivariateUnitVarianceNormalNP',
'ComplexNormalEP',
'ComplexNormalNP',
'ComplexUnitVarianceNormalEP',
'ComplexUnitVarianceNormalNP',
'DirichletEP',
'DirichletNP',
'Distribution',
'ExpectationParametrization',
'ExponentialEP',
'ExponentialNP',
'Flattener',
'GammaEP',
'GammaNP',
'GammaVP',
'GeneralizedDirichletEP',
'GeneralizedDirichletNP',
'GeometricEP',
'GeometricNP',
'HasConjugatePrior',
'HasEntropy',
'HasEntropyEP',
'HasEntropyNP',
'HasGeneralizedConjugatePrior',
'IntegralRing',
'InverseGammaEP',
'InverseGammaNP',
'InverseGaussianEP',
'InverseGaussianNP',
'IsotropicNormalEP',
'IsotropicNormalNP',
'JointDistribution',
'JointDistributionE',
'JointDistributionN',
'LogNormalEP',
'LogNormalNP',
'LogarithmicEP',
'LogarithmicNP',
'MaximumLikelihoodEstimator',
'Multidimensional',
'MultinomialEP',
'MultinomialNP',
'MultivariateDiagonalNormalEP',
'MultivariateDiagonalNormalNP',
'MultivariateDiagonalNormalVP',
'MultivariateFixedVarianceNormalEP',
'MultivariateFixedVarianceNormalNP',
'MultivariateNormalEP',
'MultivariateNormalNP',
'MultivariateNormalVP',
'MultivariateUnitVarianceNormalEP',
'MultivariateUnitVarianceNormalNP',
'NaturalParametrization',
'NegativeBinomialEP',
'NegativeBinomialNP',
'NormalDP',
'NormalEP',
'NormalNP',
'NormalVP',
'PoissonEP',
'PoissonNP',
'RayleighEP',
'RayleighNP',
'RealField',
'Ring',
'Samplable',
'ScalarSupport',
'ScipyComplexMultivariateNormal',
'ScipyComplexNormal',
'ScipyDirichlet',
'ScipyDiscreteDistribution',
'ScipyDistribution',
'ScipyGeneralizedDirichlet',
'ScipyGeometric',
'ScipyJointDistribution',
'ScipyLogNormal',
'ScipyMultivariateNormal',
'ScipySoftplusNormal',
'ScipyVonMises',
'ScipyVonMisesFisher',
'SimpleDistribution',
'SimplexSupport',
'SoftplusNormalEP',
'SoftplusNormalNP',
'SquareMatrixSupport',
'Structure',
'SubDistributionInfo',
'Support',
'SymmetricMatrixSupport',
'UnitVarianceLogNormalEP',
'UnitVarianceLogNormalNP',
'UnitVarianceNormalEP',
'UnitVarianceNormalNP',
'UnitVarianceSoftplusNormalEP',
'UnitVarianceSoftplusNormalNP',
'VectorSupport',
'VonMisesFisherEP',
'VonMisesFisherNP',
'WeibullEP',
'WeibullNP',
'distribution_parameter',
'flat_dict_of_observations',
'flat_dict_of_parameters',
'flatten_mapping',
'parameter_dot_product',
'parameter_map',
'parameter_mean',
'parameters',
'unflatten_mapping',
]
from __future__ import annotations
from typing import Any
from jax import Array
from tjax.dataclasses import field

@@ -14,5 +13,5 @@

static: bool = False
) -> Any:
) -> Array:
if static and not fixed:
raise ValueError
return field(static=static, metadata={'support': support, 'fixed': fixed, 'parameter': True})

@@ -5,4 +5,4 @@ from __future__ import annotations

from collections.abc import Mapping
from types import ModuleType
from typing import Any, Self, override
from types import EllipsisType, ModuleType
from typing import Self, override

@@ -19,3 +19,3 @@ from array_api_compat import array_namespace

"""The Distribution is the base class of all distributions."""
def __getitem__(self, key: Any) -> Self:
def __getitem__(self, key: tuple[int | slice | EllipsisType | None, ...]) -> Self:
from .iteration import parameters # noqa: PLC0415

@@ -22,0 +22,0 @@ from .structure.structure import Structure # noqa: PLC0415

@@ -82,3 +82,4 @@ from __future__ import annotations

class ScipyComplexMultivariateNormal(
ShapedDistribution[ScipyComplexMultivariateNormalUnvectorized]):
ShapedDistribution[
ScipyComplexMultivariateNormalUnvectorized]): # type: ignore # pyright: ignore
"""This class allows distributions having a non-empty shape."""

@@ -121,3 +122,5 @@ @override

for i in np.ndindex(*self.shape):
objects[i] = self.objects[i].as_multivariate_normal() # pyright: ignore
this_object = self.objects[i]
assert isinstance(this_object, ScipyComplexMultivariateNormalUnvectorized)
objects[i] = this_object.as_multivariate_normal()
return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects)

@@ -136,3 +139,5 @@

for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].variance # pyright: ignore
this_object = self.objects[i]
assert isinstance(this_object, ScipyComplexMultivariateNormalUnvectorized)
retval[i] = this_object.variance
return retval

@@ -144,3 +149,5 @@

for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].pseudo_variance # pyright: ignore
this_object = self.objects[i]
assert isinstance(this_object, ScipyComplexMultivariateNormalUnvectorized)
retval[i] = this_object.pseudo_variance
return retval

@@ -70,3 +70,4 @@ from __future__ import annotations

class ScipyComplexNormal(ShapedDistribution[ScipyComplexNormalUnvectorized]):
class ScipyComplexNormal(
ShapedDistribution[ScipyComplexNormalUnvectorized]): # type: ignore # pyright: ignore
"""This class allows distributions having a non-empty shape."""

@@ -118,3 +119,5 @@ @override

for i in np.ndindex(*self.shape):
objects[i] = self.objects[i].as_multivariate_normal() # pyright: ignore
this_object = self.objects[i]
assert isinstance(this_object, ScipyComplexNormalUnvectorized)
objects[i] = this_object.as_multivariate_normal()
return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects)

@@ -133,3 +136,5 @@

for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].variance # pyright: ignore
this_object = self.objects[i]
assert isinstance(this_object, ScipyComplexNormalUnvectorized)
retval[i] = this_object.variance
return retval

@@ -141,3 +146,5 @@

for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].pseudo_variance # pyright: ignore
this_object = self.objects[i]
assert isinstance(this_object, ScipyComplexNormalUnvectorized)
retval[i] = this_object.pseudo_variance
return retval
from __future__ import annotations
from typing import Any
import numpy as np

@@ -32,4 +30,4 @@ import optype.numpy as onp

def rvs(self,
size: Any = None,
random_state: Any = None
size: tuple[int, ...] | None = None,
random_state: np.random.Generator | None = None
) -> onp.ArrayND[np.float64]:

@@ -39,3 +37,4 @@ if size is None:

# This somehow fixes the behaviour of rvs.
return self.distribution.rvs(size=size, random_state=random_state)
return self.distribution.rvs(size=size, # type: ignore # pyright: ignore
random_state=random_state)

@@ -46,3 +45,4 @@ def entropy(self) -> NumpyRealArray:

class ScipyDirichlet(ShapedDistribution[ScipyDirichletFixRVsAndPDF]):
class ScipyDirichlet(
ShapedDistribution[ScipyDirichletFixRVsAndPDF]): # type: ignore # pyright: ignore
"""This class allows distributions having a non-empty shape."""

@@ -49,0 +49,0 @@ @override

from __future__ import annotations
from typing import Any, Self
from typing import Self

@@ -26,11 +26,7 @@ import numpy as np

def rvs(self,
size: Any = None,
random_state: Any = None) -> onp.ArrayND[np.float64]:
size: int | tuple[int, ...] = (),
random_state: np.random.Generator | None = None
) -> onp.ArrayND[np.float64]:
retval = self.distribution.rvs(size=size, random_state=random_state)
if size is None:
size = ()
elif isinstance(size, int):
size = (size,)
else:
size = tuple(size)
size = (size,) if isinstance(size, int) else tuple(size)
return np.reshape(retval, size + self.distribution.mean.shape)

@@ -42,3 +38,4 @@

class ScipyMultivariateNormal(ShapedDistribution[ScipyMultivariateNormalUnvectorized]):
class ScipyMultivariateNormal(
ShapedDistribution[ScipyMultivariateNormalUnvectorized]): # type: ignore # pyright: ignore
"""This class allows distributions having a non-empty shape."""

@@ -45,0 +42,0 @@ @classmethod

from __future__ import annotations
from typing import Any, Generic, TypeVar
from typing import Any, Generic, TypeVar, cast

@@ -8,8 +8,10 @@ import numpy as np

from numpy.random import Generator
from tjax import NumpyComplexArray, NumpyRealArray, Shape
from tjax import NumpyComplexArray, NumpyIntegralArray, NumpyRealArray, Shape
from typing_extensions import override
T = TypeVar('T')
from .base import ScipyDiscreteDistribution, ScipyDistribution
T = TypeVar('T', bound=ScipyDiscreteDistribution | ScipyDistribution)
class ShapedDistribution(Generic[T]):

@@ -41,3 +43,4 @@ """Allow a distributions with shape."""

for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].rvs(size=size, random_state=random_state) # pyright: ignore
this_object = cast('T', self.objects[i])
retval[i] = this_object.rvs(size=size, random_state=random_state)
return retval

@@ -48,12 +51,24 @@

for i in np.ndindex(*self.shape):
value = self.objects[i].pdf(x[i]) # pyright: ignore
if i == ():
return value
this_object = cast('T', self.objects[i])
if not isinstance(this_object, ScipyDistribution):
raise NotImplementedError
value = this_object.pdf(x[i])
retval[i] = value
return retval
def pmf(self, x: NumpyIntegralArray) -> NumpyRealArray:
retval = np.empty(self.shape, dtype=self.real_dtype)
for i in np.ndindex(*self.shape):
this_object = cast('T', self.objects[i])
if not isinstance(this_object, ScipyDiscreteDistribution):
raise NotImplementedError
value = this_object.pmf(x[i])
retval[i] = value
return retval
def entropy(self) -> NumpyRealArray:
retval = np.empty(self.shape, dtype=self.real_dtype)
for i in np.ndindex(*self.shape):
retval[i] = self.objects[i].entropy() # pyright: ignore
this_object = cast('T', self.objects[i])
retval[i] = this_object.entropy()
return retval

@@ -60,0 +75,0 @@

@@ -33,1 +33,4 @@ from __future__ import annotations

return softplus(samples)
def entropy(self) -> NumpyRealArray:
raise NotImplementedError

@@ -11,3 +11,3 @@ from __future__ import annotations

class ScipyVonMises(ShapedDistribution[object]):
class ScipyVonMises(ShapedDistribution[object]): # type: ignore # pyright: ignore
"""This class allows distributions having a non-empty shape."""

@@ -29,3 +29,3 @@ @override

class ScipyVonMisesFisher(ShapedDistribution[object]):
class ScipyVonMisesFisher(ShapedDistribution[object]): # type: ignore # pyright: ignore
"""This class allows distributions having a non-empty shape."""

@@ -42,3 +42,3 @@ @override

for i in np.ndindex(*shape):
objects[i] = ss.vonmises_fisher(mu[i], kappa[i]) # pyright: ignore
objects[i] = ss.vonmises_fisher(mu[i], kappa[i]) # type: ignore # pyright: ignore
super().__init__(shape, rvs_shape, dtype, objects)

@@ -7,3 +7,3 @@ from __future__ import annotations

from itertools import starmap
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, TypeVar

@@ -22,3 +22,3 @@ from array_api_compat import array_namespace

@jit
def parameter_dot_product(x: NaturalParametrization, y: Any, /) -> JaxRealArray:
def parameter_dot_product(x: NaturalParametrization, y: Distribution, /) -> JaxRealArray:
"""Return the vectorized dot product over all of the variable parameters."""

@@ -25,0 +25,0 @@ def dotted_fields() -> Iterable[JaxRealArray]:

Metadata-Version: 2.4
Name: efax
Version: 1.22.1
Version: 1.22.2
Summary: Exponential families for JAX

@@ -34,16 +34,5 @@ Project-URL: source, https://github.com/NeilGirdhar/efax

Requires-Dist: scipy>=1.15
Requires-Dist: tensorflow-probability>=0.15
Requires-Dist: tfp-nightly>=0.25
Requires-Dist: tjax>=1.3.10
Requires-Dist: typing-extensions>=4.8
Provides-Extra: dev
Requires-Dist: isort>=5.13; extra == 'dev'
Requires-Dist: jupyter>=1; extra == 'dev'
Requires-Dist: mypy>=1.12; extra == 'dev'
Requires-Dist: pre-commit>=4; extra == 'dev'
Requires-Dist: pylint>=3.3; extra == 'dev'
Requires-Dist: pyright>=1.1.401; extra == 'dev'
Requires-Dist: pytest-ordering; extra == 'dev'
Requires-Dist: pytest-xdist[psutil]>=3; extra == 'dev'
Requires-Dist: pytest>=8.4; extra == 'dev'
Requires-Dist: ruff>=0.9.10; extra == 'dev'
Description-Content-Type: text/x-rst

@@ -50,0 +39,0 @@

@@ -5,5 +5,20 @@ [build-system]

[dependency-groups]
dev = [
"isort>=5.13",
"jupyter>=1",
"lefthook>=1.11.13",
"mypy>=1.12",
"pylint>=3.3",
"pyright>=1.1.401",
"pytest-ordering",
"pytest-xdist[psutil]>=3",
"pytest>=8.4",
"ruff>=0.9.10",
"toml-sort>=0.24"
]
[project]
name = "efax"
version = "1.22.1"
version = "1.22.2"
description = "Exponential families for JAX"

@@ -27,3 +42,3 @@ readme = "README.rst"

"Topic :: Software Development :: Libraries :: Python Modules",
"Typing :: Typed",
"Typing :: Typed"
]

@@ -40,21 +55,7 @@ dependencies = [

"scipy>=1.15",
"tensorflow_probability>=0.15",
"tfp-nightly>=0.25",
"tjax>=1.3.10",
"typing_extensions>=4.8",
"typing_extensions>=4.8"
]
[project.optional-dependencies]
dev = [
"isort>=5.13",
"jupyter>=1",
"mypy>=1.12",
"pre-commit>=4",
"pylint>=3.3",
"pyright>=1.1.401",
"pytest-ordering",
"pytest-xdist[psutil]>=3",
"pytest>=8.4",
"ruff>=0.9.10",
]
[project.urls]

@@ -68,2 +69,31 @@ source = "https://github.com/NeilGirdhar/efax"

[tool.mypy]
files = ["efax", "tests", "examples"]
disable_error_code = ["type-abstract"]
check_untyped_defs = true
disallow_any_generics = true
disallow_incomplete_defs = true
# disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
pretty = true
show_error_codes = true
show_error_context = false
strict_equality = true
warn_no_return = true
warn_redundant_casts = true
warn_return_any = false
warn_unreachable = true
warn_unused_configs = true
warn_unused_ignores = true
[[tool.mypy.overrides]]
module = [
"tensorflow_probability.substrates",
"array_api_compat",
"opt_einsum"
]
ignore_missing_imports = true
[tool.pylint.master]

@@ -80,3 +110,3 @@ jobs = 0

"pylint.extensions.overlapping_exceptions",
"pylint.extensions.typing",
"pylint.extensions.typing"
]

@@ -87,28 +117,262 @@

# Ruff
"C0103", "C0105", "C0112", "C0113", "C0114", "C0115", "C0116", "C0121", "C0123", "C0131", "C0132", "C0198", "C0199", "C0201", "C0202", "C0205",
"C0206", "C0208", "C0301", "C0303", "C0304", "C0305", "C0321", "C0410", "C0411", "C0412", "C0413", "C0414", "C0415", "C0501", "C1802", "C1901",
"C2201", "C2401", "C2403", "C2701", "C2801", "C3001", "C3002", "E0001", "E0013", "E0014", "E0100", "E0101", "E0102", "E0103", "E0104", "E0105",
"E0106", "E0107", "E0108", "E0112", "E0115", "E0116", "E0117", "E0118", "E0213", "E0237", "E0241", "E0302", "E0303", "E0304", "E0305", "E0308",
"E0309", "E0402", "E0602", "E0603", "E0604", "E0605", "E0643", "E0704", "E0711", "E1132", "E1142", "E1205", "E1206", "E1300", "E1301", "E1302",
"E1303", "E1304", "E1305", "E1306", "E1307", "E1310", "E1519", "E1520", "E1700", "E2502", "E2510", "E2512", "E2513", "E2514", "E2515", "E4703",
"E6004", "E6005", "R0022", "R0123", "R0124", "R0133", "R0202", "R0203", "R0205", "R0206", "R0402", "R0904", "R0911", "R0912", "R0913", "R0914",
"R0915", "R0916", "R1260", "R1701", "R1702", "R1703", "R1704", "R1705", "R1706", "R1707", "R1710", "R1711", "R1714", "R1715", "R1717", "R1718",
"R1719", "R1720", "R1721", "R1722", "R1723", "R1724", "R1725", "R1728", "R1729", "R1730", "R1731", "R1732", "R1733", "R1734", "R1735", "R1736",
"R2004", "R2044", "R5501", "R6002", "R6003", "R6104", "R6201", "R6301", "W0012", "W0102", "W0104", "W0106", "W0107", "W0108", "W0109", "W0120",
"W0122", "W0123", "W0127", "W0129", "W0130", "W0131", "W0133", "W0150", "W0160", "W0177", "W0199", "W0211", "W0212", "W0245", "W0301", "W0401",
"W0404", "W0406", "W0410", "W0511", "W0602", "W0603", "W0604", "W0611", "W0612", "W0613", "W0622", "W0640", "W0702", "W0705", "W0706", "W0707",
"W0711", "W0718", "W0719", "W1113", "W1201", "W1202", "W1203", "W1301", "W1302", "W1303", "W1304", "W1305", "W1309", "W1310", "W1401", "W1404",
"W1405", "W1406", "W1501", "W1502", "W1508", "W1509", "W1510", "W1514", "W1515", "W1518", "W1641", "W2101", "W2402", "W2601", "W2901", "W3201",
"C0103",
"C0105",
"C0112",
"C0113",
"C0114",
"C0115",
"C0116",
"C0121",
"C0123",
"C0131",
"C0132",
"C0198",
"C0199",
"C0201",
"C0202",
"C0205",
"C0206",
"C0208",
"C0301",
"C0303",
"C0304",
"C0305",
"C0321",
"C0410",
"C0411",
"C0412",
"C0413",
"C0414",
"C0415",
"C0501",
"C1802",
"C1901",
"C2201",
"C2401",
"C2403",
"C2701",
"C2801",
"C3001",
"C3002",
"E0001",
"E0013",
"E0014",
"E0100",
"E0101",
"E0102",
"E0103",
"E0104",
"E0105",
"E0106",
"E0107",
"E0108",
"E0112",
"E0115",
"E0116",
"E0117",
"E0118",
"E0213",
"E0237",
"E0241",
"E0302",
"E0303",
"E0304",
"E0305",
"E0308",
"E0309",
"E0402",
"E0602",
"E0603",
"E0604",
"E0605",
"E0643",
"E0704",
"E0711",
"E1132",
"E1142",
"E1205",
"E1206",
"E1300",
"E1301",
"E1302",
"E1303",
"E1304",
"E1305",
"E1306",
"E1307",
"E1310",
"E1519",
"E1520",
"E1700",
"E2502",
"E2510",
"E2512",
"E2513",
"E2514",
"E2515",
"E4703",
"E6004",
"E6005",
"R0022",
"R0123",
"R0124",
"R0133",
"R0202",
"R0203",
"R0205",
"R0206",
"R0402",
"R0904",
"R0911",
"R0912",
"R0913",
"R0914",
"R0915",
"R0916",
"R1260",
"R1701",
"R1702",
"R1703",
"R1704",
"R1705",
"R1706",
"R1707",
"R1710",
"R1711",
"R1714",
"R1715",
"R1717",
"R1718",
"R1719",
"R1720",
"R1721",
"R1722",
"R1723",
"R1724",
"R1725",
"R1728",
"R1729",
"R1730",
"R1731",
"R1732",
"R1733",
"R1734",
"R1735",
"R1736",
"R2004",
"R2044",
"R5501",
"R6002",
"R6003",
"R6104",
"R6201",
"R6301",
"W0012",
"W0102",
"W0104",
"W0106",
"W0107",
"W0108",
"W0109",
"W0120",
"W0122",
"W0123",
"W0127",
"W0129",
"W0130",
"W0131",
"W0133",
"W0150",
"W0160",
"W0177",
"W0199",
"W0211",
"W0212",
"W0245",
"W0301",
"W0401",
"W0404",
"W0406",
"W0410",
"W0511",
"W0602",
"W0603",
"W0604",
"W0611",
"W0612",
"W0613",
"W0622",
"W0640",
"W0702",
"W0705",
"W0706",
"W0707",
"W0711",
"W0718",
"W0719",
"W1113",
"W1201",
"W1202",
"W1203",
"W1301",
"W1302",
"W1303",
"W1304",
"W1305",
"W1309",
"W1310",
"W1401",
"W1404",
"W1405",
"W1406",
"W1501",
"W1502",
"W1508",
"W1509",
"W1510",
"W1514",
"W1515",
"W1518",
"W1641",
"W2101",
"W2402",
"W2601",
"W2901",
"W3201",
"W3301",
# Missing
"E0601", "R1737", "W0311", "W2301",
"E0601",
"R1737",
"W0311",
"W2301",
# Mine
"C0111", "E1101", "E1102", "E1120", "E1123", "E1130", "E1135", "E1136", "E3701", "R0204", "R0401", "R0801", "R0901", "R0902", "R0903", "R0917",
"R5601", "R6102", "R6103", "W0149", "W0221", "W0222", "W0223", "W0621", "W0717",
"C0111",
"E1101",
"E1102",
"E1120",
"E1123",
"E1130",
"E1135",
"E1136",
"E3701",
"R0204",
"R0401",
"R0801",
"R0901",
"R0902",
"R0903",
"R0917",
"R5601",
"R6102",
"R6103",
"W0149",
"W0221",
"W0222",
"W0223",
"W0621",
"W0717"
]
enable = [
"useless-suppression",
"use-symbolic-message-instead",
"use-symbolic-message-instead"
]

@@ -168,30 +432,2 @@

[tool.mypy]
files = ["efax", "tests", "examples"]
disable_error_code = ["type-abstract"]
check_untyped_defs = true
disallow_any_generics = true
disallow_incomplete_defs = true
# disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
pretty = true
show_error_codes = true
show_error_context = false
strict_equality = true
warn_no_return = true
warn_redundant_casts = true
warn_return_any = false
warn_unreachable = true
warn_unused_configs = true
warn_unused_ignores = true
[[tool.mypy.overrides]]
module = [
"tensorflow_probability.substrates",
"array_api_compat",
]
ignore_missing_imports = true
[tool.ruff]

@@ -204,3 +440,2 @@ line-length = 100

ignore = [
"ANN401", # Dynamically typed expressions (Any).
"ARG001", # Unused function argument.

@@ -210,4 +445,2 @@ "ARG002", # Unused method argument.

"ARG004", # Unused static method argument.
"B011", # Do not assert false.
"C901", # Complex structure.
"COM812", # Trailing comma missing.

@@ -226,4 +459,2 @@ "CPY001", # Missing copyright.

"ERA001", # Commented-out code.
"F722", # Syntax error in forward annotation.
"FBT003", # Boolean positional value in function call.
"FIX002", # Line contains TODO, consider resolving the issue.

@@ -233,7 +464,5 @@ "G004", # Logging statement uses f-string.

"PD008", # Use .loc instead of .at. If speed is important, use NumPy.
"PD013", # `.melt` is preferred to `.stack`; provides same functionality
"PGH003", # Use specific rule codes when ignoring type issues.
"PLR0913", # Too many arguments in function definition.
"PLR6301", # Method doesn"t use self.
"PT013", # Found incorrect import of pytest, use simple import pytest instead.
"Q000", # Single quotes found but double quotes preferred.

@@ -247,6 +476,8 @@ "RUF021", # Parenthesize `a and b` expressions when chaining `and` and `or` together...

"TD003", # Missing issue link on the line following this TODO.
"TID252", # Relative imports from parent modules are banned.
"UP037", # Remove quotes from type annotation.
"TID252" # Relative imports from parent modules are banned.
]
[tool.ruff.lint.flake8-errmsg]
max-string-length = 40
[tool.ruff.lint.flake8-import-conventions.extend-aliases]

@@ -267,5 +498,2 @@ "array_api_extra" = "xpx"

[tool.ruff.lint.flake8-errmsg]
max-string-length = 40
[tool.ruff.lint.isort]

@@ -272,0 +500,0 @@ combine-as-imports = true

@@ -23,3 +23,3 @@ from __future__ import annotations

def _jax_fixture(request: pytest.FixtureRequest) -> Generator[None]: # pyright: ignore
with jax.debug_key_reuse(True), jax.numpy_rank_promotion('raise'), enable_x64():
with jax.debug_key_reuse(new_val=True), jax.numpy_rank_promotion('raise'), enable_x64():
yield

@@ -26,0 +26,0 @@

@@ -360,3 +360,3 @@ from __future__ import annotations

class LogNormal(DistributionInfo[LogNormalNP, LogNormalEP, NumpyRealArray]):
class LogNormalInfo(DistributionInfo[LogNormalNP, LogNormalEP, NumpyRealArray]):
@override

@@ -661,3 +661,3 @@ def exp_to_scipy_distribution(self, p: LogNormalEP) -> Any:

JointInfo(infos={'gamma': GammaInfo(), 'normal': NormalInfo()}),
LogNormal(),
LogNormalInfo(),
LogarithmicInfo(),

@@ -664,0 +664,0 @@ MultivariateDiagonalNormalInfo(dimensions=4),

@@ -12,3 +12,4 @@ from __future__ import annotations

from efax import ExpectationParametrization, NaturalParametrization, Structure, SubDistributionInfo
from efax import (ExpectationParametrization, NaturalParametrization, ScipyDiscreteDistribution,
ScipyDistribution, Structure, SubDistributionInfo)

@@ -26,3 +27,3 @@ NP = TypeVar('NP', bound=NaturalParametrization, default=Any)

def exp_to_scipy_distribution(self, p: EP) -> Any:
def exp_to_scipy_distribution(self, p: EP) -> ScipyDistribution | ScipyDiscreteDistribution:
"""Produce a corresponding scipy distribution from expectation parameters.

@@ -35,3 +36,3 @@

def nat_to_scipy_distribution(self, q: NP) -> Any:
def nat_to_scipy_distribution(self, q: NP) -> ScipyDistribution | ScipyDiscreteDistribution:
"""Produce a corresponding scipy distribution from natural parameters.

@@ -102,7 +103,7 @@

def new_method(*args: Any,
def new_method(*args: object,
old_method: Callable[..., Any] = old_method,
**kwargs: Any) -> Any:
**kwargs: object) -> object:
return old_method(*args, **kwargs)
setattr(cls, method, new_method)

@@ -11,3 +11,4 @@ from __future__ import annotations

from efax import JointDistributionN, Multidimensional, NaturalParametrization, SimpleDistribution
from efax import (JointDistributionN, Multidimensional, NaturalParametrization,
ScipyDiscreteDistribution, ScipyDistribution, SimpleDistribution)

@@ -48,6 +49,7 @@ from ..create_info import MultivariateDiagonalNormalInfo

efax_density = np.asarray(nat_parameters.pdf(efax_x), dtype=np.float64)
try:
if isinstance(scipy_distribution, ScipyDistribution):
scipy_density = scipy_distribution.pdf(scipy_x)
except AttributeError:
scipy_density = scipy_distribution.pmf(scipy_x)
else:
assert isinstance(scipy_distribution, ScipyDiscreteDistribution)
scipy_density = scipy_distribution.pmf(scipy_x) # type: ignore # pyright: ignore

@@ -54,0 +56,0 @@ if isinstance(distribution_info, MultivariateDiagonalNormalInfo):

@@ -38,3 +38,3 @@ from __future__ import annotations

np.average(
np.conj(z)[..., np.newaxis, :, :] * z[..., np.newaxis, :], # type: ignore[arg-type]
np.conj(z)[..., np.newaxis, :, :] * z[..., np.newaxis, :],
weights=weights,

@@ -41,0 +41,0 @@ axis=-1)

@@ -5,3 +5,2 @@ """These tests verify entropy gradients."""

from functools import partial
from typing import Any

@@ -30,3 +29,3 @@ import jax.numpy as jnp

@jit
def _all_finite(some_tree: Any, /) -> JaxBooleanArray:
def _all_finite(some_tree: object, /) -> JaxBooleanArray:
return dynamic_tree_all(tree.map(lambda x: jnp.all(jnp.isfinite(x)), some_tree))

@@ -33,0 +32,0 @@

from __future__ import annotations
from typing import Any
import numpy as np

@@ -10,3 +8,3 @@ import pytest

from efax import ScipyDirichlet, ScipyMultivariateNormal
from efax import ScipyDirichlet, ScipyDistribution, ScipyMultivariateNormal

@@ -22,3 +20,4 @@

])
def test_shaped(generator: NumpyGenerator, distribution: Any, m: Shape, n: int) -> None:
def test_shaped(generator: NumpyGenerator, distribution: ScipyDistribution, m: Shape, n: int
) -> None:
assert distribution.rvs().shape == (*m, n)

@@ -25,0 +24,0 @@ assert distribution.rvs(1).shape == (*m, 1, n)

# Install the pre-commit hooks below with
# 'pre-commit install'
# Auto-update the version of the hooks with
# 'pre-commit autoupdate'
# Run the hooks on all files with
# 'pre-commit run --all'
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-toml
- id: check-yaml
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.10
hooks:
- id: ruff
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.401
hooks:
- id: pyright

Sorry, the diff of this file is too big to display