Skip to content

Commit

Permalink
Halfway done with conditioniong for product layers
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Oct 15, 2024
1 parent b27e30b commit fa9715e
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 3 deletions.
112 changes: 110 additions & 2 deletions src/probabilistic_model/probabilistic_circuit/jax/inner_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
import numpy as np
import tqdm
from jax import numpy as jnp, tree_flatten
from jax.experimental.sparse import BCOO, bcoo_concatenate
from jax.experimental.sparse import BCOO, bcoo_concatenate, bcoo_slice
from jaxtyping import Int
from random_events.product_algebra import SimpleEvent
from random_events.utils import recursive_subclasses, SubclassJSONSerializer
from scipy.sparse import coo_matrix, coo_array
from sortedcontainers import SortedSet
from triton.language import dtype
from typing_extensions import List, Iterator, Tuple, Union, Type, Dict, Any, Self, Optional

from . import embed_sparse_array_in_nan_array
from . import embed_sparse_array_in_nan_array, shrink_index_array
from .utils import copy_bcoo, sample_from_sparse_probabilities_csc, sparse_remove_rows_and_cols_where_all
from ..nx.probabilistic_circuit import SumUnit, ProductUnit, ProbabilisticCircuitMixin
from ...utils import timeit_print
Expand Down Expand Up @@ -218,6 +219,15 @@ def merge_with(self, others: List[Self]) -> Self:
"""
raise NotImplementedError

def remove_nodes(self, remove_mask: jax.Array) -> Self:
"""
Remove nodes from the layer.
:param remove_mask: A boolean mask of the nodes to remove.
:return: The layer with the nodes removed.
"""
raise NotImplementedError


class InnerLayer(Layer, ABC):
"""
Expand Down Expand Up @@ -722,6 +732,104 @@ def to_json(self) -> Dict[str, Any]:
result["edges"] = (self.edges.data.tolist(), self.edges.indices.tolist(), self.edges.shape)
return result

def log_conditional_of_simple_event(self, event: SimpleEvent, ) -> Tuple[Optional[Self], jax.Array]:

# initialize the conditional child layers and the log probabilities
log_probabilities = jnp.zeros(self.number_of_nodes, dtype=jnp.float32)
conditional_child_layers = []
remapped_edges = []

# for edge bundle and child layer
for index, (edges, child_layer) in enumerate(zip(self.edges, self.child_layers)):
edges: BCOO
edges = edges.sum_duplicates(remove_zeros=False)

# condition the child layer
conditional, child_log_prob = child_layer.log_conditional_of_simple_event(event)

# if it is entirely impossible, this layer also is
if conditional is None:
continue

# update the log probabilities and child layers
log_probabilities = log_probabilities.at[edges.indices[:, 0]].add(child_log_prob[edges.data])
conditional_child_layers.append(conditional)

# create the remapping of the node indices. nan indicates the node got deleted
# enumerate the indices of the conditional child layer nodes
new_node_indices = jnp.arange(conditional.number_of_nodes)

# initialize the remapping of the child layer node indices
layer_remap = jnp.full((child_layer.number_of_nodes,), jnp.nan, dtype=jnp.float32)
layer_remap = layer_remap.at[child_log_prob > -jnp.inf].set(new_node_indices)

# update the edges
remapped_child_edges = layer_remap[edges.data]
valid_edges = ~jnp.isnan(remapped_child_edges)

# create new indices for the edges
new_indices = edges.indices[valid_edges]
new_indices = jnp.concatenate([new_indices, jnp.zeros((len(new_indices), 1), dtype=jnp.int32)],
axis=1)

new_edges = BCOO((remapped_child_edges[valid_edges].astype(jnp.int32),
new_indices),
shape = (self.number_of_nodes, 1), indices_sorted=True,
unique_indices=True)
remapped_edges.append(new_edges)

remapped_edges = bcoo_concatenate(remapped_edges, dimension=1).sort_indices()

print(embed_sparse_array_in_nan_array(remapped_edges))

# get nodes that should be removed as boolean mask
remove_mask = log_probabilities == -jnp.inf # shape (#nodes, )
keep_mask = ~remove_mask

print(remapped_edges.indices, remapped_edges.data)
# remove the nodes that have -inf log probabilities from remapped_edges
remapped_edges = coo_array((remapped_edges.data, remapped_edges.indices.T), shape=remapped_edges.shape).tocsr()
remapped_edges = remapped_edges[keep_mask].tocoo()

remapped_edges = BCOO((remapped_edges.data, jnp.stack((remapped_edges.row, remapped_edges.col)).T),
shape=remapped_edges.shape, indices_sorted=True, unique_indices=True)

# construct result and clean it up
result = self.__class__(conditional_child_layers, remapped_edges)
result = result.clean_up_orphans()
return result, log_probabilities

def clean_up_orphans(self):
"""
Clean up the layer by removing orphans in the child layers.
"""

new_child_layers = []

for index, (edges, child_layer) in enumerate(zip(self.edges, self.child_layers)):
edges: BCOO
edges = edges.sum_duplicates(remove_zeros=False)
# mask rather nodes have parent edges or not
orphans = jnp.ones(child_layer.number_of_nodes, dtype=jnp.bool)

# mark nodes that have parents with False
data = edges.data
if len(data) > 0:
orphans = orphans.at[data].set(False)

# if orphans exist
if orphans.any():
# remove them from the child layer
child_layer = child_layer.remove_nodes(orphans)
new_child_layers.append(child_layer)

# compress edges
shrunken_indices = shrink_index_array(self.edges.indices)
new_edges = BCOO((self.edges.data, shrunken_indices), shape=self.edges.shape, indices_sorted=True,
unique_indices=True)
return self.__class__(new_child_layers, new_edges)


@classmethod
def _from_json(cls, data: Dict[str, Any]) -> Self:
child_layer = [Layer.from_json(child_layer) for child_layer in data["child_layers"]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def to_json(self) -> Dict[str, Any]:
def __deepcopy__(self):
return self.__class__(self.variables[0].item(), self.interval.copy())

def remove_nodes(self, remove_mask: jax.Array) -> Self:
return self.__class__(self.variable, self.interval[~remove_mask])


class DiracDeltaLayer(ContinuousLayer):

Expand Down Expand Up @@ -271,4 +274,7 @@ def _from_json(cls, data: Dict[str, Any]) -> Self:

def merge_with(self, others: List[Self]) -> Self:
return self.__class__(self.variable, jnp.concatenate([self.location] + [other.location for other in others]),
jnp.concatenate([self.density_cap] + [other.density_cap for other in others]))
jnp.concatenate([self.density_cap] + [other.density_cap for other in others]))

def remove_nodes(self, remove_mask: jax.Array) -> Self:
return self.__class__(self.variable, self.location[~remove_mask], self.density_cap[~remove_mask])
17 changes: 17 additions & 0 deletions test/test_jax/test_product_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from probabilistic_model.probabilistic_circuit.jax.inner_layer import ProductLayer
from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import ProbabilisticCircuit

import warnings
warnings.filterwarnings("ignore")

class DiracProductTestCase(unittest.TestCase):

Expand Down Expand Up @@ -75,6 +77,21 @@ def test_probability(self):
result = jnp.array([1, 0], dtype=jnp.float32)
self.assertTrue(jnp.allclose(prob, result))

def test_conditioning(self):

event = SimpleEvent({self.x: closed(-1, 1),
self.y: closed(4.5, 5.5),
self.z: closed(5.5, 6.5)})

conditional, log_prob = self.product_layer.log_conditional_of_simple_event(event)
conditional.validate()
self.assertTrue(jnp.allclose(log_prob, jnp.log(jnp.array([1., 0.]))))
self.assertEqual(conditional.number_of_nodes, 1)
self.assertEqual(len(conditional.child_layers), 3)
self.assertEqual(conditional.child_layers[0].number_of_nodes, 1)
self.assertEqual(conditional.child_layers[1].number_of_nodes, 1)
self.assertEqual(conditional.child_layers[2].number_of_nodes, 1)


class PCProductLayerTestCase(unittest.TestCase):

Expand Down

0 comments on commit fa9715e

Please sign in to comment.