From 79a13aedbb3eb1a6cc5e49b0f1572091f7c0024e Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Mon, 7 Oct 2024 08:43:41 +0200 Subject: [PATCH] Sum Layer CDF --- .../probabilistic_circuit/jax/inner_layer.py | 52 +++++++++++++++++-- .../probabilistic_circuit/jax/input_layer.py | 5 -- test/test_jax/test_sum_layer.py | 17 +++++- 3 files changed, 64 insertions(+), 10 deletions(-) diff --git a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py index 8fd062e..7cde5ca 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py @@ -46,15 +46,33 @@ def log_likelihood_of_nodes_single(self, x: jnp.array) -> jnp.array: """ Calculate the log-likelihood of the distribution. - .. Note:: - The shape of the log likelihood depends on the number of samples and nodes. - The shape of the result is (#samples, #nodes). + :param x: The input vector. + :return: The log-likelihood of every node in the layer for x. """ raise NotImplementedError def log_likelihood_of_nodes(self, x: jnp.array) -> jnp.array: + """ + Vectorized version of :meth:`log_likelihood_of_nodes_single` + """ return jax.vmap(self.log_likelihood_of_nodes_single)(x) + + def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: + """ + Calculate the cumulative distribution function of the distribution if applicable. + + :param x: The input vector. + :return: The cumulative distribution function of every node in the layer for x. + """ + raise NotImplementedError + + def cdf_of_nodes(self, x: jnp.array) -> jnp.array: + """ + Vectorized version of :meth:`cdf_of_nodes_single` + """ + return jax.vmap(self.cdf_of_nodes_single)(x) + def validate(self): """ Validate the parameters and their layouts. @@ -139,6 +157,7 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[Probabilis def partition(self) -> Tuple[Any, Any]: """ Partition the layer into the parameters and the static structure. + :return: A tuple containing the parameters and the static structure as pytrees. """ return eqx.partition(self, eqx.is_inexact_array) @@ -320,7 +339,7 @@ def normalized_weights(self): return result def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array: - result = jnp.zeros(self.number_of_nodes) + result = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) for log_weights, child_layer in self.log_weighted_child_layers: # get the log likelihoods of the child nodes @@ -341,6 +360,31 @@ def log_likelihood_of_nodes_single(self, x: jax.Array) -> jax.Array: return jnp.where(result > 0, jnp.log(result) - self.log_normalization_constants, -jnp.inf) + def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: + result = jnp.zeros(self.number_of_nodes, dtype=jnp.float32) + + for log_weights, child_layer in self.log_weighted_child_layers: + # get the cdf of the child nodes + child_layer_cdf = child_layer.cdf_of_nodes_single(x) + + # weight the cdf of the child nodes by the weight for each node of this layer + cloned_log_weights = copy_bcoo(log_weights) # clone the weights + + # multiply the weights with the child layer cdf + cloned_log_weights.data = jnp.exp(cloned_log_weights.data) # exponent weights + cloned_log_weights.data *= child_layer_cdf[cloned_log_weights.indices[:, 1]] + + # sum the weights for each node + ll = cloned_log_weights.sum(1).todense() + + # sum the child layer result + result += ll + + # normalize the result + normalization_constants = jnp.exp(self.log_normalization_constants) + return result / normalization_constants + + def sample_from_frequencies(self, frequencies: jax.Array, key: jax.random.PRNGKey) -> BCOO: # calculate the probabilities for the latent variable interpretation of this layer diff --git a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py index 6caa92b..862483b 100644 --- a/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py +++ b/src/probabilistic_model/probabilistic_circuit/jax/input_layer.py @@ -19,11 +19,6 @@ class ContinuousLayer(InputLayer, ABC): Abstract base class for continuous univariate input units. """ - def cdf_of_nodes_single(self, x: jnp.array) -> jnp.array: - raise NotImplementedError - - def cdf_of_nodes(self, x: jnp.array) -> jnp.array: - return jax.vmap(self.cdf_of_nodes_single)(x) class ContinuousLayerWithFiniteSupport(ContinuousLayer, ABC): diff --git a/test/test_jax/test_sum_layer.py b/test/test_jax/test_sum_layer.py index 10e0a29..d913194 100644 --- a/test/test_jax/test_sum_layer.py +++ b/test/test_jax/test_sum_layer.py @@ -3,6 +3,7 @@ from jax.experimental.sparse import BCOO from random_events.variable import Continuous import jax.numpy as jnp +from triton.language import dtype from probabilistic_model.probabilistic_circuit.jax import in_bound_elements_from_sparse_slice from probabilistic_model.probabilistic_circuit.jax.input_layer import DiracDeltaLayer @@ -69,4 +70,18 @@ def test_sampling(self): _, sample_row = in_bound_elements_from_sparse_slice(sample_row) self.assertEqual(len(sample_row), frequencies[index]) likelihood = self.sum_layer.log_likelihood_of_nodes(sample_row) - self.assertTrue(all(likelihood[:, index] > -jnp.inf)) \ No newline at end of file + self.assertTrue(all(likelihood[:, index] > -jnp.inf)) + + def test_cdf(self): + data = jnp.arange(7, dtype=jnp.float32).reshape(-1, 1) - 0.5 + cdf = self.sum_layer.cdf_of_nodes(data) + self.assertEqual(cdf.shape, (7, 2)) + result = jnp.array([[0, 0], # -0.5 + [0, 0.4], # 0.5 + [0.1, 0.4], # 1.5 + [0.3, 0.7], # 2.5 + [0.6, 0.7], # 3.5 + [0.6, 0.8], # 4.5 + [1, 1], # 5.5 + ], dtype=jnp.float32) + self.assertTrue(jnp.allclose(cdf, result))