From 3c7051bd2abb3b38ad164e87dcf1ca36915bddc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 25 Aug 2022 13:27:29 -0600 Subject: [PATCH] Add a simple test for numba compilation --- tests/test_numba.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 tests/test_numba.py diff --git a/tests/test_numba.py b/tests/test_numba.py new file mode 100644 index 0000000..35e35f6 --- /dev/null +++ b/tests/test_numba.py @@ -0,0 +1,40 @@ +"""Test file created for the sole purpose of tracking the status of Numba compilation""" +import aesara +import aesara.tensor as at +from aeppl import joint_logprob + +import aehmc.nuts as nuts + + +def test_sample_with_numba(): + + srng = at.random.RandomStream(seed=0) + Y_rv = srng.normal(1, 2) + + def logprob_fn(y): + logprob = joint_logprob({Y_rv: y}) + return logprob + + # Build the transition kernel + kernel = nuts.new_kernel(srng, logprob_fn) + + # Compile a function that updates the chain + y_vv = Y_rv.clone() + initial_state = nuts.new_state(y_vv, logprob_fn) + + step_size = at.as_tensor(1e-2) + inverse_mass_matrix = at.as_tensor(1.0) + ( + next_state, + potential_energy, + potential_energy_grad, + acceptance_prob, + num_doublings, + is_turning, + is_diverging, + ), updates = kernel(*initial_state, step_size, inverse_mass_matrix) + + next_step_fn = aesara.function([y_vv], next_state, updates=updates, mode="NUMBA") + + # TODO: Assert something + next_step_fn(Y_rv.eval())