efax
Advanced tools
| from __future__ import annotations | ||
| from typing import Self | ||
| from typing import Self, override | ||
@@ -11,3 +11,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -14,0 +13,0 @@ from ..interfaces.conjugate_prior import HasConjugatePrior |
| from __future__ import annotations | ||
| from typing import override | ||
| import jax.random as jr | ||
@@ -7,3 +9,2 @@ from array_api_compat import array_namespace | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -10,0 +11,0 @@ from ..expectation_parametrization import ExpectationParametrization |
| from __future__ import annotations | ||
| import math | ||
| from typing import override | ||
@@ -10,3 +11,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -13,0 +13,0 @@ from ..interfaces.samplable import Samplable |
| from __future__ import annotations | ||
| import math | ||
| from typing import override | ||
@@ -9,3 +10,2 @@ import jax.scipy.special as jss | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -12,0 +12,0 @@ from ..mixins.has_entropy import HasEntropyEP, HasEntropyNP |
| from __future__ import annotations | ||
| import math | ||
| from typing import override | ||
@@ -9,3 +10,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -12,0 +12,0 @@ from ...interfaces.multidimensional import Multidimensional |
| from __future__ import annotations | ||
| import math | ||
| from typing import override | ||
@@ -9,3 +10,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -12,0 +12,0 @@ from ...interfaces.multidimensional import Multidimensional |
| from __future__ import annotations | ||
| import math | ||
| from typing import override | ||
@@ -9,3 +10,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -12,0 +12,0 @@ from ...interfaces.samplable import Samplable |
| from __future__ import annotations | ||
| import math | ||
| from typing import override | ||
@@ -9,3 +10,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -12,0 +12,0 @@ from ...interfaces.samplable import Samplable |
| from __future__ import annotations | ||
| from typing import Any, Generic, TypeVar | ||
| from typing import Any, TypeVar, override | ||
@@ -10,3 +10,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -24,7 +23,6 @@ from ..interfaces.multidimensional import Multidimensional | ||
| @dataclass | ||
| class DirichletCommonNP(HasEntropyNP[EP], | ||
| class DirichletCommonNP[EP: 'DirichletCommonEP[Any]'](HasEntropyNP[EP], | ||
| Samplable, | ||
| Multidimensional, | ||
| NaturalParametrization[EP, JaxRealArray], | ||
| Generic[EP]): | ||
| NaturalParametrization[EP, JaxRealArray]): | ||
| alpha_minus_one: JaxRealArray = distribution_parameter(VectorSupport( | ||
@@ -65,6 +63,6 @@ ring=RealField(minimum=-1.0, generation_scale=3.0))) | ||
| @dataclass | ||
| class DirichletCommonEP(HasEntropyEP[NP], | ||
| class DirichletCommonEP[NP: DirichletCommonNP[Any]](HasEntropyEP[NP], | ||
| Samplable, | ||
| ExpToNat[NP], | ||
| Multidimensional, Generic[NP]): | ||
| Multidimensional): | ||
| mean_log_probability: JaxRealArray = distribution_parameter(VectorSupport( | ||
@@ -71,0 +69,0 @@ ring=negative_support)) |
| from __future__ import annotations | ||
| from typing import override | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxRealArray | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -8,0 +9,0 @@ from ..expectation_parametrization import ExpectationParametrization |
| from __future__ import annotations | ||
| from typing import Self | ||
| from typing import Self, override | ||
@@ -9,3 +9,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -12,0 +11,0 @@ from ..interfaces.conjugate_prior import HasConjugatePrior |
| from __future__ import annotations | ||
| from typing import override | ||
| import jax.random as jr | ||
@@ -8,3 +10,2 @@ import jax.scipy.special as jss | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -11,0 +12,0 @@ from ..interfaces.samplable import Samplable |
@@ -10,2 +10,4 @@ """The generalized Dirichlet distribution. | ||
| from typing import override | ||
| import jax.scipy.special as jss | ||
@@ -15,3 +17,2 @@ from array_api_compat import array_namespace | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -18,0 +19,0 @@ from ..interfaces.multidimensional import Multidimensional |
| from __future__ import annotations | ||
| from typing import override | ||
| import jax.random as jr | ||
@@ -7,3 +9,2 @@ from array_api_compat import array_namespace | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -10,0 +11,0 @@ from ..interfaces.samplable import Samplable |
| from __future__ import annotations | ||
| from typing import override | ||
| from array_api_compat import array_namespace | ||
| from tjax import Array, JaxArray, JaxRealArray, KeyArray, Shape | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -8,0 +9,0 @@ from ..interfaces.samplable import Samplable |
| from __future__ import annotations | ||
| from typing import override | ||
| import jax.random as jr | ||
@@ -7,3 +9,2 @@ from array_api_compat import array_namespace | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -10,0 +11,0 @@ from ..expectation_parametrization import ExpectationParametrization |
| from __future__ import annotations | ||
| from typing import override | ||
| import jax.random as jr | ||
@@ -7,3 +9,2 @@ from array_api_compat import array_namespace | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -10,0 +11,0 @@ from ...interfaces.samplable import Samplable |
| from __future__ import annotations | ||
| from typing import Self | ||
| from typing import Self, override | ||
@@ -9,3 +9,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -12,0 +11,0 @@ from ...interfaces.conjugate_prior import HasConjugatePrior |
| from __future__ import annotations | ||
| from typing import override | ||
| from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxRealArray, Shape | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -8,0 +9,0 @@ from ..expectation_parametrization import ExpectationParametrization |
| from __future__ import annotations | ||
| from typing import Self | ||
| from typing import Self, override | ||
@@ -12,3 +12,2 @@ import array_api_extra as xpx | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -78,3 +77,3 @@ from ..interfaces.conjugate_prior import HasGeneralizedConjugatePrior | ||
| retval = xpx.one_hot(jr.categorical(key, self.log_odds, shape=shape), self.dimensions()) | ||
| assert isinstance(retval, JaxRealArray) | ||
| assert isinstance(retval, JaxArray) | ||
| return retval | ||
@@ -81,0 +80,0 @@ |
| from __future__ import annotations | ||
| import math | ||
| from typing import cast | ||
| from typing import cast, override | ||
@@ -11,3 +11,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -14,0 +13,0 @@ from ...interfaces.multidimensional import Multidimensional |
| from __future__ import annotations | ||
| from typing import override | ||
| import jax.random as jr | ||
@@ -8,3 +10,2 @@ import numpy as np | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -11,0 +12,0 @@ from ...interfaces.multidimensional import Multidimensional |
| from __future__ import annotations | ||
| import math | ||
| from typing import Self | ||
| from typing import Self, override | ||
@@ -10,3 +10,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -13,0 +12,0 @@ from ...interfaces.conjugate_prior import HasGeneralizedConjugatePrior |
| from __future__ import annotations | ||
| from typing import override | ||
| import jax.random as jr | ||
@@ -7,3 +9,2 @@ from array_api_compat import array_namespace | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -10,0 +11,0 @@ from ...interfaces.multidimensional import Multidimensional |
| from __future__ import annotations | ||
| import math | ||
| from typing import Self | ||
| from typing import Self, override | ||
@@ -10,3 +10,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -13,0 +12,0 @@ from ...interfaces.conjugate_prior import HasGeneralizedConjugatePrior |
| from __future__ import annotations | ||
| from abc import abstractmethod | ||
| from typing import Any, TypeVar | ||
| from typing import Any, TypeVar, override | ||
@@ -10,3 +10,2 @@ import jax.scipy.special as jss | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -13,0 +12,0 @@ from ..expectation_parametrization import ExpectationParametrization |
| from __future__ import annotations | ||
| from typing import override | ||
| import jax.random as jr | ||
@@ -7,3 +9,2 @@ from array_api_compat import array_namespace | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -10,0 +11,0 @@ from ..expectation_parametrization import ExpectationParametrization |
| from __future__ import annotations | ||
| from typing import override | ||
| import jax.random as jr | ||
@@ -8,3 +10,2 @@ import numpy as np | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -11,0 +12,0 @@ from ...interfaces.samplable import Samplable |
| from __future__ import annotations | ||
| import math | ||
| from typing import Self | ||
| from typing import Self, override | ||
@@ -10,3 +10,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -13,0 +12,0 @@ from ...interfaces.conjugate_prior import HasConjugatePrior |
| from __future__ import annotations | ||
| from typing import Self | ||
| from typing import Self, override | ||
@@ -10,3 +10,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -13,0 +12,0 @@ from ..expectation_parametrization import ExpectationParametrization |
| from __future__ import annotations | ||
| import math | ||
| from typing import override | ||
@@ -10,3 +11,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -13,0 +13,0 @@ from ..interfaces.samplable import Samplable |
| from __future__ import annotations | ||
| from typing import cast | ||
| from typing import cast, override | ||
@@ -9,3 +9,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -12,0 +11,0 @@ from ...interfaces.samplable import Samplable |
| from __future__ import annotations | ||
| from typing import Self, cast | ||
| from typing import Self, cast, override | ||
@@ -9,3 +9,2 @@ import jax.random as jr | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -12,0 +11,0 @@ from ...interfaces.conjugate_prior import HasConjugatePrior |
| from __future__ import annotations | ||
| import math | ||
| from typing import cast | ||
| from typing import cast, override | ||
@@ -9,3 +9,2 @@ from array_api_compat import array_namespace | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -12,0 +11,0 @@ from ..interfaces.multidimensional import Multidimensional |
| from __future__ import annotations | ||
| from typing import override | ||
| import jax.random as jr | ||
@@ -8,3 +10,2 @@ import numpy as np | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import override | ||
@@ -11,0 +12,0 @@ from ..interfaces.samplable import Samplable |
| from __future__ import annotations | ||
| from dataclasses import KW_ONLY, field | ||
| from typing import Any, Generic, Self, TypeAlias | ||
| from typing import Any, Generic, Self, override | ||
@@ -10,3 +10,3 @@ from array_api_compat import array_namespace | ||
| from tjax.dataclasses import dataclass | ||
| from typing_extensions import TypeVar, override | ||
| from typing_extensions import TypeVar | ||
@@ -19,3 +19,3 @@ from ...expectation_parametrization import ExpectationParametrization | ||
| NP = TypeVar('NP', bound=NaturalParametrization, default=Any) | ||
| SP: TypeAlias = JaxRealArray | ||
| type SP = JaxRealArray | ||
@@ -22,0 +22,0 @@ |
@@ -1,2 +0,2 @@ | ||
| from typing import Any, TypeAlias, TypeVar | ||
| from typing import Any, TypeVar, override | ||
@@ -8,3 +8,2 @@ import optimistix as optx | ||
| from tjax.dataclasses import dataclass, field | ||
| from typing_extensions import override | ||
@@ -16,3 +15,3 @@ from .exp_to_nat import ExpToNat, ExpToNatMinimizer | ||
| Aux = TypeVar('Aux') | ||
| RootFinder: TypeAlias = ( | ||
| type RootFinder[Y, Out, Aux] = ( | ||
| optx.AbstractRootFinder[Y, Out, Aux, Any] | ||
@@ -19,0 +18,0 @@ | optx.AbstractLeastSquaresSolver[Y, Out, Aux, Any] |
@@ -5,3 +5,3 @@ from __future__ import annotations | ||
| from functools import partial | ||
| from typing import Any, Generic, cast | ||
| from typing import Any, Generic, cast, override | ||
@@ -11,3 +11,3 @@ from array_api_compat import array_namespace | ||
| from tjax import JaxArray, JaxComplexArray, JaxRealArray, Shape | ||
| from typing_extensions import TypeVar, override | ||
| from typing_extensions import TypeVar | ||
@@ -24,3 +24,3 @@ from ..expectation_parametrization import ExpectationParametrization | ||
| class TransformedNaturalParametrization(NaturalParametrization[TEP, Domain], | ||
| Generic[NP, EP, TEP, Domain]): | ||
| Generic[NP, EP, TEP, Domain]): # noqa: UP046 | ||
| """Produce a NaturalParametrization by relating it to some base distrubtion NP.""" | ||
@@ -100,3 +100,3 @@ @classmethod | ||
| class TransformedExpectationParametrization(ExpectationParametrization[TNP], | ||
| Generic[EP, NP, TNP]): | ||
| Generic[EP, NP, TNP]): # noqa: UP046 | ||
| """Produce an ExpectationParametrization by relating it to some base distrubtion EP.""" | ||
@@ -103,0 +103,0 @@ @classmethod |
@@ -6,2 +6,3 @@ from __future__ import annotations | ||
| from types import ModuleType | ||
| from typing import override | ||
@@ -15,3 +16,2 @@ import array_api_extra as xpx | ||
| softplus) | ||
| from typing_extensions import override | ||
@@ -192,2 +192,4 @@ from ..types import Namespace | ||
| if minimum is None: | ||
| minimum = 0.0 | ||
| if maximum is None: | ||
@@ -194,0 +196,0 @@ # x is outside the disk of the given minimum. Map it to the plane. |
@@ -5,3 +5,3 @@ from __future__ import annotations | ||
| from math import comb, isqrt | ||
| from typing import Any, cast | ||
| from typing import Any, cast, override | ||
@@ -15,3 +15,2 @@ import array_api_extra as xpx | ||
| from tjax import JaxArray, JaxRealArray, Shape | ||
| from typing_extensions import override | ||
@@ -18,0 +17,0 @@ from ..types import Namespace |
| from __future__ import annotations | ||
| from typing import override | ||
| import numpy as np | ||
| from numpy.random import Generator | ||
| from tjax import NumpyComplexArray, NumpyRealArray, ShapeLike | ||
| from typing_extensions import override | ||
@@ -45,2 +46,3 @@ from .multivariate_normal import ScipyMultivariateNormal, ScipyMultivariateNormalUnvectorized | ||
| def pdf(self, z: NumpyComplexArray, out: None = None) -> float: | ||
| assert z.ndim == 1 | ||
| zr = np.concat([np.real(z), np.imag(z)], axis=-1) | ||
@@ -117,3 +119,3 @@ return self.as_multivariate_normal().pdf(zr).item() | ||
| mean[i], variance[i], pseudo_variance[i]) | ||
| super().__init__(shape, rvs_shape, dtype, objects) | ||
| super().__init__(shape, rvs_shape, dtype, objects, multivariate=True) | ||
@@ -126,3 +128,4 @@ def as_multivariate_normal(self) -> ScipyMultivariateNormal: | ||
| objects[i] = this_object.as_multivariate_normal() | ||
| return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects) | ||
| return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects, | ||
| multivariate=True) | ||
@@ -129,0 +132,0 @@ @property |
| from __future__ import annotations | ||
| from typing import Self | ||
| from typing import Self, override | ||
@@ -8,3 +8,2 @@ import numpy as np | ||
| from tjax import NumpyComplexArray, NumpyComplexNumeric, NumpyRealArray, NumpyRealNumeric, ShapeLike | ||
| from typing_extensions import override | ||
@@ -101,3 +100,3 @@ from .multivariate_normal import ScipyMultivariateNormal, ScipyMultivariateNormalUnvectorized | ||
| objects[i] = ScipyComplexNormalUnvectorized(mean[i], variance[i], pseudo_variance[i]) | ||
| super().__init__(shape, rvs_shape, dtype, objects) | ||
| super().__init__(shape, rvs_shape, dtype, objects, multivariate=False) | ||
@@ -123,3 +122,4 @@ @classmethod | ||
| objects[i] = this_object.as_multivariate_normal() | ||
| return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects) | ||
| return ScipyMultivariateNormal(self.shape, self.rvs_shape, self.real_dtype, objects, | ||
| multivariate=True) | ||
@@ -126,0 +126,0 @@ @property |
| from __future__ import annotations | ||
| from typing import override | ||
| import numpy as np | ||
@@ -9,3 +11,2 @@ import optype.numpy as onp | ||
| from tjax import NumpyComplexArray, NumpyRealArray, ShapeLike | ||
| from typing_extensions import override | ||
@@ -55,3 +56,3 @@ from .shaped_distribution import ShapedDistribution | ||
| objects[i] = ScipyDirichletFixRVsAndPDF(alpha[i]) | ||
| super().__init__(shape, rvs_shape, dtype, objects) | ||
| super().__init__(shape, rvs_shape, dtype, objects, multivariate=True) | ||
@@ -58,0 +59,0 @@ @override |
@@ -58,2 +58,2 @@ from __future__ import annotations | ||
| objects[i] = ScipyMultivariateNormalUnvectorized(mean[i], cov[i]) | ||
| return cls(shape, rvs_shape, dtype, objects) | ||
| return cls(shape, rvs_shape, dtype, objects, multivariate=True) |
| from __future__ import annotations | ||
| from typing import Any, Generic, TypeVar, cast | ||
| from typing import Any, TypeVar, cast, override | ||
@@ -9,3 +9,2 @@ import numpy as np | ||
| from tjax import NumpyComplexArray, NumpyIntegralArray, NumpyRealArray, Shape | ||
| from typing_extensions import override | ||
@@ -17,3 +16,3 @@ from .base import ScipyDiscreteDistribution, ScipyDistribution | ||
| class ShapedDistribution(Generic[T]): | ||
| class ShapedDistribution[T: ScipyDiscreteDistribution | ScipyDistribution]: | ||
| """Allow a distributions with shape.""" | ||
@@ -25,3 +24,5 @@ @override | ||
| rvs_dtype: np.dtype[Any], | ||
| objects: npt.NDArray[np.object_] | ||
| objects: npt.NDArray[np.object_], | ||
| *, | ||
| multivariate: bool, | ||
| ) -> None: | ||
@@ -34,2 +35,3 @@ super().__init__() | ||
| self.objects = objects | ||
| self.multivariate = multivariate | ||
@@ -56,3 +58,4 @@ @property | ||
| assert x.shape[:self.ndim] == self.shape | ||
| retval = np.empty(x.shape, dtype=self.real_dtype) | ||
| final_shape = x.shape[:-1] if self.multivariate else x.shape | ||
| retval = np.empty(final_shape, dtype=self.real_dtype) | ||
| for i in np.ndindex(*self.shape): | ||
@@ -62,4 +65,11 @@ this_object = cast('T', self.objects[i]) | ||
| raise NotImplementedError | ||
| for j in np.ndindex(*x.shape[self.ndim:]): | ||
| value = this_object.pdf(x[*i, *j]) | ||
| j_range = x.shape[self.ndim: -1] if self.multivariate else x.shape[self.ndim:] | ||
| for j in np.ndindex(*j_range): | ||
| if self.multivariate: | ||
| x_ij = x[*i, *j, :] | ||
| assert x_ij.ndim == 1 | ||
| else: | ||
| x_ij = x[*i, *j] | ||
| assert x_ij.ndim == 0 | ||
| value = this_object.pdf(x_ij) | ||
| retval[*i, *j] = value | ||
@@ -66,0 +76,0 @@ return retval |
| from __future__ import annotations | ||
| from typing import override | ||
| import numpy as np | ||
| import scipy.stats as ss | ||
| from tjax import NumpyRealArray | ||
| from typing_extensions import override | ||
@@ -25,3 +26,3 @@ from .shaped_distribution import ShapedDistribution | ||
| objects[i] = ss.vonmises(kappa[i], loc[i]) | ||
| super().__init__(shape, rvs_shape, dtype, objects) | ||
| super().__init__(shape, rvs_shape, dtype, objects, multivariate=False) | ||
@@ -42,2 +43,2 @@ | ||
| objects[i] = ss.vonmises_fisher(mu[i], kappa[i]) # type: ignore # pyright: ignore | ||
| super().__init__(shape, rvs_shape, dtype, objects) | ||
| super().__init__(shape, rvs_shape, dtype, objects, multivariate=True) |
@@ -40,3 +40,3 @@ from __future__ import annotations | ||
| def parameter_mean(x: T, /, *, axis: Axis | None = None) -> T: | ||
| def parameter_mean[T: Distribution](x: T, /, *, axis: Axis | None = None) -> T: | ||
| """Return the mean of the parameters (fixed and variable).""" | ||
@@ -50,3 +50,3 @@ xp = array_namespace(x) | ||
| def parameter_map(operation: Callable[..., JaxComplexArray], | ||
| def parameter_map[T: Distribution](operation: Callable[..., JaxComplexArray], | ||
| x: T, | ||
@@ -67,7 +67,3 @@ /, | ||
| _T = TypeVar('_T') | ||
| _V = TypeVar('_V') | ||
| def join_mappings(**field_to_map: Mapping[_T, _V]) -> dict[_T, dict[str, _V]]: | ||
| def join_mappings[T, V](**field_to_map: Mapping[T, V]) -> dict[T, dict[str, V]]: | ||
| """Joins multiple mappings together using their common keys. | ||
@@ -80,3 +76,3 @@ | ||
| """ | ||
| retval = defaultdict[_T, dict[str, _V]](dict) | ||
| retval = defaultdict[T, dict[str, V]](dict) | ||
| for field_name, mapping in field_to_map.items(): | ||
@@ -83,0 +79,0 @@ for key, value in mapping.items(): |
| from __future__ import annotations | ||
| from typing import Any, TypeAlias | ||
| from typing import Any | ||
| Axis: TypeAlias = int | tuple[int, ...] | ||
| Path: TypeAlias = tuple[str, ...] | ||
| Namespace: TypeAlias = Any | ||
| type Axis = int | tuple[int, ...] | ||
| type Path = tuple[str, ...] | ||
| type Namespace = Any |
+5
-5
| Metadata-Version: 2.4 | ||
| Name: efax | ||
| Version: 1.22.3 | ||
| Version: 1.23.0 | ||
| Summary: Exponential families for JAX | ||
@@ -18,9 +18,9 @@ Project-URL: source, https://github.com/NeilGirdhar/efax | ||
| Classifier: Programming Language :: Python :: 3 | ||
| Classifier: Programming Language :: Python :: 3.11 | ||
| Classifier: Programming Language :: Python :: 3.12 | ||
| Classifier: Programming Language :: Python :: 3.13 | ||
| Classifier: Programming Language :: Python :: 3.14 | ||
| Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence | ||
| Classifier: Topic :: Software Development :: Libraries :: Python Modules | ||
| Classifier: Typing :: Typed | ||
| Requires-Python: <3.14,>=3.11 | ||
| Requires-Python: <3.15,>=3.12 | ||
| Requires-Dist: array-api-compat>=1.10 | ||
@@ -32,7 +32,7 @@ Requires-Dist: array-api-extra>=0.8 | ||
| Requires-Dist: optimistix>=0.0.9 | ||
| Requires-Dist: optype>=0.8.0 | ||
| Requires-Dist: optype[numpy]>=0.8.0 | ||
| Requires-Dist: scipy-stubs>=1.15 | ||
| Requires-Dist: scipy>=1.15 | ||
| Requires-Dist: tfp-nightly>=0.25 | ||
| Requires-Dist: tjax>=1.3.10 | ||
| Requires-Dist: tjax>=1.4.1 | ||
| Requires-Dist: typing-extensions>=4.8 | ||
@@ -39,0 +39,0 @@ Description-Content-Type: text/x-rst |
+5
-5
@@ -22,6 +22,6 @@ [build-system] | ||
| name = "efax" | ||
| version = "1.22.3" | ||
| version = "1.23.0" | ||
| description = "Exponential families for JAX" | ||
| readme = "README.rst" | ||
| requires-python = ">=3.11,<3.14" | ||
| requires-python = ">=3.12,<3.15" | ||
| license = "Apache-2.0" | ||
@@ -36,5 +36,5 @@ authors = [{email = "mistersheik@gmail.com"}, {name = "Neil Girdhar"}] | ||
| "Programming Language :: Python :: 3", | ||
| "Programming Language :: Python :: 3.11", | ||
| "Programming Language :: Python :: 3.12", | ||
| "Programming Language :: Python :: 3.13", | ||
| "Programming Language :: Python :: 3.14", | ||
| "Programming Language :: Python", | ||
@@ -52,7 +52,7 @@ "Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
| "optimistix>=0.0.9", | ||
| "optype>=0.8.0", | ||
| "optype[numpy]>=0.8.0", | ||
| "scipy-stubs>=1.15", | ||
| "scipy>=1.15", | ||
| "tfp-nightly>=0.25", | ||
| "tjax>=1.3.10", | ||
| "tjax>=1.4.1", | ||
| "typing_extensions>=4.8" | ||
@@ -59,0 +59,0 @@ ] |
| from __future__ import annotations | ||
| from collections.abc import Mapping | ||
| from typing import Any, cast | ||
| from typing import Any, cast, override | ||
@@ -11,3 +11,2 @@ import array_api_extra as xpx | ||
| from tjax import JaxRealArray, NumpyComplexArray, NumpyRealArray, abs_square | ||
| from typing_extensions import override | ||
@@ -14,0 +13,0 @@ from efax import (BernoulliEP, BernoulliNP, BetaEP, BetaNP, ChiEP, ChiNP, ChiSquareEP, ChiSquareNP, |
| from __future__ import annotations | ||
| from collections.abc import Callable | ||
| from typing import Any, Generic, final | ||
| from typing import Any, Generic, final, override | ||
@@ -10,3 +10,3 @@ import jax.numpy as jnp | ||
| from tjax import JaxComplexArray, NumpyComplexArray, Shape | ||
| from typing_extensions import TypeVar, override | ||
| from typing_extensions import TypeVar | ||
@@ -13,0 +13,0 @@ from efax import (ExpectationParametrization, NaturalParametrization, ScipyDiscreteDistribution, |
| from __future__ import annotations | ||
| from collections.abc import Callable | ||
| from typing import TypeAlias | ||
@@ -17,3 +16,3 @@ from array_api_compat import array_namespace | ||
| _LogNormalizer: TypeAlias = Callable[[NaturalParametrization], JaxRealArray] | ||
| type _LogNormalizer = Callable[[NaturalParametrization], JaxRealArray] | ||
@@ -20,0 +19,0 @@ |
| """These tests apply to only samplable distributions.""" | ||
| from __future__ import annotations | ||
| from typing import Any, TypeAlias | ||
| from typing import Any | ||
@@ -19,3 +19,3 @@ import jax.numpy as jnp | ||
| _Path: TypeAlias = tuple[str, ...] | ||
| type _Path = tuple[str, ...] | ||
@@ -22,0 +22,0 @@ |
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
1002295
0.48%8119
-0.09%