Skip to content

Commit

Permalink
Single truncation is working. Multiple truncation is on its way
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Jul 30, 2024
1 parent 23660bb commit a9edb8b
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/probabilistic_model/learning/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .uniform_layer import *
from .pc import *
from input_layer import *
from .input_layer import *
79 changes: 47 additions & 32 deletions src/probabilistic_model/learning/torch/input_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from abc import ABC
from abc import ABC, abstractmethod
from typing import Optional

import random_events
Expand All @@ -9,7 +9,6 @@
from random_events.product_algebra import Event, SimpleEvent
from random_events.sigma_algebra import AbstractCompositeSet
from random_events.variable import Continuous
from sortedcontainers import SortedSet
from typing_extensions import List, Tuple, Self

from .pc import InputLayer, AnnotatedLayer, SumLayer
Expand Down Expand Up @@ -39,35 +38,40 @@ def probability_of_simple_event(self, event: SimpleEvent) -> torch.Tensor:
lower_bound_cdf = self.cdf(points[:, (0,)])
return (upper_bound_cdf - lower_bound_cdf).sum(dim=0).unsqueeze(-1)

def log_conditional(self, event: Event) -> Tuple[Optional[Self], float]:

def log_conditional_of_simple_event(self, event: SimpleEvent):
if event.is_empty():
return None, -torch.inf
interval: Interval = event[self.variable]

# extract the interval of the event
marginal_event = event.marginal(SortedSet(self.variables))
assert len(marginal_event.simple_sets) == 1, "The event must be a simple event."
interval = marginal_event.simple_sets[0][self.variable]
if interval.is_singleton():
return self.log_conditional_from_singleton(interval.simple_sets[0])

if len(interval.simple_sets) == 1:
return self.log_conditional_from_simple_interval(interval.simple_sets[0])
else:
return self.log_conditional_from_interval(interval)

def log_conditional_from_singletons(self, singletons: List[SimpleInterval]) -> Tuple[DiracDeltaLayer, torch.Tensor]:
def log_conditional_from_singleton(self, singleton: SimpleInterval) -> Tuple[DiracDeltaLayer, torch.Tensor]:
"""
Calculate the conditional distribution given a list singleton events with p(event) > zero forall events.
Calculate the conditional distribution given singleton event.
In this case, the conditional distribution is a Dirac delta distribution and the log-likelihood is chosen
for the log-probability.
:param singletons: The singleton events
:return: The dirac delta layer and the log-likelihoods with shape (#singletons, 1).
"""
values = torch.tensor([s.lower for s in singletons])
log_likelihoods = self.log_likelihood(values.reshape(-1, 1))
return DiracDeltaLayer(self.variable, values, log_likelihoods), log_likelihoods
This method returns a Dirac delta layer that has at most the same number of nodes as the input layer.
:param singleton: The singleton event
:return: The dirac delta layer and the log-likelihoods with shape (something <= #singletons, 1).
"""
value = singleton.lower
log_likelihoods = self.log_likelihood(torch.tensor(value).reshape(-1, 1)).squeeze() # shape: (#nodes, )
possible_indices = (log_likelihoods != -torch.inf).nonzero()[0] # shape: (#dirac-nodes, )
filtered_likelihood = log_likelihoods[possible_indices]
locations = torch.full_like(filtered_likelihood, value)
layer = DiracDeltaLayer(self.variable, locations, torch.exp(filtered_likelihood))
return layer, log_likelihoods

@abstractmethod
def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, torch.Tensor]:
"""
Calculate the conditional distribution given a simple interval with p(interval) > 0.
Expand All @@ -76,27 +80,36 @@ def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tupl
:param interval: The simple interval
:return: The conditional distribution and the log-probability of the interval.
"""
# form the intersection per node
intersection = [interval.intersection_with(node_interval.simple_sets[0]) for node_interval in
self.univariate_support_per_node]
singletons = [simple_interval for simple_interval in intersection if simple_interval.is_singleton()]
non_singletons = [simple_interval for simple_interval in intersection if not simple_interval.is_singleton()]


def log_conditional_from_non_singleton_simple_interval(self, interval: SimpleInterval) -> Tuple[SumLayer, float]:
"""
Calculate the conditional distribution given a non-singleton, simple interval with p(interval) > 0.
:param interval: The simple interval.
:return: The conditional distribution and the log-probability of the interval.
"""
raise NotImplementedError

def log_conditional_from_interval(self, interval) -> Tuple[Self, float]:
def log_conditional_from_interval(self, interval: Interval) -> Tuple[SumLayer, torch.Tensor]:
"""
Calculate the conditional distribution given an interval with p(interval) > 0.
:param interval: The simple interval
:return: The conditional distribution and the log-probability of the interval.
"""
results = [self.log_conditional_from_simple_interval(simple_interval) for simple_interval in
interval.simple_sets]
input_layer = results[0][0]
input_layer.merge_with([layer for layer, _ in results[1:]])
log_weights = torch.full((self.number_of_nodes, input_layer.number_of_nodes), -torch.inf)

for i, (layer, log_likelihood) in enumerate(results):
log_likelihood = log_likelihood.squeeze()
indices = (log_likelihood != -torch.inf).nonzero().squeeze()
print(indices)
log_weights[indices, i] = log_likelihood[log_likelihood != -torch.inf].float()

resulting_layer = SumLayer([input_layer], [log_weights])
return resulting_layer, resulting_layer.log_normalization_constants

@abstractmethod
def merge_with(self, others: List[Self]):
"""
Merge this layer with another layer inplace.
:param others: The other layers
"""
raise NotImplementedError


Expand Down Expand Up @@ -154,7 +167,6 @@ def included_condition(self, x: torch.Tensor) -> torch.Tensor:


class DiracDeltaLayer(ContinuousLayer):

location: torch.Tensor
"""
The locations of the Dirac delta distributions.
Expand All @@ -175,6 +187,10 @@ def validate(self):
assert self.location.shape == self.density_cap.shape, "The shapes of the location and density cap must match."
assert all(self.density_cap > 0), "The density cap must be positive."

@property
def number_of_nodes(self) -> int:
return len(self.location)

def log_likelihood(self, x: torch.Tensor) -> torch.Tensor:
return torch.where(x == self.location, torch.log(self.density_cap), -torch.inf)

Expand All @@ -196,4 +212,3 @@ def log_mode(self) -> Tuple[Event, float]:

def sample(self, amount: int) -> torch.Tensor:
pass

4 changes: 2 additions & 2 deletions src/probabilistic_model/learning/torch/pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def validate(self):
raise NotImplementedError

@property
@abstractmethod
def number_of_nodes(self) -> int:
"""
The number of nodes in the layer.
Expand Down Expand Up @@ -174,7 +175,7 @@ def log_conditional(self, event: Event) -> Tuple[Optional[Layer], float]:
raise NotImplementedError

@abstractmethod
def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[Optional[Self], torch.Tensor]:
def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[Optional[AnnotatedLayer], torch.Tensor]:
raise NotImplementedError


Expand Down Expand Up @@ -509,7 +510,6 @@ def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[Optional[
pass



@dataclass
class AnnotatedLayer:
layer: Layer
Expand Down
29 changes: 19 additions & 10 deletions src/probabilistic_model/learning/torch/uniform_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from typing import Tuple, Optional, Union, Type, List

import random_events.interval
from random_events.interval import SimpleInterval, Bound
from typing_extensions import Self

import torch
Expand All @@ -25,6 +28,9 @@ class UniformLayer(ContinuousLayerWithFiniteSupport):
The index of the variable that this layer represents.
"""

def merge_with(self, others: List[Self]):
self.interval = torch.cat([self.interval] + [other.interval for other in others])

def __init__(self, variable: Continuous, interval: torch.Tensor):
"""
Initialize the uniform layer.
Expand All @@ -41,15 +47,6 @@ def cdf(self, x: torch.Tensor) -> torch.Tensor:
result = torch.clamp(result, 0, 1)
return result

def log_mode(self) -> Tuple[Event, float]:
pass

def log_conditional(self, event: Event) -> Tuple[Optional[Union[ProbabilisticModel, Self]], float]:
pass

def sample(self, amount: int) -> torch.Tensor:
pass

@classmethod
def original_class(cls) -> Tuple[Type, ...]:
return UniformDistribution,
Expand Down Expand Up @@ -79,9 +76,21 @@ def log_likelihood(self, x: torch.Tensor) -> torch.Tensor:
"""
return torch.where(self.included_condition(x), self.log_pdf_value(), -torch.inf)

def log_conditional_of_simple_event(self, event: SimpleEvent) -> Tuple[Optional[Self], torch.Tensor]:
def log_mode(self) -> Tuple[Event, float]:
pass

def sample(self, amount: int) -> torch.Tensor:
pass

def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, torch.Tensor]:
probabilities = self.probability_of_simple_event(SimpleEvent({self.variable: interval}))
intersections = [interval.intersection_with(SimpleInterval(lower.item(), upper.item(),
Bound.OPEN, Bound.OPEN))
for lower, upper in self.interval]
return self.__class__(self.variable, torch.stack([simple_interval_to_open_tensor(intersection)
for intersection in intersections
if not intersection.is_empty()])), probabilities

@classmethod
def create_layer_from_nodes_with_same_type_and_scope(cls, nodes: List[UniformUnit],
child_layers: List[AnnotatedLayer]) -> \
Expand Down
1 change: 1 addition & 0 deletions test/test_torch/test_dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ def test_support_per_node(self):
SimpleEvent({self.x: singleton(1)}).as_composite_set()]
self.assertEqual(support, result)


if __name__ == '__main__':
unittest.main()
26 changes: 26 additions & 0 deletions test/test_torch/test_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from random_events.variable import Continuous
from torch.testing import assert_close

from probabilistic_model.learning.torch import SumLayer
from probabilistic_model.learning.torch.uniform_layer import UniformLayer


Expand Down Expand Up @@ -42,6 +43,31 @@ def test_support_per_node(self):
SimpleEvent({self.x: open(1, 3)}).as_composite_set()]
self.assertEqual(support, result)

def test_conditional_singleton(self):
event = SimpleEvent({self.x: closed(0.5, 0.5)})
layer, ll = self.p_x.log_conditional_of_simple_event(event)
self.assertEqual(layer.number_of_nodes, 1)
assert_close(torch.tensor([0.5]), layer.location)
assert_close(torch.tensor([1.]), layer.density_cap)

def test_conditional_single_truncation(self):
event = SimpleEvent({self.x: closed(0.5, 2.5)})
layer, ll = self.p_x.log_conditional_of_simple_event(event)
self.assertEqual(layer.number_of_nodes, 2)
assert_close(layer.interval, torch.tensor([[0.5, 1], [1, 2.5]]))
assert_close(torch.tensor([0.5, 0.75]).reshape(-1, 1).double(), ll)

def test_conditional_multiple_truncation(self):
event = closed(-1, 0.5) | closed(0.7, 0.8) | closed(2., 3.) | closed(3., 4.)
layer, ll = self.p_x.log_conditional_from_interval(event)
self.assertIsInstance(layer, SumLayer)
layer.validate()
self.assertEqual(layer.number_of_nodes, 2)
self.assertEqual(len(layer.child_layers), 1)


assert_close(torch.tensor([0.5, 0.75]).reshape(-1, 1).double(), ll)


if __name__ == '__main__':
unittest.main()

0 comments on commit a9edb8b

Please sign in to comment.