efax
Advanced tools
@@ -32,2 +32,6 @@ from __future__ import annotations | ||
| @property | ||
| def ndim(self) -> int: | ||
| return len(self.shape) | ||
| def rvs(self, | ||
@@ -48,3 +52,4 @@ size: int | Shape | None = None, | ||
| def pdf(self, x: NumpyComplexArray) -> NumpyRealArray: | ||
| retval = np.empty(self.shape, dtype=self.real_dtype) | ||
| assert x.shape[:self.ndim] == self.shape | ||
| retval = np.empty(x.shape, dtype=self.real_dtype) | ||
| for i in np.ndindex(*self.shape): | ||
@@ -54,8 +59,10 @@ this_object = cast('T', self.objects[i]) | ||
| raise NotImplementedError | ||
| value = this_object.pdf(x[i]) | ||
| retval[i] = value | ||
| for j in np.ndindex(*x.shape[self.ndim:]): | ||
| value = this_object.pdf(x[*i, *j]) | ||
| retval[*i, *j] = value | ||
| return retval | ||
| def pmf(self, x: NumpyIntegralArray) -> NumpyRealArray: | ||
| retval = np.empty(self.shape, dtype=self.real_dtype) | ||
| assert x.shape[:self.ndim] == self.shape | ||
| retval = np.empty(x.shape, dtype=self.real_dtype) | ||
| for i in np.ndindex(*self.shape): | ||
@@ -65,4 +72,5 @@ this_object = cast('T', self.objects[i]) | ||
| raise NotImplementedError | ||
| value = this_object.pmf(x[i]) | ||
| retval[i] = value | ||
| for j in np.ndindex(*x.shape[self.ndim:]): | ||
| value = this_object.pmf(x[*i, *j]) | ||
| retval[*i, *j] = value | ||
| return retval | ||
@@ -69,0 +77,0 @@ |
+1
-1
| Metadata-Version: 2.4 | ||
| Name: efax | ||
| Version: 1.22.2 | ||
| Version: 1.22.3 | ||
| Summary: Exponential families for JAX | ||
@@ -5,0 +5,0 @@ Project-URL: source, https://github.com/NeilGirdhar/efax |
+1
-1
@@ -22,3 +22,3 @@ [build-system] | ||
| name = "efax" | ||
| version = "1.22.2" | ||
| version = "1.22.3" | ||
| description = "Exponential families for JAX" | ||
@@ -25,0 +25,0 @@ readme = "README.rst" |
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
997521
0.21%8126
0.09%