Estimate the true posterior distribution over the weights of a neural network using the Kronecker Factored Laplace Approximation. Compatible with TensorFlow 2.x
This library implements three Laplace approximation methods to approximate the true posterior distribution over the weights of a neural network. It is similar to this implementation for PyTorch of [2].
The library, located in the laplace
package, implements three
approximations to the Fisher Information Matrix:
Furthermore, it implements a sampler for each approximation to sample weights from the approximated posterior distribution. The library can be used with any Tensorflow 2 sequential model. So far, the approximations only considers Dense and Convolutional layers.
The experiments
package contains code related to the experiments
in experiments.py
.
To install the library, clone or download the repository and run:
pip install .
or add it as a dependency to your environment using
pip install git+https://github.com/ferewi/tf-laplace.git
This will install all the following dependencies that are needed:
- numpy
- tensorflow (v2.5)
- tensorflow_probability (v0.13)
It will also install the following libraries that are needed to run
the demo in the provided jupyter notebook and the experiments in
experiments.py
:
- matplotlib
- pandas
- sklearn
This mini-example as well as the demo provided in the jupyter notebook
demonstrate how to use this library to approximate the posterior
distribution of a neural network trained on a synthetic multi-label
classification dataset. This dataset is contained in the
experiments/dataset.py
module.
# standard imports
import numpy as np
import tensorflow as tf
# library imports
from laplace.curvature import KFAC
from laplace.sampler import Sampler
# additional imports
from experiments.dataset import F3
# 1. Create Dataset
# We create a multi-label dataset which consists of two classes.
ds = F3.create(50, -5.5, 5.5)
training_set = tf.data.Dataset.from_tensor_slices(ds.get()).batch(32)
test_set = tf.data.Dataset.from_tensor_slices(ds.get_test_set(2000)).batch(256)
# 2. Build and Train Model
# As the method can be applied to already trained models this is
# just for demo purposes.
NUM_CLASSES = 2
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, input_dim=2, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(0.001)),
tf.keras.layers.Dense(10, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(0.001)),
tf.keras.layers.Dense(NUM_CLASSES)
])
criterion = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(loss=criterion, optimizer=optimizer, metrics=['accuracy'])
model.fit(training_set, epochs=1000, verbose=True)
# 3. Approximate Curvature
# We approximate the curvature by running a training loop over
# the training set. Instead of updating the models parameters
# the curvature approximation gets computed incrementally.
kfac = KFAC.compute(model, training_set, criterion)
sampler = Sampler.create(kfac, tau=1, n=50)
# 4. Evaluate the Bayesian neural network on a Test Set
MC_SAMPLES = 50
predictions = np.zeros([MC_SAMPLES, 0, NUM_CLASSES], dtype=np.float32)
for i, (x, y) in enumerate(test_set):
posterior_mean = model.get_weights()
batch_predictions = np.zeros([MC_SAMPLES, x.shape[0], NUM_CLASSES], dtype=np.float32)
for sample in range(MC_SAMPLES):
sampler.sample_and_replace_weights()
batch_predictions[sample] = tf.sigmoid(model.predict(x)).numpy()
model.set_weights(posterior_mean)
predictions = np.concatenate([predictions, batch_predictions], axis=1)
The experiments in experiments.py
demonstrate how the model
uncertainty estimates obtained from the Bayesian neural network
can improve the calibration and out-of-distribution detection of
the deterministic model. All code related to the experiment can
be found in the experiments
package.
The calibration experiment compares the calibration of the deterministic baseline model to that of the Bayesian neural network. This is visualised using calibration curve diagrams.
The experiment can be run via:
python experiments.py calibration
The out-of-distribution detection experiment demonstrates how using the predictive standard deviation of multiple probabilistic forward passes improves the separability from in- and out-of-distribution data compared to using the confidence scores of the baseline model. This is visualised using receiver operating characteristic (ROC) plots.
The experiment can be run via:
python experiments.py ood
If you have any questions, comments, bug reports or feature requests please use the GitHub issue tracker or write me at: [email protected]
© 2021 Ferdinand Rewicki ([email protected])
The tf-laplace
library is released under the MIT license