Skip to content

Commit

Permalink
Conditioning for sum units
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Oct 15, 2024
1 parent 16f518d commit b27e30b
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 12 deletions.
2 changes: 1 addition & 1 deletion scripts/jpt_speed_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def timed_jax_method():

# times_nx, times_jax = eval_performance(nx_model.log_likelihood, (data, ), compiled_ll_jax, (data_jax, ), 20, 2)
# times_nx, times_jax = eval_performance(prob_nx, event, prob_jax, event, 15, 10)
times_nx, times_jax = eval_performance(nx_model.sample, (10000, ), jax_model.sample, (10000,), 15, 5)
times_nx, times_jax = eval_performance(nx_model.sample, (10000, ), jax_model.sample, (10000,), 1, 0)

time_jax = np.mean(times_jax), np.std(times_jax)
time_nx = np.mean(times_nx), np.std(times_nx)
Expand Down
59 changes: 58 additions & 1 deletion src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from sortedcontainers import SortedSet
from typing_extensions import List, Iterator, Tuple, Union, Type, Dict, Any, Self, Optional

from .utils import copy_bcoo, sample_from_sparse_probabilities_csc
from . import embed_sparse_array_in_nan_array
from .utils import copy_bcoo, sample_from_sparse_probabilities_csc, sparse_remove_rows_and_cols_where_all
from ..nx.probabilistic_circuit import SumUnit, ProductUnit, ProbabilisticCircuitMixin
from ...utils import timeit_print

Expand Down Expand Up @@ -472,6 +473,62 @@ def node_to_child_frequency_map(self, frequencies: np.array):
csr = coo_matrix((clw.data, clw.indices.T), shape=clw.shape).tocsr(copy=False)
return sample_from_sparse_probabilities_csc(csr, frequencies)

def log_conditional_of_simple_event(self, event: SimpleEvent, ) -> Tuple[Optional[Self], jax.Array]:
conditional_child_layers = []
conditional_log_weights = []

probabilities = jnp.zeros(self.number_of_nodes, dtype=jnp.float32)

for log_weights, child_layer in self.log_weighted_child_layers:
# get the conditional of the child layer
conditional, child_log_prob = child_layer.log_conditional_of_simple_event(event)
if conditional is None:
continue

# clone weights
log_weights = copy_bcoo(log_weights)

# calculate the weighted sum of the child log probabilities
log_weights.data += child_log_prob[log_weights.indices[:, 1]]

# skip if this layer is not connected to anything anymore
if jnp.all(log_weights.data == -jnp.inf):
continue

log_weights.data = jnp.exp(log_weights.data)

# calculate the probabilities of the child nodes in total
current_probabilities = log_weights.sum(1).todense()
probabilities += current_probabilities

log_weights.data = jnp.log(log_weights.data)

conditional_child_layers.append(conditional)
conditional_log_weights.append(log_weights)

if len(conditional_child_layers) == 0:
return self.impossible_condition_result

log_probabilities = jnp.log(probabilities)

concatenated_log_weights = bcoo_concatenate(conditional_log_weights, dimension=1).sort_indices()
# remove rows and columns where all weights are -inf
cleaned_log_weights = sparse_remove_rows_and_cols_where_all(concatenated_log_weights, -jnp.inf)

# normalize the weights
z = cleaned_log_weights.sum(1).todense()
cleaned_log_weights.data -= z[cleaned_log_weights.indices[:, 0]]

# slice the weights for each child layer
log_weight_slices = jnp.array([0] + [ccl.number_of_nodes for ccl in conditional_child_layers])
log_weight_slices = jnp.cumsum(log_weight_slices)
conditional_log_weights = [cleaned_log_weights[:, log_weight_slices[i]:log_weight_slices[i + 1]].sort_indices()
for i in range(len(conditional_child_layers))]

resulting_layer = SumLayer(conditional_child_layers, conditional_log_weights)
return resulting_layer, (log_probabilities - self.log_normalization_constants)


def __deepcopy__(self):
child_layers = [child_layer.__deepcopy__() for child_layer in self.child_layers]
log_weights = [copy_bcoo(log_weight) for log_weight in self.log_weights]
Expand Down
12 changes: 6 additions & 6 deletions src/probabilistic_model/probabilistic_circuit/jax/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,16 @@ def moment_of_nodes(self, order: jax.Array, center: jax.Array):
return result.reshape(-1, 1)

def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, jax.Array]:
probabilities = jnp.log(self.probability_of_simple_interval(interval))
log_probs = jnp.log(self.probability_of_simple_interval(interval))

valid_probabilities = probabilities > -jnp.inf
valid_log_probs = log_probs > -jnp.inf

if not valid_probabilities.any():
if not valid_log_probs.any():
return self.impossible_condition_result

result = self.__class__(self.variable, self.location[valid_probabilities],
self.density_cap[valid_probabilities])
return result, probabilities
result = self.__class__(self.variable, self.location[valid_log_probs],
self.density_cap[valid_log_probs])
return result, log_probs

def to_json(self) -> Dict[str, Any]:
result = super().to_json()
Expand Down
59 changes: 58 additions & 1 deletion src/probabilistic_model/probabilistic_circuit/jax/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax.numpy as jnp
import numpy as np
from jax.experimental.sparse import BCOO, BCSR
from jax.experimental.sparse import BCOO, BCSR, CSC, CSR
from random_events.interval import SimpleInterval, Bound
import jax
from scipy.sparse import csr_matrix, csr_array, csc_array
Expand Down Expand Up @@ -181,3 +181,60 @@ def remove_rows_and_cols_where_all(array: jax.Array, value: float) -> jax.Array:
# remove rows and cols that are entirely -inf
valid = array[valid_rows][:, valid_cols]
return valid

def shrink_index_array(index_array: jax.Array) -> jax.Array:
"""
Shrink an index array to only contain successive indices.
Example::
>>> shrink_index_array(jnp.array([[0, 3], [1, 0], [4, 1]]))
[[0 2]
[1 0]
[2 1]]
:param index_array: The index tensor to shrink.
:return: The shrunken index tensor.
"""
result = index_array.copy()

for dim in range(index_array.shape[1]):
unique_indices = jnp.unique(index_array[:, dim])

# map the old indices to the new indices
for new_index, unique_index in zip(range(len(unique_indices)), unique_indices):
result = result.at[result[:, dim] == unique_index, dim].set(new_index)


return result


def sparse_remove_rows_and_cols_where_all(array: BCOO, value: float) -> BCOO:
"""
Remove rows and columns from a sparse tensor where all elements are equal to a given value.
Example::
>>> array = BCOO.fromdense(jnp.array([[1, 0, 3], [0, 0, 0], [7, 0, 9]]))
>>> sparse_remove_rows_and_cols_where_all(array, 0).todense()
[[1 3]
[7 9]]
:param array: The sparse tensor to remove rows and columns from.
:param value: The value to remove.
:return: The tensor without the unnecessary rows and columns.
"""
# get indices of values where all elements are equal to a given value
values = array.data
valid_elements = (values != value)

# filter indices by valid elements
valid_indices = array.indices[valid_elements]

# shrink indices
valid_indices = shrink_index_array(valid_indices)

new_shape = jnp.max(valid_indices, axis=0) + 1

# construct result tensor
result = BCOO((values[valid_elements], valid_indices), shape=new_shape, indices_sorted=array.indices_sorted,
unique_indices=array.unique_indices)
return result
4 changes: 2 additions & 2 deletions src/probabilistic_model/probabilistic_circuit/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def shrink_index_tensor(index_tensor: torch.Tensor) -> torch.Tensor:
Example::
>>> shrink_index_tensor(torch.tensor([[0, 3], [1, 0], [4, 1]]))
tensor([[0, 2], [1, 0], [2, 1]])
>>> shrink_index_tensor(jnp.array([[0, 3], [1, 0], [4, 1]]))
:param index_tensor: The index tensor to shrink.
:return: The shrunken index tensor.
"""
Expand Down
19 changes: 19 additions & 0 deletions test/test_jax/test_sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,25 @@ def test_probability(self):
result = jnp.array([0.7, 0.5], dtype=jnp.float32)
self.assertTrue(jnp.allclose(result, prob))

def test_conditional(self):
event = SimpleEvent({self.x: closed(0.5, 1.5)})
c, lp = self.sum_layer.log_conditional_of_simple_event(event)
c.validate()
self.assertEqual(c.number_of_nodes, 1)
self.assertEqual(len(c.child_layers), 1)
self.assertEqual(c.child_layers[0].number_of_nodes, 1)
self.assertTrue(jnp.allclose(c.log_weights[0].todense(), jnp.array([[0.]])))
self.assertTrue(jnp.allclose(lp, jnp.log(jnp.array([0.1, 0.]))))

def test_conditional_2(self):
event = SimpleEvent({self.x: closed(1.5, 4.5)})
c, lp = self.sum_layer.log_conditional_of_simple_event(event)
c.validate()
self.assertEqual(c.number_of_nodes, 2)
self.assertEqual(len(c.child_layers), 2)
self.assertEqual(c.child_layers[0].number_of_nodes, 1)
self.assertEqual(c.child_layers[1].number_of_nodes, 2)


class PCSumUnitTestCase(unittest.TestCase):
x: Continuous = Continuous("x")
Expand Down
15 changes: 14 additions & 1 deletion test/test_jax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from random_events.interval import SimpleInterval
from scipy.sparse import coo_array

from probabilistic_model.probabilistic_circuit.jax import create_bcsr_indices_from_row_lengths
from probabilistic_model.probabilistic_circuit.jax import create_bcsr_indices_from_row_lengths, shrink_index_array, \
sparse_remove_rows_and_cols_where_all
from probabilistic_model.probabilistic_circuit.jax.utils import copy_bcoo, simple_interval_to_open_array, \
create_bcoo_indices_from_row_lengths, sample_from_sparse_probabilities_csc

Expand Down Expand Up @@ -53,6 +54,18 @@ def test_sample_from_sparse_probabilities_csc(self):
self.assertTrue(np.all(amounts == amount))
self.assertTrue(np.all(samples.data <= 3))

def test_shrink_index_array(self):
index_array = jnp.array([[0, 3], [1, 0], [4, 1]])
new_index_tensor = shrink_index_array(index_array)
result = jnp.array([[0, 2], [1, 0], [2, 1]])
self.assertTrue(jnp.allclose(new_index_tensor, result))

def test_sparse_remove_rows_and_cols_where_all(self):
array = BCOO.fromdense(jnp.array([[1, 0, 3], [0, 0, 0], [7, 0, 9]]))
result = jnp.array([[1, 3], [7, 9]])
new_array = sparse_remove_rows_and_cols_where_all(array, 0)
self.assertTrue(jnp.allclose(new_array.todense(), result))


class IntervalConversionTestCase(unittest.TestCase):

Expand Down

0 comments on commit b27e30b

Please sign in to comment.