Skip to content

Commit

Permalink
Sampling from a product layer now works.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Aug 22, 2024
1 parent 561cf41 commit 46c961f
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 52 deletions.
12 changes: 10 additions & 2 deletions src/probabilistic_model/learning/torch/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .pc import InputLayer, AnnotatedLayer, SumLayer
from ...probabilistic_circuit.probabilistic_circuit import ProbabilisticCircuitMixin
from ...utils import interval_as_array, remove_rows_and_cols_where_all
from ...utils import interval_as_array, remove_rows_and_cols_where_all, create_sparse_tensor_indices_from_row_lengths


class ContinuousLayer(InputLayer, ABC):
Expand Down Expand Up @@ -200,7 +200,7 @@ def number_of_nodes(self) -> int:
return len(self.location)

def log_likelihood(self, x: torch.Tensor) -> torch.Tensor:
return torch.where(x == self.location, torch.log(self.density_cap), -torch.inf)
return torch.where(x == self.location, torch.log(self.density_cap), -torch.inf).reshape(1, -1)

@classmethod
def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[ProbabilisticCircuitMixin],
Expand Down Expand Up @@ -241,5 +241,13 @@ def remove_nodes_inplace(self, remove_mask: torch.BoolTensor):
self.location = self.location[~remove_mask]
self.density_cap = self.density_cap[~remove_mask]

def sample_from_frequencies(self, frequencies: torch.Tensor) -> torch.Tensor:
max_frequency = max(frequencies)
result_indices = create_sparse_tensor_indices_from_row_lengths(frequencies)
values = self.location.repeat_interleave(frequencies)
result = torch.sparse_coo_tensor(result_indices, values, (self.number_of_nodes, max_frequency),
is_coalesced=True)
return result

def __deepcopy__(self):
return self.__class__(self.variable, self.location.clone(), self.density_cap.clone())
38 changes: 35 additions & 3 deletions src/probabilistic_model/learning/torch/pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import numpy as np
import torch
import torch.nn as nn
from numpy.ma.core import shape, argsort
from random_events.product_algebra import Event, SimpleEvent
from random_events.sigma_algebra import AbstractCompositeSet
from random_events.utils import recursive_subclasses
from random_events.variable import Variable
from sortedcontainers import SortedSet
from triton.language import dtype
from typing_extensions import List, Tuple, Type, Dict, Union, Self

from ...probabilistic_circuit.probabilistic_circuit import ProbabilisticCircuit, \
Expand Down Expand Up @@ -820,6 +822,7 @@ def log_likelihood(self, x: torch.Tensor) -> torch.Tensor:
for columns, edges, layer in zip(self.columns_of_child_layers, self.edges, self.child_layers):

edges = edges.coalesce()

# calculate the log likelihood over the columns of the child layer
ll = layer.log_likelihood(x[:, columns]) # shape: (#x, #child_nodes)

Expand Down Expand Up @@ -871,9 +874,6 @@ def probability_of_simple_event(self, event: SimpleEvent) -> torch.Tensor:
def log_mode(self) -> Tuple[Event, float]:
pass

def sample(self, amount: int) -> torch.Tensor:
pass

@property
def support_per_node(self) -> List[Event]:
pass
Expand Down Expand Up @@ -983,6 +983,38 @@ def clean_up_orphans_inplace(self):
shrunken_indices = shrink_index_tensor(self.edges.indices())
self.edges =torch.sparse_coo_tensor(shrunken_indices, self.edges.values())

def sample_from_frequencies(self, frequencies: torch.Tensor) -> torch.Tensor:

concatenated_samples_per_variable = [torch.zeros(0) for _ in range(len(self.variables))]

for index, (edges, child_layer) in enumerate(zip(self.edges, self.child_layers)):
edges: torch.Tensor = edges.coalesce() # shape (self.number_of_nodes,)
squeezed_edge_indices = edges.indices().squeeze(0)

# count the number of samples for each child node
frequencies_for_child_layer = torch.zeros(child_layer.number_of_nodes, dtype=torch.long) # shape (#child_nodes)
frequencies_for_child_layer = frequencies_for_child_layer.scatter_add(0, edges.values(),
frequencies[squeezed_edge_indices])

# sample the child layer
child_layer_samples = child_layer.sample_from_frequencies(frequencies_for_child_layer)

# reorder the samples according to the order required by the values of the edges (request order of children)
reordered_sample_values = child_layer_samples.index_select(0, edges.values().unique_consecutive()).coalesce().values()

# write samples in the correct columns for the result
for column in self.columns_of_child_layers[index]:
concatenated_samples_per_variable[column] = (
torch.cat((concatenated_samples_per_variable[column], reordered_sample_values)))

# assemble the result
result_indices = create_sparse_tensor_indices_from_row_lengths(frequencies)
result_values = torch.stack(concatenated_samples_per_variable, dim=1)
result = torch.sparse_coo_tensor(result_indices, result_values,
size=(self.number_of_nodes, max(frequencies), len(self.variables)),
is_coalesced=True)
return result


def __deepcopy__(self):
child_layers = [child_layer.__deepcopy__() for child_layer in self.child_layers]
Expand Down
4 changes: 4 additions & 0 deletions test/test_torch/test_dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def test_conditional_of_simple_interval(self):
assert_close(layer.location, torch.tensor([0.]))
assert_close(layer.density_cap, torch.tensor([1.]))

def test_sample(self):
s = self.p_x.sample_from_frequencies(torch.tensor([10, 5]))
self.assertTrue(torch.all(s.values()[:10] == 0.))
self.assertTrue(torch.all(s.values()[10:] == 1.))


if __name__ == '__main__':
Expand Down
87 changes: 40 additions & 47 deletions test/test_torch/test_product_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from random_events.variable import Continuous
from torch.testing import assert_close

from probabilistic_model.learning.torch import DiracDeltaLayer
from probabilistic_model.learning.torch.pc import SumLayer, ProductLayer
from probabilistic_model.learning.torch.uniform_layer import UniformLayer
from probabilistic_model.probabilistic_circuit.distributions import UniformDistribution
from probabilistic_model.probabilistic_circuit.probabilistic_circuit import SumUnit, ProductUnit
from probabilistic_model.utils import embed_sparse_tensor_in_nan_tensor


class ProductTestCase2(unittest.TestCase):
Expand Down Expand Up @@ -59,63 +61,54 @@ def test_conditional_of_simple_event(self):
self.assertEqual(c.child_layers[1].number_of_nodes, 1)
assert_close(lp, torch.tensor([0.25, 0.]).log())

def test_sample_from_frequencies(self):
torch.random.manual_seed(69)
frequencies = torch.tensor([5, 3])
samples = self.product_layer.sample_from_frequencies(frequencies)
for index, sample_row in enumerate(samples):
sample_row = sample_row.coalesce().values()
self.assertEqual(len(sample_row), frequencies[index])
sample_row = sample_row.reshape(-1, 1)
likelihood = self.product_layer.likelihood(sample_row)
self.assertTrue(all(likelihood[:, index] > 0.))

class ProductTestCase(unittest.TestCase):

class ProductDiracTestCase(unittest.TestCase):
x = Continuous("x")
y = Continuous("y")
p1_x_by_hand = UniformDistribution(x, SimpleInterval(0, 1))
p1_y_by_hand = UniformDistribution(y, SimpleInterval(0.5, 1))
p2_y_by_hand = UniformDistribution(y, SimpleInterval(5, 6))
z = Continuous("z")

product_1 = ProductUnit()
product_1.add_subcircuit(p1_x_by_hand)
product_1.add_subcircuit(p1_y_by_hand)
p1_x = DiracDeltaLayer(x, torch.tensor([0., 1.]).double(), torch.tensor([1, 1]).double())
p2_x = DiracDeltaLayer(x, torch.tensor([2., 3.]).double(), torch.tensor([1, 1]).double())
p_y = DiracDeltaLayer(y, torch.tensor([4., 5.]).double(), torch.tensor([1, 1]).double())
p_z = DiracDeltaLayer(z, torch.tensor([6.]).double(), torch.tensor([1]).double())

product_2 = ProductUnit()
product_2.probabilistic_circuit = product_1.probabilistic_circuit
product_2.add_subcircuit(p1_x_by_hand)
product_2.add_subcircuit(p2_y_by_hand)
indices = torch.tensor([[1, 2, 3, 3, 0, 0],
[0, 1, 0, 1, 0, 1]])
values = torch.tensor([0, 0, 1, 0, 0, 0])
edges = torch.sparse_coo_tensor(indices, values).coalesce()

p1_x = UniformLayer(x, torch.Tensor([[0, 1]]))
p1_y = UniformLayer(y, torch.Tensor([[0.5, 1], [5, 6]]))
product_layer = ProductLayer([p_z, p1_x, p2_x, p_y, ], edges)

product = ProductLayer(child_layers=[p1_x, p1_y], edges=torch.tensor([[0, 0], [0, 1]]).long())
def test_likelihood(self):
data = [[0., 5., 6.]]
likelihood = self.product_layer.log_likelihood(torch.tensor(data))
self.assertTrue(all(likelihood[:, 0] > 0))

def test_log_likelihood(self):
data = [[0.5, 0.75], [0.9, 0.7], [0.5, 5.5]]
ll_p1_by_hand = self.product_1.log_likelihood(np.array(data))
ll_p2_by_hand = self.product_2.log_likelihood(np.array(data))
ll = self.product.log_likelihood(torch.tensor(data))
self.assertEqual(ll.shape, (3, 2))
assert_almost_equal(ll_p1_by_hand.tolist(), ll[:, 0].tolist())
assert_almost_equal(ll_p2_by_hand.tolist(), ll[:, 1].tolist())
def test_sample_from_frequencies(self):
torch.random.manual_seed(69)
frequencies = torch.tensor([5, 3])
samples = self.product_layer.sample_from_frequencies(frequencies)

def test_probability(self):
event = SimpleEvent({self.x: closed(0.5, 2.5) | closed(3, 5), self.y: closed(0.5, 2.5) | closed(3, 5)})
prob = self.product.probability_of_simple_event(event)
self.assertEqual(prob.shape, (2,))
p_by_hand_1 = self.product_1.probability_of_simple_event(event)
p_by_hand_2 = self.product_2.probability_of_simple_event(event)
assert_almost_equal([p_by_hand_1, p_by_hand_2], prob.tolist())
samples_n0 = samples[0].to_dense()
samples_n1 = samples[1].to_dense()

self.assertEqual(samples_n0.shape, torch.Size((5, 3)))
self.assertEqual(samples_n1.shape, torch.Size((5, 3)))
self.assertEqual(len(samples[1].coalesce().values()), 3)
self.assertTrue(torch.all(samples_n0 == torch.tensor([0, 5 ,6])))
self.assertTrue(torch.all(samples_n1[:3] == torch.tensor([2, 4 ,6])))

def test_conditional_of_simple_event(self):
event = SimpleEvent({self.x: closed(0.5, 2.), self.y: closed(4, 5.5)})
c, lp = self.product.log_conditional_of_simple_event(event)
c.validate()
self.assertEqual(c.number_of_nodes, 1)
self.assertEqual(len(c.child_layers), 2)
self.assertEqual(c.child_layers[0].number_of_nodes, 1)
self.assertEqual(c.child_layers[1].number_of_nodes, 1)
assert_close(lp, torch.tensor([0., 0.25]).log())

def test_remove_nodes_inplace(self):
product = self.product.__deepcopy__()
remove_mask = torch.tensor([1, 0]).bool()
product.remove_nodes_inplace(remove_mask)
self.assertEqual(product.number_of_nodes, 1)
product.validate()
self.assertEqual(len(product.child_layers), 2)
self.assertTrue((product.edges == 0).all())


class CleanUpTestCase(unittest.TestCase):
Expand Down

0 comments on commit 46c961f

Please sign in to comment.