ndonnx
An ONNX-backed array library that is compliant with the Array API standard.
Installation
Releases are available on PyPI and conda-forge.
pip install ndonnx
conda install ndonnx
pixi add ndonnx
Development
You can install the package in development mode using:
git clone https://github.com/quantco/ndonnx
cd ndonnx
git submodule update --init --recursive
pixi shell
pre-commit run -a
pip install --no-build-isolation --no-deps -e .
pytest tests -n auto
Quick start
ndonnx
is an ONNX based python array library.
It has a couple of key features:
-
It implements the Array API
standard. Standard compliant code can be executed without changes across numerous backends such as like NumPy
, JAX
and now ndonnx
.
import numpy as np
import ndonnx as ndx
import jax.numpy as jnp
def mean_drop_outliers(a, low=-5, high=5):
xp = a.__array_namespace__()
return xp.mean(a[(low < a) & (a < high)])
np_result = mean_drop_outliers(np.asarray([-10, 0.5, 1, 5]))
jax_result = mean_drop_outliers(jnp.asarray([-10, 0.5, 1, 5]))
onnx_result = mean_drop_outliers(ndx.asarray([-10, 0.5, 1, 5]))
assert np_result == onnx_result.to_numpy() == jax_result == 0.75
-
It supports ONNX export. This allows you persist your logic into an ONNX computation graph.
import ndonnx as ndx
import onnx
x = ndx.array(shape=("N",), dtype=ndx.float32)
y = mean_drop_outliers(x)
model = ndx.build({"x": x}, {"y": y})
onnx.save(model, "mean_drop_outliers.onnx")
You can then make predictions using a runtime of your choice.
import onnxruntime as ort
import numpy as np
inference_session = ort.InferenceSession("mean_drop_outliers.onnx")
prediction, = inference_session.run(None, {
"x": np.array([-10, 0.5, 1, 5], dtype=np.float32),
})
assert prediction == 0.75
In the future we will be enabling a stable API for an extensible data type system. This will allow users to define their own data types and operations on arrays with these data types.
Array API coverage
Array API compatibility is tracked in api-coverage-tests
. Missing coverage is tracked in the skips.txt
file. Contributions are welcome!
Summary(1119 total):
- 961 passed
- 107 failed
- 51 deselected
Run the tests with:
pixi run arrayapitests