From a9edb8b64549de8b44edf1b446460208708803c3 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Tue, 30 Jul 2024 17:38:51 +0200 Subject: [PATCH] Single truncation is working. Multiple truncation is on its way --- .../learning/torch/__init__.py | 2 +- .../learning/torch/input_layer.py | 79 +++++++++++-------- src/probabilistic_model/learning/torch/pc.py | 4 +- .../learning/torch/uniform_layer.py | 29 ++++--- test/test_torch/test_dirac.py | 1 + test/test_torch/test_uniform.py | 26 ++++++ 6 files changed, 96 insertions(+), 45 deletions(-) diff --git a/src/probabilistic_model/learning/torch/__init__.py b/src/probabilistic_model/learning/torch/__init__.py index c1b8cb1..016df76 100644 --- a/src/probabilistic_model/learning/torch/__init__.py +++ b/src/probabilistic_model/learning/torch/__init__.py @@ -1,3 +1,3 @@ from .uniform_layer import * from .pc import * -from input_layer import * +from .input_layer import * diff --git a/src/probabilistic_model/learning/torch/input_layer.py b/src/probabilistic_model/learning/torch/input_layer.py index b78f20d..4ee254a 100644 --- a/src/probabilistic_model/learning/torch/input_layer.py +++ b/src/probabilistic_model/learning/torch/input_layer.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from typing import Optional import random_events @@ -9,7 +9,6 @@ from random_events.product_algebra import Event, SimpleEvent from random_events.sigma_algebra import AbstractCompositeSet from random_events.variable import Continuous -from sortedcontainers import SortedSet from typing_extensions import List, Tuple, Self from .pc import InputLayer, AnnotatedLayer, SumLayer @@ -39,35 +38,40 @@ def probability_of_simple_event(self, event: SimpleEvent) -> torch.Tensor: lower_bound_cdf = self.cdf(points[:, (0,)]) return (upper_bound_cdf - lower_bound_cdf).sum(dim=0).unsqueeze(-1) - def log_conditional(self, event: Event) -> Tuple[Optional[Self], float]: - + def log_conditional_of_simple_event(self, event: SimpleEvent): if event.is_empty(): return None, -torch.inf + interval: Interval = event[self.variable] - # extract the interval of the event - marginal_event = event.marginal(SortedSet(self.variables)) - assert len(marginal_event.simple_sets) == 1, "The event must be a simple event." - interval = marginal_event.simple_sets[0][self.variable] + if interval.is_singleton(): + return self.log_conditional_from_singleton(interval.simple_sets[0]) if len(interval.simple_sets) == 1: return self.log_conditional_from_simple_interval(interval.simple_sets[0]) else: return self.log_conditional_from_interval(interval) - def log_conditional_from_singletons(self, singletons: List[SimpleInterval]) -> Tuple[DiracDeltaLayer, torch.Tensor]: + def log_conditional_from_singleton(self, singleton: SimpleInterval) -> Tuple[DiracDeltaLayer, torch.Tensor]: """ - Calculate the conditional distribution given a list singleton events with p(event) > zero forall events. + Calculate the conditional distribution given singleton event. In this case, the conditional distribution is a Dirac delta distribution and the log-likelihood is chosen for the log-probability. - :param singletons: The singleton events - :return: The dirac delta layer and the log-likelihoods with shape (#singletons, 1). - """ - values = torch.tensor([s.lower for s in singletons]) - log_likelihoods = self.log_likelihood(values.reshape(-1, 1)) - return DiracDeltaLayer(self.variable, values, log_likelihoods), log_likelihoods + This method returns a Dirac delta layer that has at most the same number of nodes as the input layer. + :param singleton: The singleton event + :return: The dirac delta layer and the log-likelihoods with shape (something <= #singletons, 1). + """ + value = singleton.lower + log_likelihoods = self.log_likelihood(torch.tensor(value).reshape(-1, 1)).squeeze() # shape: (#nodes, ) + possible_indices = (log_likelihoods != -torch.inf).nonzero()[0] # shape: (#dirac-nodes, ) + filtered_likelihood = log_likelihoods[possible_indices] + locations = torch.full_like(filtered_likelihood, value) + layer = DiracDeltaLayer(self.variable, locations, torch.exp(filtered_likelihood)) + return layer, log_likelihoods + + @abstractmethod def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, torch.Tensor]: """ Calculate the conditional distribution given a simple interval with p(interval) > 0. @@ -76,27 +80,36 @@ def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tupl :param interval: The simple interval :return: The conditional distribution and the log-probability of the interval. """ - # form the intersection per node - intersection = [interval.intersection_with(node_interval.simple_sets[0]) for node_interval in - self.univariate_support_per_node] - singletons = [simple_interval for simple_interval in intersection if simple_interval.is_singleton()] - non_singletons = [simple_interval for simple_interval in intersection if not simple_interval.is_singleton()] - - - def log_conditional_from_non_singleton_simple_interval(self, interval: SimpleInterval) -> Tuple[SumLayer, float]: - """ - Calculate the conditional distribution given a non-singleton, simple interval with p(interval) > 0. - :param interval: The simple interval. - :return: The conditional distribution and the log-probability of the interval. - """ raise NotImplementedError - def log_conditional_from_interval(self, interval) -> Tuple[Self, float]: + def log_conditional_from_interval(self, interval: Interval) -> Tuple[SumLayer, torch.Tensor]: """ Calculate the conditional distribution given an interval with p(interval) > 0. :param interval: The simple interval :return: The conditional distribution and the log-probability of the interval. """ + results = [self.log_conditional_from_simple_interval(simple_interval) for simple_interval in + interval.simple_sets] + input_layer = results[0][0] + input_layer.merge_with([layer for layer, _ in results[1:]]) + log_weights = torch.full((self.number_of_nodes, input_layer.number_of_nodes), -torch.inf) + + for i, (layer, log_likelihood) in enumerate(results): + log_likelihood = log_likelihood.squeeze() + indices = (log_likelihood != -torch.inf).nonzero().squeeze() + print(indices) + log_weights[indices, i] = log_likelihood[log_likelihood != -torch.inf].float() + + resulting_layer = SumLayer([input_layer], [log_weights]) + return resulting_layer, resulting_layer.log_normalization_constants + + @abstractmethod + def merge_with(self, others: List[Self]): + """ + Merge this layer with another layer inplace. + + :param others: The other layers + """ raise NotImplementedError @@ -154,7 +167,6 @@ def included_condition(self, x: torch.Tensor) -> torch.Tensor: class DiracDeltaLayer(ContinuousLayer): - location: torch.Tensor """ The locations of the Dirac delta distributions. @@ -175,6 +187,10 @@ def validate(self): assert self.location.shape == self.density_cap.shape, "The shapes of the location and density cap must match." assert all(self.density_cap > 0), "The density cap must be positive." + @property + def number_of_nodes(self) -> int: + return len(self.location) + def log_likelihood(self, x: torch.Tensor) -> torch.Tensor: return torch.where(x == self.location, torch.log(self.density_cap), -torch.inf) @@ -196,4 +212,3 @@ def log_mode(self) -> Tuple[Event, float]: def sample(self, amount: int) -> torch.Tensor: pass - diff --git a/src/probabilistic_model/learning/torch/pc.py b/src/probabilistic_model/learning/torch/pc.py index a6db28e..501334a 100644 --- a/src/probabilistic_model/learning/torch/pc.py +++ b/src/probabilistic_model/learning/torch/pc.py @@ -81,6 +81,7 @@ def validate(self): raise NotImplementedError @property + @abstractmethod def number_of_nodes(self) -> int: """ The number of nodes in the layer. @@ -174,7 +175,7 @@ def log_conditional(self, event: Event) -> Tuple[Optional[Layer], float]: raise NotImplementedError @abstractmethod - def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[Optional[Self], torch.Tensor]: + def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[Optional[AnnotatedLayer], torch.Tensor]: raise NotImplementedError @@ -509,7 +510,6 @@ def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[Optional[ pass - @dataclass class AnnotatedLayer: layer: Layer diff --git a/src/probabilistic_model/learning/torch/uniform_layer.py b/src/probabilistic_model/learning/torch/uniform_layer.py index 8ac7bfd..7f938a8 100644 --- a/src/probabilistic_model/learning/torch/uniform_layer.py +++ b/src/probabilistic_model/learning/torch/uniform_layer.py @@ -1,6 +1,9 @@ from __future__ import annotations from typing import Tuple, Optional, Union, Type, List + +import random_events.interval +from random_events.interval import SimpleInterval, Bound from typing_extensions import Self import torch @@ -25,6 +28,9 @@ class UniformLayer(ContinuousLayerWithFiniteSupport): The index of the variable that this layer represents. """ + def merge_with(self, others: List[Self]): + self.interval = torch.cat([self.interval] + [other.interval for other in others]) + def __init__(self, variable: Continuous, interval: torch.Tensor): """ Initialize the uniform layer. @@ -41,15 +47,6 @@ def cdf(self, x: torch.Tensor) -> torch.Tensor: result = torch.clamp(result, 0, 1) return result - def log_mode(self) -> Tuple[Event, float]: - pass - - def log_conditional(self, event: Event) -> Tuple[Optional[Union[ProbabilisticModel, Self]], float]: - pass - - def sample(self, amount: int) -> torch.Tensor: - pass - @classmethod def original_class(cls) -> Tuple[Type, ...]: return UniformDistribution, @@ -79,9 +76,21 @@ def log_likelihood(self, x: torch.Tensor) -> torch.Tensor: """ return torch.where(self.included_condition(x), self.log_pdf_value(), -torch.inf) - def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[Optional[Self], torch.Tensor]: + def log_mode(self) -> Tuple[Event, float]: pass + def sample(self, amount: int) -> torch.Tensor: + pass + + def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, torch.Tensor]: + probabilities = self.probability_of_simple_event(SimpleEvent({self.variable: interval})) + intersections = [interval.intersection_with(SimpleInterval(lower.item(), upper.item(), + Bound.OPEN, Bound.OPEN)) + for lower, upper in self.interval] + return self.__class__(self.variable, torch.stack([simple_interval_to_open_tensor(intersection) + for intersection in intersections + if not intersection.is_empty()])), probabilities + @classmethod def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[UniformUnit], child_layers: List[AnnotatedLayer]) -> \ diff --git a/test/test_torch/test_dirac.py b/test/test_torch/test_dirac.py index c3185c2..c33790f 100644 --- a/test/test_torch/test_dirac.py +++ b/test/test_torch/test_dirac.py @@ -29,5 +29,6 @@ def test_support_per_node(self): SimpleEvent({self.x: singleton(1)}).as_composite_set()] self.assertEqual(support, result) + if __name__ == '__main__': unittest.main() diff --git a/test/test_torch/test_uniform.py b/test/test_torch/test_uniform.py index dbb8c46..bb55414 100644 --- a/test/test_torch/test_uniform.py +++ b/test/test_torch/test_uniform.py @@ -7,6 +7,7 @@ from random_events.variable import Continuous from torch.testing import assert_close +from probabilistic_model.learning.torch import SumLayer from probabilistic_model.learning.torch.uniform_layer import UniformLayer @@ -42,6 +43,31 @@ def test_support_per_node(self): SimpleEvent({self.x: open(1, 3)}).as_composite_set()] self.assertEqual(support, result) + def test_conditional_singleton(self): + event = SimpleEvent({self.x: closed(0.5, 0.5)}) + layer, ll = self.p_x.log_conditional_of_simple_event(event) + self.assertEqual(layer.number_of_nodes, 1) + assert_close(torch.tensor([0.5]), layer.location) + assert_close(torch.tensor([1.]), layer.density_cap) + + def test_conditional_single_truncation(self): + event = SimpleEvent({self.x: closed(0.5, 2.5)}) + layer, ll = self.p_x.log_conditional_of_simple_event(event) + self.assertEqual(layer.number_of_nodes, 2) + assert_close(layer.interval, torch.tensor([[0.5, 1], [1, 2.5]])) + assert_close(torch.tensor([0.5, 0.75]).reshape(-1, 1).double(), ll) + + def test_conditional_multiple_truncation(self): + event = closed(-1, 0.5) | closed(0.7, 0.8) | closed(2., 3.) | closed(3., 4.) + layer, ll = self.p_x.log_conditional_from_interval(event) + self.assertIsInstance(layer, SumLayer) + layer.validate() + self.assertEqual(layer.number_of_nodes, 2) + self.assertEqual(len(layer.child_layers), 1) + + + assert_close(torch.tensor([0.5, 0.75]).reshape(-1, 1).double(), ll) + if __name__ == '__main__': unittest.main()