Quickstart
| Install guide
| Documentation
| Slack Community
GPJax aims to provide a low-level interface to Gaussian process (GP) models in Jax, structured to give researchers maximum flexibility in extending the code to suit their own needs. The idea is that the code should be as close as possible to the maths we write on paper when working with GP models.
Package support
GPJax was founded by Thomas Pinder. Today, the maintenance of GPJax is undertaken by Thomas Pinder and Daniel Dodd.
We would be delighted to receive contributions from interested individuals and groups. To learn how you can get involved, please read our guide for contributing. If you have any questions, we encourage you to open an issue. For broader conversations, such as best GP fitting practices or questions about the mathematics of GPs, we invite you to open a discussion.
Feel free to join our Slack Channel, where we can discuss the development of GPJax and broader support for Gaussian process modelling.
Supported methods and interfaces
Notebook examples
Guides for customisation
Conversion between .ipynb
and .py
Above examples are stored in examples directory in the double percent (py:percent
) format. Checkout jupytext using-cli for more info.
- To convert
example.py
to example.ipynb
, run:
jupytext --to notebook example.py
- To convert
example.ipynb
to example.py
, run:
jupytext --to py:percent example.ipynb
Simple example
Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.
import gpjax as gpx
from jax import grad, jit
import jax.numpy as jnp
import jax.random as jr
import optax as ox
key = jr.PRNGKey(123)
f = lambda x: 10 * jnp.sin(x)
n = 50
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
y = f(x) + jr.normal(key, shape=(n,1))
D = gpx.Dataset(X=x, y=y)
The function of interest here, $f(\cdot)$, is sinusoidal, but our observations of it have been perturbed by Gaussian noise. We aim to utilise a Gaussian process to try and recover this latent function.
1. Constructing the prior and posterior
We begin by defining a zero-mean Gaussian process prior with a radial basis function kernel and assume the likelihood to be Gaussian.
prior = gpx.Prior(kernel = gpx.RBF())
likelihood = gpx.Gaussian(num_datapoints = n)
Similar to how we would write on paper, the posterior is constructed by the product of our prior with our likelihood.
posterior = prior * likelihood
2. Learning hyperparameters
Equipped with the posterior, we seek to learn the model's hyperparameters through gradient-optimisation of the marginal log-likelihood. We this below, adding Jax's just-in-time (JIT) compilation to accelerate training.
mll = jit(posterior.marginal_log_likelihood(D, negative=True))
For purposes of optimisation, we'll use optax's Adam.
opt = ox.adam(learning_rate=1e-3)
We define an initial parameter state through the initialise
callable.
parameter_state = gpx.initialise(posterior, key=key)
Finally, we run an optimisation loop using the Adam optimiser via the fit
callable.
inference_state = gpx.fit(mll, parameter_state, opt, num_iters=500)
3. Making predictions
Using our learned hyperparameters, we can obtain the posterior distribution of the latent function at novel test points.
learned_params, _ = inference_state.unpack()
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
latent_distribution = posterior(learned_params, D)(xtest)
predictive_distribution = likelihood(learned_params, latent_distribution)
predictive_mean = predictive_distribution.mean()
predictive_cov = predictive_distribution.covariance()
Installation
Stable version
The latest stable version of GPJax can be installed via pip
:
pip install gpjax
Note
We recommend you check your installation version:
python -c 'import gpjax; print(gpjax.__version__)'
Development version
Warning
This version is possibly unstable and may contain bugs.
Clone a copy of the repository to your local machine and run the setup configuration in development mode.
git clone https://github.com/JaxGaussianProcesses/GPJax.git
cd GPJax
python setup.py develop
Note
We advise you create virtual environment before installing:
conda create -n gpjax_experimental python=3.10.0
conda activate gpjax_experimental
and recommend you check your installation passes the supplied unit tests:
python -m pytest tests/
Citing GPJax
If you use GPJax in your research, please cite our JOSS paper.
@article{Pinder2022,
doi = {10.21105/joss.04455},
url = {https://doi.org/10.21105/joss.04455},
year = {2022},
publisher = {The Open Journal},
volume = {7},
number = {75},
pages = {4455},
author = {Thomas Pinder and Daniel Dodd},
title = {GPJax: A Gaussian Process Framework in JAX},
journal = {Journal of Open Source Software}
}