You're Invited:Meet the Socket Team at RSAC and BSidesSF 2026, March 23–26.RSVP
Socket
Book a DemoSign in
Socket

efax

Package Overview
Dependencies
Maintainers
1
Versions
115
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

efax - pypi Package Compare versions

Comparing version
1.22.3
to
1.23.0
+1
-2
efax/_src/distributions/bernoulli.py
from __future__ import annotations
from typing import Self
from typing import Self, override

@@ -11,3 +11,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -14,0 +13,0 @@ from ..interfaces.conjugate_prior import HasConjugatePrior

from __future__ import annotations
from typing import override
import jax.random as jr

@@ -7,3 +9,2 @@ from array_api_compat import array_namespace

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -10,0 +11,0 @@ from ..expectation_parametrization import ExpectationParametrization

from __future__ import annotations
import math
from typing import override

@@ -10,3 +11,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -13,0 +13,0 @@ from ..interfaces.samplable import Samplable

from __future__ import annotations
import math
from typing import override

@@ -9,3 +10,2 @@ import jax.scipy.special as jss

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -12,0 +12,0 @@ from ..mixins.has_entropy import HasEntropyEP, HasEntropyNP

from __future__ import annotations
import math
from typing import override

@@ -9,3 +10,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -12,0 +12,0 @@ from ...interfaces.multidimensional import Multidimensional

from __future__ import annotations
import math
from typing import override

@@ -9,3 +10,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -12,0 +12,0 @@ from ...interfaces.multidimensional import Multidimensional

from __future__ import annotations
import math
from typing import override

@@ -9,3 +10,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -12,0 +12,0 @@ from ...interfaces.samplable import Samplable

from __future__ import annotations
import math
from typing import override

@@ -9,3 +10,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -12,0 +12,0 @@ from ...interfaces.samplable import Samplable

from __future__ import annotations
from typing import Any, Generic, TypeVar
from typing import Any, TypeVar, override

@@ -10,3 +10,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -24,7 +23,6 @@ from ..interfaces.multidimensional import Multidimensional

@dataclass
class DirichletCommonNP(HasEntropyNP[EP],
class DirichletCommonNP[EP: 'DirichletCommonEP[Any]'](HasEntropyNP[EP],
Samplable,
Multidimensional,
NaturalParametrization[EP, JaxRealArray],
Generic[EP]):
NaturalParametrization[EP, JaxRealArray]):
alpha_minus_one: JaxRealArray = distribution_parameter(VectorSupport(

@@ -65,6 +63,6 @@ ring=RealField(minimum=-1.0, generation_scale=3.0)))

@dataclass
class DirichletCommonEP(HasEntropyEP[NP],
class DirichletCommonEP[NP: DirichletCommonNP[Any]](HasEntropyEP[NP],
Samplable,
ExpToNat[NP],
Multidimensional, Generic[NP]):
Multidimensional):
mean_log_probability: JaxRealArray = distribution_parameter(VectorSupport(

@@ -71,0 +69,0 @@ ring=negative_support))

from __future__ import annotations
from typing import override
from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray
from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -8,0 +9,0 @@ from ..expectation_parametrization import ExpectationParametrization

from __future__ import annotations
from typing import Self
from typing import Self, override

@@ -9,3 +9,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -12,0 +11,0 @@ from ..interfaces.conjugate_prior import HasConjugatePrior

from __future__ import annotations
from typing import override
import jax.random as jr

@@ -8,3 +10,2 @@ import jax.scipy.special as jss

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -11,0 +12,0 @@ from ..interfaces.samplable import Samplable

@@ -10,2 +10,4 @@ """The generalized Dirichlet distribution.

from typing import override
import jax.scipy.special as jss

@@ -15,3 +17,2 @@ from array_api_compat import array_namespace

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -18,0 +19,0 @@ from ..interfaces.multidimensional import Multidimensional

from __future__ import annotations
from typing import override
import jax.random as jr

@@ -7,3 +9,2 @@ from array_api_compat import array_namespace

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -10,0 +11,0 @@ from ..interfaces.samplable import Samplable

from __future__ import annotations
from typing import override
from array_api_compat import array_namespace
from tjax import Array, JaxArray, JaxRealArray, KeyArray, Shape
from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -8,0 +9,0 @@ from ..interfaces.samplable import Samplable

from __future__ import annotations
from typing import override
import jax.random as jr

@@ -7,3 +9,2 @@ from array_api_compat import array_namespace

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -10,0 +11,0 @@ from ..expectation_parametrization import ExpectationParametrization

from __future__ import annotations
from typing import override
import jax.random as jr

@@ -7,3 +9,2 @@ from array_api_compat import array_namespace

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -10,0 +11,0 @@ from ...interfaces.samplable import Samplable

from __future__ import annotations
from typing import Self
from typing import Self, override

@@ -9,3 +9,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -12,0 +11,0 @@ from ...interfaces.conjugate_prior import HasConjugatePrior

from __future__ import annotations
from typing import override
from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray, Shape
from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -8,0 +9,0 @@ from ..expectation_parametrization import ExpectationParametrization

from __future__ import annotations
from typing import Self
from typing import Self, override

@@ -12,3 +12,2 @@ import array_api_extra as xpx

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -78,3 +77,3 @@ from ..interfaces.conjugate_prior import HasGeneralizedConjugatePrior

retval = xpx.one_hot(jr.categorical(key, self.log_odds, shape=shape), self.dimensions())
assert isinstance(retval, JaxRealArray)
assert isinstance(retval, JaxArray)
return retval

@@ -81,0 +80,0 @@

from __future__ import annotations
import math
from typing import cast
from typing import cast, override

@@ -11,3 +11,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -14,0 +13,0 @@ from ...interfaces.multidimensional import Multidimensional

from __future__ import annotations
from typing import override
import jax.random as jr

@@ -8,3 +10,2 @@ import numpy as np

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -11,0 +12,0 @@ from ...interfaces.multidimensional import Multidimensional

from __future__ import annotations
import math
from typing import Self
from typing import Self, override

@@ -10,3 +10,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -13,0 +12,0 @@ from ...interfaces.conjugate_prior import HasGeneralizedConjugatePrior

from __future__ import annotations
from typing import override
import jax.random as jr

@@ -7,3 +9,2 @@ from array_api_compat import array_namespace

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -10,0 +11,0 @@ from ...interfaces.multidimensional import Multidimensional

from __future__ import annotations
import math
from typing import Self
from typing import Self, override

@@ -10,3 +10,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -13,0 +12,0 @@ from ...interfaces.conjugate_prior import HasGeneralizedConjugatePrior

from __future__ import annotations
from abc import abstractmethod
from typing import Any, TypeVar
from typing import Any, TypeVar, override

@@ -10,3 +10,2 @@ import jax.scipy.special as jss

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -13,0 +12,0 @@ from ..expectation_parametrization import ExpectationParametrization

from __future__ import annotations
from typing import override
import jax.random as jr

@@ -7,3 +9,2 @@ from array_api_compat import array_namespace

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -10,0 +11,0 @@ from ..expectation_parametrization import ExpectationParametrization

from __future__ import annotations
from typing import override
import jax.random as jr

@@ -8,3 +10,2 @@ import numpy as np

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -11,0 +12,0 @@ from ...interfaces.samplable import Samplable

from __future__ import annotations
import math
from typing import Self
from typing import Self, override

@@ -10,3 +10,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -13,0 +12,0 @@ from ...interfaces.conjugate_prior import HasConjugatePrior

from __future__ import annotations
from typing import Self
from typing import Self, override

@@ -10,3 +10,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -13,0 +12,0 @@ from ..expectation_parametrization import ExpectationParametrization

from __future__ import annotations
import math
from typing import override

@@ -10,3 +11,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -13,0 +13,0 @@ from ..interfaces.samplable import Samplable

from __future__ import annotations
from typing import cast
from typing import cast, override

@@ -9,3 +9,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -12,0 +11,0 @@ from ...interfaces.samplable import Samplable

from __future__ import annotations
from typing import Self, cast
from typing import Self, cast, override

@@ -9,3 +9,2 @@ import jax.random as jr

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -12,0 +11,0 @@ from ...interfaces.conjugate_prior import HasConjugatePrior

from __future__ import annotations
import math
from typing import cast
from typing import cast, override

@@ -9,3 +9,2 @@ from array_api_compat import array_namespace

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -12,0 +11,0 @@ from ..interfaces.multidimensional import Multidimensional

from __future__ import annotations
from typing import override
import jax.random as jr

@@ -8,3 +10,2 @@ import numpy as np

from tjax.dataclasses import dataclass
from typing_extensions import override

@@ -11,0 +12,0 @@ from ..interfaces.samplable import Samplable

from __future__ import annotations
from dataclasses import KW_ONLY, field
from typing import Any, Generic, Self, TypeAlias
from typing import Any, Generic, Self, override

@@ -10,3 +10,3 @@ from array_api_compat import array_namespace

from tjax.dataclasses import dataclass
from typing_extensions import TypeVar, override
from typing_extensions import TypeVar

@@ -19,3 +19,3 @@ from ...expectation_parametrization import ExpectationParametrization

NP = TypeVar('NP', bound=NaturalParametrization, default=Any)
SP: TypeAlias = JaxRealArray
type SP = JaxRealArray

@@ -22,0 +22,0 @@

@@ -1,2 +0,2 @@

from typing import Any, TypeAlias, TypeVar
from typing import Any, TypeVar, override

@@ -8,3 +8,2 @@ import optimistix as optx

from tjax.dataclasses import dataclass, field
from typing_extensions import override

@@ -16,3 +15,3 @@ from .exp_to_nat import ExpToNat, ExpToNatMinimizer

Aux = TypeVar('Aux')
RootFinder: TypeAlias = (
type RootFinder[Y, Out, Aux] = (
optx.AbstractRootFinder[Y, Out, Aux, Any]

@@ -19,0 +18,0 @@ | optx.AbstractLeastSquaresSolver[Y, Out, Aux, Any]

@@ -5,3 +5,3 @@ from __future__ import annotations

from functools import partial
from typing import Any, Generic, cast
from typing import Any, Generic, cast, override

@@ -11,3 +11,3 @@ from array_api_compat import array_namespace

from tjax import JaxArray, JaxComplexArray, JaxRealArray, Shape
from typing_extensions import TypeVar, override
from typing_extensions import TypeVar

@@ -24,3 +24,3 @@ from ..expectation_parametrization import ExpectationParametrization

class TransformedNaturalParametrization(NaturalParametrization[TEP, Domain],
Generic[NP, EP, TEP, Domain]):
Generic[NP, EP, TEP, Domain]): # noqa: UP046
"""Produce a NaturalParametrization by relating it to some base distrubtion NP."""

@@ -100,3 +100,3 @@ @classmethod

class TransformedExpectationParametrization(ExpectationParametrization[TNP],
Generic[EP, NP, TNP]):
Generic[EP, NP, TNP]): # noqa: UP046
"""Produce an ExpectationParametrization by relating it to some base distrubtion EP."""

@@ -103,0 +103,0 @@ @classmethod

@@ -6,2 +6,3 @@ from __future__ import annotations

from types import ModuleType
from typing import override

@@ -15,3 +16,2 @@ import array_api_extra as xpx

softplus)
from typing_extensions import override

@@ -192,2 +192,4 @@ from ..types import Namespace

if minimum is None:
minimum = 0.0
if maximum is None:

@@ -194,0 +196,0 @@ # x is outside the disk of the given minimum. Map it to the plane.

@@ -5,3 +5,3 @@ from __future__ import annotations

from math import comb, isqrt
from typing import Any, cast
from typing import Any, cast, override

@@ -15,3 +15,2 @@ import array_api_extra as xpx

from tjax import JaxArray, JaxRealArray, Shape
from typing_extensions import override

@@ -18,0 +17,0 @@ from ..types import Namespace

from __future__ import annotations
from typing import override
import numpy as np
from numpy.random import Generator
from tjax import NumpyComplexArray, NumpyRealArray, ShapeLike
from typing_extensions import override

@@ -45,2 +46,3 @@ from .multivariate_normal import ScipyMultivariateNormal, ScipyMultivariateNormalUnvectorized

def pdf(self, z: NumpyComplexArray, out: None = None) -> float:
assert z.ndim == 1
zr = np.concat([np.real(z), np.imag(z)], axis=-1)

@@ -117,3 +119,3 @@ return self.as_multivariate_normal().pdf(zr).item()

mean[i], variance[i], pseudo_variance[i])
super().__init__(shape, rvs_shape, dtype, objects)
super().__init__(shape, rvs_shape, dtype, objects, multivariate=True)

@@ -126,3 +128,4 @@ def as_multivariate_normal(self) -> ScipyMultivariateNormal:

objects[i] = this_object.as_multivariate_normal()
return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects)
return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects,
multivariate=True)

@@ -129,0 +132,0 @@ @property

from __future__ import annotations
from typing import Self
from typing import Self, override

@@ -8,3 +8,2 @@ import numpy as np

from tjax import NumpyComplexArray, NumpyComplexNumeric, NumpyRealArray, NumpyRealNumeric, ShapeLike
from typing_extensions import override

@@ -101,3 +100,3 @@ from .multivariate_normal import ScipyMultivariateNormal, ScipyMultivariateNormalUnvectorized

objects[i] = ScipyComplexNormalUnvectorized(mean[i], variance[i], pseudo_variance[i])
super().__init__(shape, rvs_shape, dtype, objects)
super().__init__(shape, rvs_shape, dtype, objects, multivariate=False)

@@ -123,3 +122,4 @@ @classmethod

objects[i] = this_object.as_multivariate_normal()
return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects)
return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects,
multivariate=True)

@@ -126,0 +126,0 @@ @property

from __future__ import annotations
from typing import override
import numpy as np

@@ -9,3 +11,2 @@ import optype.numpy as onp

from tjax import NumpyComplexArray, NumpyRealArray, ShapeLike
from typing_extensions import override

@@ -55,3 +56,3 @@ from .shaped_distribution import ShapedDistribution

objects[i] = ScipyDirichletFixRVsAndPDF(alpha[i])
super().__init__(shape, rvs_shape, dtype, objects)
super().__init__(shape, rvs_shape, dtype, objects, multivariate=True)

@@ -58,0 +59,0 @@ @override

@@ -58,2 +58,2 @@ from __future__ import annotations

objects[i] = ScipyMultivariateNormalUnvectorized(mean[i], cov[i])
return cls(shape, rvs_shape, dtype, objects)
return cls(shape, rvs_shape, dtype, objects, multivariate=True)
from __future__ import annotations
from typing import Any, Generic, TypeVar, cast
from typing import Any, TypeVar, cast, override

@@ -9,3 +9,2 @@ import numpy as np

from tjax import NumpyComplexArray, NumpyIntegralArray, NumpyRealArray, Shape
from typing_extensions import override

@@ -17,3 +16,3 @@ from .base import ScipyDiscreteDistribution, ScipyDistribution

class ShapedDistribution(Generic[T]):
class ShapedDistribution[T: ScipyDiscreteDistribution | ScipyDistribution]:
"""Allow a distributions with shape."""

@@ -25,3 +24,5 @@ @override

rvs_dtype: np.dtype[Any],
objects: npt.NDArray[np.object_]
objects: npt.NDArray[np.object_],
*,
multivariate: bool,
) -> None:

@@ -34,2 +35,3 @@ super().__init__()

self.objects = objects
self.multivariate = multivariate

@@ -56,3 +58,4 @@ @property

assert x.shape[:self.ndim] == self.shape
retval = np.empty(x.shape, dtype=self.real_dtype)
final_shape = x.shape[:-1] if self.multivariate else x.shape
retval = np.empty(final_shape, dtype=self.real_dtype)
for i in np.ndindex(*self.shape):

@@ -62,4 +65,11 @@ this_object = cast('T', self.objects[i])

raise NotImplementedError
for j in np.ndindex(*x.shape[self.ndim:]):
value = this_object.pdf(x[*i, *j])
j_range = x.shape[self.ndim: -1] if self.multivariate else x.shape[self.ndim:]
for j in np.ndindex(*j_range):
if self.multivariate:
x_ij = x[*i, *j, :]
assert x_ij.ndim == 1
else:
x_ij = x[*i, *j]
assert x_ij.ndim == 0
value = this_object.pdf(x_ij)
retval[*i, *j] = value

@@ -66,0 +76,0 @@ return retval

from __future__ import annotations
from typing import override
import numpy as np
import scipy.stats as ss
from tjax import NumpyRealArray
from typing_extensions import override

@@ -25,3 +26,3 @@ from .shaped_distribution import ShapedDistribution

objects[i] = ss.vonmises(kappa[i], loc[i])
super().__init__(shape, rvs_shape, dtype, objects)
super().__init__(shape, rvs_shape, dtype, objects, multivariate=False)

@@ -42,2 +43,2 @@

objects[i] = ss.vonmises_fisher(mu[i], kappa[i]) # type: ignore # pyright: ignore
super().__init__(shape, rvs_shape, dtype, objects)
super().__init__(shape, rvs_shape, dtype, objects, multivariate=True)

@@ -40,3 +40,3 @@ from __future__ import annotations

def parameter_mean(x: T, /, *, axis: Axis | None = None) -> T:
def parameter_mean[T: Distribution](x: T, /, *, axis: Axis | None = None) -> T:
"""Return the mean of the parameters (fixed and variable)."""

@@ -50,3 +50,3 @@ xp = array_namespace(x)

def parameter_map(operation: Callable[..., JaxComplexArray],
def parameter_map[T: Distribution](operation: Callable[..., JaxComplexArray],
x: T,

@@ -67,7 +67,3 @@ /,

_T = TypeVar('_T')
_V = TypeVar('_V')
def join_mappings(**field_to_map: Mapping[_T, _V]) -> dict[_T, dict[str, _V]]:
def join_mappings[T, V](**field_to_map: Mapping[T, V]) -> dict[T, dict[str, V]]:
"""Joins multiple mappings together using their common keys.

@@ -80,3 +76,3 @@

"""
retval = defaultdict[_T, dict[str, _V]](dict)
retval = defaultdict[T, dict[str, V]](dict)
for field_name, mapping in field_to_map.items():

@@ -83,0 +79,0 @@ for key, value in mapping.items():

from __future__ import annotations
from typing import Any, TypeAlias
from typing import Any
Axis: TypeAlias = int | tuple[int, ...]
Path: TypeAlias = tuple[str, ...]
Namespace: TypeAlias = Any
type Axis = int | tuple[int, ...]
type Path = tuple[str, ...]
type Namespace = Any
Metadata-Version: 2.4
Name: efax
Version: 1.22.3
Version: 1.23.0
Summary: Exponential families for JAX

@@ -18,9 +18,9 @@ Project-URL: source, https://github.com/NeilGirdhar/efax

Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.14
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Classifier: Typing :: Typed
Requires-Python: <3.14,>=3.11
Requires-Python: <3.15,>=3.12
Requires-Dist: array-api-compat>=1.10

@@ -32,7 +32,7 @@ Requires-Dist: array-api-extra>=0.8

Requires-Dist: optimistix>=0.0.9
Requires-Dist: optype>=0.8.0
Requires-Dist: optype[numpy]>=0.8.0
Requires-Dist: scipy-stubs>=1.15
Requires-Dist: scipy>=1.15
Requires-Dist: tfp-nightly>=0.25
Requires-Dist: tjax>=1.3.10
Requires-Dist: tjax>=1.4.1
Requires-Dist: typing-extensions>=4.8

@@ -39,0 +39,0 @@ Description-Content-Type: text/x-rst

@@ -22,6 +22,6 @@ [build-system]

name = "efax"
version = "1.22.3"
version = "1.23.0"
description = "Exponential families for JAX"
readme = "README.rst"
requires-python = ">=3.11,<3.14"
requires-python = ">=3.12,<3.15"
license = "Apache-2.0"

@@ -36,5 +36,5 @@ authors = [{email = "mistersheik@gmail.com"}, {name = "Neil Girdhar"}]

"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python",

@@ -52,7 +52,7 @@ "Topic :: Scientific/Engineering :: Artificial Intelligence",

"optimistix>=0.0.9",
"optype>=0.8.0",
"optype[numpy]>=0.8.0",
"scipy-stubs>=1.15",
"scipy>=1.15",
"tfp-nightly>=0.25",
"tjax>=1.3.10",
"tjax>=1.4.1",
"typing_extensions>=4.8"

@@ -59,0 +59,0 @@ ]

from __future__ import annotations
from collections.abc import Mapping
from typing import Any, cast
from typing import Any, cast, override

@@ -11,3 +11,2 @@ import array_api_extra as xpx

from tjax import JaxRealArray, NumpyComplexArray, NumpyRealArray, abs_square
from typing_extensions import override

@@ -14,0 +13,0 @@ from efax import (BernoulliEP, BernoulliNP, BetaEP, BetaNP, ChiEP, ChiNP, ChiSquareEP, ChiSquareNP,

from __future__ import annotations
from collections.abc import Callable
from typing import Any, Generic, final
from typing import Any, Generic, final, override

@@ -10,3 +10,3 @@ import jax.numpy as jnp

from tjax import JaxComplexArray, NumpyComplexArray, Shape
from typing_extensions import TypeVar, override
from typing_extensions import TypeVar

@@ -13,0 +13,0 @@ from efax import (ExpectationParametrization, NaturalParametrization, ScipyDiscreteDistribution,

from __future__ import annotations
from collections.abc import Callable
from typing import TypeAlias

@@ -17,3 +16,3 @@ from array_api_compat import array_namespace

_LogNormalizer: TypeAlias = Callable[[NaturalParametrization], JaxRealArray]
type _LogNormalizer = Callable[[NaturalParametrization], JaxRealArray]

@@ -20,0 +19,0 @@

"""These tests apply to only samplable distributions."""
from __future__ import annotations
from typing import Any, TypeAlias
from typing import Any

@@ -19,3 +19,3 @@ import jax.numpy as jnp

_Path: TypeAlias = tuple[str, ...]
type _Path = tuple[str, ...]

@@ -22,0 +22,0 @@

Sorry, the diff of this file is too big to display