MCX is a probabilistic programming library with a laser-focus on sampling methods. MCX transforms the model definitions to generate logpdf or sampling functions. These functions are JIT-compiled with JAX; they support batching and can be exectuted on CPU, GPU or TPU transparently.
The project is currently at its infancy and a moonshot towards providing sequential inference as a first-class citizen, and performant sampling methods for Bayesian deep learning.
MCX's philosophy
- Knowing how to express a graphical model and manipulating Numpy arrays should be enough to define a model.
- Models should be modular and re-usable.
- Inference should be performant and should leverage GPUs.
See the documentation for more information. See this issue for an updated roadmap for v0.1.
Note that there are still many moving pieces in mcx
and the API may change
slightly.
import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import mcx
from mcx.distributions import Exponential, Normal
from mcx.inference import HMC
rng_key = jax.random.PRNGKey(0)
x_data = np.random.normal(0, 5, size=(1000,1))
y_data = 3 * x_data + np.random.normal(size=x_data.shape)
@mcx.model
def linear_regression(x, lmbda=1.):
scale <~ Exponential(lmbda)
coefs <~ Normal(jnp.zeros(jnp.shape(x)[-1]), 1)
preds <~ Normal(jnp.dot(x, coefs), scale)
return preds
prior_predictive = mcx.prior_predict(rng_key, linear_regression, (x_data,))
posterior = mcx.sampler(
rng_key,
linear_regression,
(x_data,),
{'preds': y_data},
HMC(100),
).run()
az.plot_trace(posterior)
posterior_predictive = mcx.posterior_predict(rng_key, linear_regression, (x_data,), posterior)
We are currently considering the future directions:
- Neural network layers: You can follow discussions about the API in this Pull Request.
- Programs with stochastic support: Discussion in this Issue.
- Tools for causal inference: Made easier by the internal representation as a graph.
You are more than welcome to contribute to these discussions, or suggest potential future directions.
Like most PPL, MCX implements a batch sampling runtime:
sampler = mcx.sampler(
rng_key,
linear_regression,
*args,
observations,
kernel,
)
posterior = sampler.run()
The warmup trace is discarded by default but you can obtain it by running:
warmup_posterior = sampler.warmup()
posterior = sampler.run()
You can extract more samples from the chain after a run and combine the two traces:
posterior += sampler.run()
By default MCX will sample in interactive mode using a python for
loop and
display a progress bar and various diagnostics. For faster sampling you can use:
posterior = sampler.run(compile=True)
One could use the combination in a notebook to first get a lower bound on the sampling rate before deciding on a number of samples.
Sampling the posterior is an iterative process. Yet most libraries only provide
batch sampling. The generator runtime is already implemented in mcx
, which
opens many possibilities such as:
- Dynamical interruption of inference (say after getting a set number of effective samples);
- Real-time monitoring of inference with something like tensorboard;
- Easier debugging.
samples = mcx.sampler(
rng_key,
linear_regression,
*args,
observations,
kernel,
)
trace = mcx.Trace()
for sample in samples:
trace.append(sample)
iter(sampler)
next(sampler)
Note that the performance of the interactive mode is significantly lower than that of the batch sampler. However, both can be used successively:
trace = mcx.Trace()
for i, sample in enumerate(samples):
print(do_something(sample))
trace.append(sample)
if i % 10 == 0:
trace += sampler.run(100_000, compile=True)
MCX takes a lot of inspiration from other probabilistic programming languages and libraries: Stan (NUTS and the very knowledgeable community), PyMC3 (for its simple API), Tensorflow Probability (for its shape system and inference vectorization), (Num)Pyro (for the use of JAX in the backend), Gen.jl and Turing.jl (for composable inference), Soss.jl (generative model API), Anglican, and many that I forget.