UPDATE: This repo is no longer maintained. Please see JaxGaussianProcess/GPJax if you’re interested in a Gaussian process library in JAX.
GPJax
is a minimal package for implementing Gaussian process models in Python using JAX.
I have spent a lot of time using GPflow and I like how they implement their GP library, in particular,
their focus on variational inference and how they implement GP conditionals.
As such, this package takes a similar approach but offers the benefits (and ease) of having JAX under the hood.
GPJax
keeps in the spirit of JAX (to provide an extensible system for composable function transformations)
by implementing its features as pure functions.
However, managing the parameters associated with the different components in Gaussian process methods
(compositions of mean functions, kernels, variational parameters etc) makes a purely functional approach less appealing.
I have experimented using both haiku and objax for state management and neither of them provided a level of abstraction
I was happy with.
As a result, GPJax
now implements a simple approach to state management.
This package is a work in progress and functionality will be implemented when I need it for my research. If you like what I’m doing or have any recommendations then please reach out, or even better, get involved!
gpjax
relies upon itsgpjax.Module
class which all constructs with trainable state must inherit e.g.MeanFunction
,Kernel
,GPModel
.- Each
gpjax.Module
subclass must implement anget_params()
method that returns a dictionary of the parameters associated with the object, e.g. for a stationary kernelkernel_param_dict = kernel.get_params()
returns{'lengthscales': DeviceArray([1.], dtype=float64), 'variance': DeviceArray([1.], dtype=float64)}
. - JAX’s functionality depends upon pure functions and GPJax implements all functionality as standalone functions; classes then call the relevant functions.
- For example, the squared exponential kernel class is really just a convenience class for storing its associated parameters and providing a wrapper around its associated function. This is handy because it means that we can either create an instance of the kernel class and call it, or we can call the standalone function,
kernel = SquaredExponential() params = kernel.get_params() # get the dictionary of parameters associated with kernel K_object = kernel( params, X1, X2 ) # call the kernel object with the dictionary of parameters K_func = squared_exponential_cov_fn( params, X1, X2 ) # call the standalone function with the dictionary of parameters assert K_object == K_func
- For example, the squared exponential kernel class is really just a convenience class for storing its associated parameters and providing a wrapper around its associated function. This is handy because it means that we can either create an instance of the kernel class and call it, or we can call the standalone function,
- In general, classes provide wrappers around functions that accept a dictionary of parameters as the first argument,
gpjax.models.SVGP.predict_f(params, Xnew,...)
gpjax.kernels.SquaredExponential.K(params, X1, X2)
gpjax.models.SVGP.predict_f(params, Xnew,...)
Let’s take a look at how state is managed in an example using a kernel.
import jax
from gpjax.kernels import SquaredExponential
key = jax.random.PRNGKey(0)
# create some dummy data
input_dim = 5
num_x1 = 100
num_x2 = 30
X1 = jax.random.uniform(key, shape=(num_x1, input_dim))
X2 = jax.random.uniform(key, shape=(num_x2, input_dim))
# create an instance of the SE kernel
kernel = SquaredExponential()
params = kernel.get_params() # get the dictionary of parameters associated with kernel
K = kernel(params, X1, X2) # call the kernel with the dictionary of parameters
Importantly, this is a pure function because the kernel’s hyperparameters (lengthscales and variances in this case)
are passed as an argument. This means that we can easily compose it with any JAX transformation,
for example, jax.jit
and jax.jacfwd
.
We can calculate the derivative of our kernel w.r.t its hyperparameters using jax.jacfwd
,
# create a function that returns the derivative of kernel w.r.t params (its first argument)
jac_kernel_wrt_params_fn = jax.jacfwd(kernel, argnums=0)
# evaluate the derivative of the kernel wrt its hyperparameters
jac_kernel_wrt_params = jac_kernel_wrt_params_fn(params, X1, X2)
print(jac_kernel_wrt_params["lengthscales"].shape)
print(jac_kernel_wrt_params["variance"].shape)
(100, 30, 1)
(100, 30, 1)
This is a Python package that should be installed into a virtual environment. Start by cloning this repo from Github:
git clone https://github.com/aidanscannell/GPJax.git
The package can then be installed into a virtual environment by adding it as a local dependency.
GPJax
’s dependencies and packaging are being managed with Poetry, instead of other tools such as Pipenv.
To install GPJax
into an existing poetry environment add it as a dependency under
[tool.poetry.dependencies]
(in the ./pyproject.toml configuration file) with the following line:
gpjax = {path = "/path/to/gpjax"}
If you want to develop the gpjax
codebase then set develop=true
:
gpjax = {path = "/path/to/gpjax", develop=true}
The dependencies in a ./pyproject.toml file are resolved and installed with:
poetry install
If you do not require the development packages then you can opt to install without them,
poetry install --no-dev
Create a new virtualenv and activate it, for example,
mkvirtualenv --python=python3 gpjax-env
workon gpjax-env
cd into the root of this package and install it and its dependencies with,
pip install .
If you want to develop the gpjax
codebase then install it in “editable” or “develop” mode with:
pip install -e .
- [ ] Implement mean functions
- [X] Implement zero
- [X] Implement constant
- [ ] Implement kernels
- [X] Implement base
- [X] Implement squared exponential
- [X] Implement multi output
- [X] Implement separate independent
- [ ] Implement shared independent
- [ ] Implement LinearCoregionalization
- [ ] Implement conditionals
- [X] Implement single-output conditionals
- [X] Implement multi-output conditionals
- [X] Implement dispatch for single/multioutput
- [ ] Implement dispatch for different inducing variables
- [ ] Implement likelihoods
- [X] Implement base likelihood
- [X] Implement Gaussian likelihood
- [ ] Implement Bernoulli likelihood
- [ ] Implement Softmax likelihood
- [ ] Implement gpjax.models
- [X] Implement gpjax.models.GPModel
- [X] predict_f
- [X] predict_y
- [ ] Implement gpjax.models.GPR
- [ ] Implement gpjax.models.SVGP
- [X] predict_f
- [X] init_variational_parameters
- [X] KL
- [X] lower bound
- [X] Implement gpjax.models.GPModel
- [ ] Notebook examples
- [ ] GPR regression
- [X] SVGP regression
- [ ] SVGP classification
- [X] Tests for mean functions
- [X] Tests for zero
- [X] Tests for constant
- [X] Tests for kernels
- [X] Tests for squared exponential
- [X] Tests for separate independent
- [ ] Tests for conditionals
- [ ] Tests for single output conditionals
- [ ] Tests for multi output conditionals
- [ ] Tests for likelihoods
- [ ] Tests for gaussian likelihood
- [ ] Tests for bernoulli likelihood
- [ ] Tests for softmax likelihood
- [ ] Tests for gpjax.models.SVGP
- [X] Tests for gpjax.models.SVGP.predict_f
- [X] Tests for gpjax.models.SVGP.prior_kl