Skip to content

Commit

Permalink
Abit of refactoring and truncation of simple events in sum units is a…
Browse files Browse the repository at this point in the history
…lmost done
  • Loading branch information
tomsch420 committed Aug 1, 2024
1 parent 0585055 commit 8797527
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 40 deletions.
13 changes: 4 additions & 9 deletions src/probabilistic_model/learning/torch/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .pc import InputLayer, AnnotatedLayer, SumLayer
from ...probabilistic_circuit.probabilistic_circuit import ProbabilisticCircuitMixin
from ...utils import interval_as_array
from ...utils import interval_as_array, remove_rows_and_cols_where_all


class ContinuousLayer(InputLayer, ABC):
Expand Down Expand Up @@ -105,14 +105,9 @@ def log_conditional_from_interval(self, interval: Interval) -> Tuple[SumLayer, t
# calculate log probabilities of the entire interval
log_probabilities = stacked_log_probabilities.logsumexp(dim=0) # shape: (#nodes, 1)

stacked_log_probabilities.squeeze_()

# get the rows and columns that are not entirely -inf
valid_rows = (stacked_log_probabilities > -torch.inf).any(dim=1)
valid_cols = (stacked_log_probabilities > -torch.inf).any(dim=0)

# remove rows and cols that are entirely -inf
valid_log_probabilities = stacked_log_probabilities[valid_rows][:, valid_cols]
# remove rows and columns where all elements are -inf
stacked_log_probabilities.squeeze_(-1)
valid_log_probabilities = remove_rows_and_cols_where_all(stacked_log_probabilities, -torch.inf)

# create sparse log weights
log_weights = valid_log_probabilities.T.exp().to_sparse_coo()
Expand Down
71 changes: 46 additions & 25 deletions src/probabilistic_model/learning/torch/pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from ...probabilistic_circuit.probabilistic_circuit import ProbabilisticCircuit, \
ProbabilisticCircuitMixin, SumUnit, ProductUnit
from ...probabilistic_model import ProbabilisticModel
from ...utils import (remove_rows_and_cols_where_all, add_sparse_edges_dense_child_tensor_inplace,
sparse_remove_rows_and_cols_where_all)


def inverse_class_of(clazz: Type[ProbabilisticCircuitMixin]) -> Type[Layer]:
Expand Down Expand Up @@ -161,18 +163,13 @@ def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[Probabilis
"""
raise NotImplementedError

def log_conditional(self, event: Event) -> Tuple[Optional[Layer], float]:

# skip trivial case
def log_conditional(self, event: Event) -> Tuple[Optional[Layer], torch.Tensor]:
if event.is_empty():
return None, -torch.inf

# if the event is easy, don't create a proxy node
elif len(event.simple_sets) == 1:
conditional, log_probability = self.log_conditional_of_simple_event(event.simple_sets[0])
return conditional, log_probability.item()

raise NotImplementedError
return self.impossible_condition_result
if len(event.simple_sets) == 1:
return self.log_conditional_of_simple_event(event.simple_sets[0])
else:
return self.log_conditional_of_composite_event(event)

@abstractmethod
def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[Optional[Layer], torch.Tensor]:
Expand Down Expand Up @@ -248,14 +245,6 @@ def log_conditional_of_composite_event(self, event: Event):
resulting_layer = SumLayer([merged_layer], [log_weights])
return resulting_layer, log_probabilities

def log_conditional(self, event: Event) -> Tuple[Optional[Layer], torch.Tensor]:
if event.is_empty():
return self.impossible_condition_result
if len(event.simple_sets) == 1:
return self.log_conditional_of_simple_event(event.simple_sets[0])
else:
return self.log_conditional_of_composite_event(event)


class InputLayer(Layer, ABC):
"""
Expand Down Expand Up @@ -365,9 +354,6 @@ def number_of_nodes(self) -> int:
def support_per_node(self) -> List[Event]:
pass

def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[Optional[Self], torch.Tensor]:
pass

def log_likelihood(self, x: torch.Tensor) -> torch.Tensor:
result = torch.zeros(len(x), self.number_of_nodes)

Expand Down Expand Up @@ -399,6 +385,41 @@ def log_likelihood(self, x: torch.Tensor) -> torch.Tensor:

return torch.log(result) - self.log_normalization_constants

def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[Optional[Layer], torch.Tensor]:

conditional_child_layers = []
conditional_log_weights = []

probabilities = torch.zeros(self.number_of_nodes, 1)

for log_weights, child_layer in self.log_weighted_child_layers:
# get the conditional of the child layer, log prob shape: (#child_nodes, 1)

conditional, child_log_prob = child_layer.log_conditional_of_simple_event(event)

if conditional is None:
continue

if log_weights.is_sparse:
log_weights = log_weights.clone().coalesce().double() # shape: (#nodes, #child_nodes)
add_sparse_edges_dense_child_tensor_inplace(log_weights, child_log_prob)
probabilities += log_weights.sum(1).unsqueeze(-1)

sparse_remove_rows_and_cols_where_all(log_weights, -torch.inf)
print("-" * 80)
else:
raise NotImplementedError("Only sparse weights are supported for conditioning.")

conditional_child_layers.append(conditional)
conditional_log_weights.append(log_weights)


if len(conditional_child_layers) == 0:
return self.impossible_condition_result

resulting_layer = SumLayer(conditional_child_layers, conditional_log_weights)
return resulting_layer, probabilities.log()

@classmethod
def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[SumUnit],
child_layers: List[AnnotatedLayer]) -> \
Expand Down Expand Up @@ -450,9 +471,6 @@ def probability_of_simple_event(self, event: SimpleEvent) -> torch.Tensor:
def log_mode(self) -> Tuple[Event, float]:
pass

def log_conditional(self, event: Event) -> Tuple[Optional[Union[ProbabilisticModel, Self]], float]:
pass

def sample(self, amount: int) -> torch.Tensor:
pass

Expand Down Expand Up @@ -485,6 +503,9 @@ class ProductLayer(InnerLayer):
child layer. Nodes in the child layer can be mapped to by multiple nodes in this layer.
"""

def merge_with(self, others: List[Self]):
pass

def __init__(self, child_layers: List[Layer], edges: List[torch.Tensor]):
"""
Initialize the product layer.
Expand Down
2 changes: 1 addition & 1 deletion src/probabilistic_model/learning/torch/uniform_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tupl
non_empty_intervals = [simple_interval_to_open_tensor(intersection) for intersection in intersections
if not intersection.is_empty()]
if len(non_empty_intervals) == 0:
return None, probabilities
return self.impossible_condition_result
new_intervals = torch.stack(non_empty_intervals)
return self.__class__(self.variable, new_intervals), probabilities

Expand Down
98 changes: 98 additions & 0 deletions src/probabilistic_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,101 @@ def simple_interval_to_open_tensor(interval: SimpleInterval) -> torch.Tensor:
if interval.right == Bound.CLOSED:
upper = nextafter(upper, upper + 1)
return torch.tensor([lower, upper])


def remove_rows_and_cols_where_all(tensor: torch.Tensor, value: float) -> torch.Tensor:
"""
Remove rows and columns from a tensor where all elements are equal to a given value.
:param tensor: The tensor to remove rows and columns from.
:param value: The value to remove.
:return: The tensor without the rows and columns.
Example::
>>> t = torch.tensor([[1, 0, 3], [0, 0, 0], [7, 0, 9]])
>>> remove_rows_and_cols_where_all(t, 0)
tensor([[1, 3],
[7, 9]])
"""

# get the rows and columns that are not entirely -inf
valid_rows = (tensor != value).any(dim=1)
valid_cols = (tensor != value).any(dim=0)

# remove rows and cols that are entirely -inf
valid_tensor = tensor[valid_rows][:, valid_cols]
return valid_tensor


def sparse_remove_rows_and_cols_where_all(tensor: torch.Tensor, value: float) -> torch.Tensor:
# get indices of values where all elements are equal to a given value
values = tensor.values()
valid_elements = (values != value)
valid_indices = tensor.indices()[valid_elements]
print(values)
print(valid_elements)
result = torch.sparse_coo_tensor(valid_indices, values[valid_elements]).coalesce()
return result


def shrink_index_tensor(index_tensor: torch.Tensor) -> torch.Tensor:
"""
Shrink a 2D index tensor to only contain successive indices.
The tensor has shape (#indices, 2).
Example::
>>> shrink_index_tensor(torch.tensor([[0, 3], [1, 0], [4, 1]]))
tensor([[0, 2], [1, 0], [2, 1]])
:param index_tensor: The index tensor to shrink.
:return: The shrunken index tensor.
"""

result = index_tensor.clone()

for dim in range(2):
unique_indices = torch.unique(index_tensor[:, dim], sorted=True)

for new_index, unique_index in zip(range(len(unique_indices)), unique_indices):
result[result[:, dim] == unique_index, dim] = new_index

# map the old indices to the new indices
return result

def sparse_dense_mul_inplace(sparse: torch.Tensor, dense: torch.Tensor):
"""
Multiply a sparse tensor with a dense tensor element-wise in-place of the sparse tensor.
:param sparse: The sparse tensor
:param dense: The dense tensor
:return: The result of the multiplication
"""
indices = sparse._indices()

# get values from relevant entries of dense matrix
dense_values_at_sparse_indices = dense[indices[0, :], indices[1, :]]

# multiply sparse values with dense values inplace
sparse.values().mul_(dense_values_at_sparse_indices)

def add_sparse_edges_dense_child_tensor_inplace(edges: torch.Tensor, dense_child_tensor: torch.Tensor):
"""
Add a dense tensor to a sparse tensor at the positions specified by the edge tensor.
This method is used when a weighted sum of the child tensor is necessary.
The edges specify how to weight the child tensor and the dense tensor is the child tensor.
The result is stored in the sparse tensor.
:param edges: The edge tensor of shape (#edges, n).
:param dense_child_tensor: The dense tensor of shape (n, 1).
:return: The result of the addition
"""
# get indices of the sparse tensor
indices = edges._indices()

# get values from relevant entries of dense matrix
dense_values_at_sparse_indices = dense_child_tensor[indices[1]].squeeze()

# add sparse values with dense values inplace
edges.values().add_(dense_values_at_sparse_indices)
23 changes: 20 additions & 3 deletions test/test_torch/test_inner_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from random_events.variable import Continuous
from torch.testing import assert_close

from probabilistic_model.learning.torch.uniform_layer import UniformLayer
from probabilistic_model.learning.torch.pc import SumLayer, ProductLayer
from probabilistic_model.learning.torch.uniform_layer import UniformLayer
from probabilistic_model.probabilistic_circuit.distributions import UniformDistribution
from probabilistic_model.probabilistic_circuit.probabilistic_circuit import SumUnit, ProductUnit

Expand Down Expand Up @@ -64,6 +64,24 @@ def test_probability(self):
assert_almost_equal([p_by_hand_1, p_by_hand_2], prob[:, 0].tolist())


class SparseSumUnitTestCase(unittest.TestCase):
x = Continuous("x")
p1_x = UniformLayer(x, torch.Tensor([[0, 1]]))
p2_x = UniformLayer(x, torch.Tensor([[1, 3], [1, 1.5]]))
s1 = SumLayer([p1_x, p2_x],
log_weights=[torch.tensor([[math.log(2)], [1]]).to_sparse_coo(),
torch.tensor([[0, 0], [1, 1]]).to_sparse_coo()])

def test_conditional(self):
event = SimpleEvent({self.x: closed(2., 3.)}).as_composite_set()
c, lp = self.s1.log_conditional(event)
c.validate()
print(c.log_weights)
self.assertEqual(c.number_of_nodes, 1)




class ProductTestCase(unittest.TestCase):
x = Continuous("x")
y = Continuous("y")
Expand Down Expand Up @@ -95,8 +113,7 @@ def test_log_likelihood(self):
assert_almost_equal(ll_p2_by_hand.tolist(), ll[:, 1].tolist())

def test_probability(self):
event = SimpleEvent({self.x: closed(0.5, 2.5) | closed(3, 5),
self.y: closed(0.5, 2.5) | closed(3, 5)})
event = SimpleEvent({self.x: closed(0.5, 2.5) | closed(3, 5), self.y: closed(0.5, 2.5) | closed(3, 5)})
prob = self.product.probability_of_simple_event(event)
self.assertEqual(prob.shape, (2, 1))
p_by_hand_1 = self.product_1.probability_of_simple_event(event)
Expand Down
2 changes: 1 addition & 1 deletion test/test_torch/test_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_conditional_single_truncation(self):
layer, ll = self.p_x.log_conditional_of_simple_event(event)
self.assertEqual(layer.number_of_nodes, 2)
assert_close(layer.interval, torch.tensor([[0.5, 1], [1, 2.5]]))
assert_close(torch.tensor([0.5, 0.75]).reshape(-1, 1).double(), ll)
assert_close(torch.tensor([0.5, 0.75]).reshape(-1, 1).double().log(), ll)

def test_conditional_multiple_truncation(self):
event = closed(-1, 0.5) | closed(0.7, 0.8) | closed(2., 3.) | closed(3.5, 4.)
Expand Down
28 changes: 27 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import unittest

from torch.testing import assert_close

from probabilistic_model.distributions.distributions import SymbolicDistribution
import probabilistic_model.probabilistic_circuit
import probabilistic_model.probabilistic_circuit.distributions
from probabilistic_model.utils import type_converter

import torch
from probabilistic_model.utils import sparse_dense_mul_inplace, add_sparse_edges_dense_child_tensor_inplace, shrink_index_tensor

class TypeConversionTestCase(unittest.TestCase):

Expand All @@ -13,5 +16,28 @@ def test_univariate_distribution_type_converter(self):
self.assertEqual(result, probabilistic_model.probabilistic_circuit.distributions.SymbolicDistribution)


class TorchUtilsTestCase(unittest.TestCase):

def test_sparse_dense_mul_inplace(self):
indices = torch.tensor([[0, 1], [1, 0]])
values = torch.tensor([2., 3.])
sparse = torch.sparse_coo_tensor(indices, values, ).coalesce()
dense = torch.tensor([[1., 2.], [3, 4]])
sparse_dense_mul_inplace(sparse, dense)
assert_close(sparse.values(), torch.tensor([4., 9.]))

def test_add_sparse_edges_dense_child_tensor_inplace(self):
indices = torch.tensor([[0, 1], [1, 0], [1, 1]]).T
values = torch.tensor([2., 3., 4.])
sparse = torch.sparse_coo_tensor(indices, values, ).coalesce()
dense = torch.tensor([1., 2.]).reshape(-1, 1)
add_sparse_edges_dense_child_tensor_inplace(sparse, dense)
assert_close(sparse.values(), torch.tensor([4., 4., 6.]))

def test_shrink_index_tensor(self):
shrank = shrink_index_tensor(torch.tensor([[0, 3], [1, 0], [4, 1]]))
assert_close(torch.tensor([[0, 2], [1, 0], [2, 1]]), shrank)


if __name__ == '__main__':
unittest.main()

0 comments on commit 8797527

Please sign in to comment.