Skip to content

Commit

Permalink
Finished MPE and decorator.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Jan 23, 2024
1 parent 87b4b9d commit 37d895a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 60 deletions.
98 changes: 44 additions & 54 deletions src/probabilistic_model/graph_circuits/probabilistic_circuit.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import itertools
from typing import Tuple, Iterable

import networkx as nx
from random_events.events import EncodedEvent
from random_events.variables import Variable
from typing_extensions import List, Optional, Union, Any
from typing_extensions import List, Optional, Union, Any, Self

from ..probabilistic_model import ProbabilisticModel, ProbabilisticModelWrapper
import networkx as nx
from ..probabilistic_model import ProbabilisticModel, ProbabilisticModelWrapper, OrderType, CenterType, MomentType


class ProbabilisticCircuitMixin:
Expand Down Expand Up @@ -42,28 +42,41 @@ def edges_to_sub_circuits(self) -> List[Union['Edge', 'DirectedWeightedEdge']]:
"""
Return a list of targets to the children of this component.
"""
return [self.probabilistic_circuit[source][target]["edge"]
for source, target in self.probabilistic_circuit.out_edges(self)]
return [self.probabilistic_circuit[source][target]["edge"] for source, target in
self.probabilistic_circuit.out_edges(self)]

@property
def variables(self) -> Tuple[Variable]:
variables = set([variable for distribution in self.leaf_nodes() for variable in distribution.variables])
return tuple(sorted(variables))

def leaf_nodes(self) -> List[ProbabilisticModel]:
return [node for node in nx.descendants(self.probabilistic_circuit, self)
if self.probabilistic_circuit.out_degree(node) == 0]
return [node for node in nx.descendants(self.probabilistic_circuit, self) if
self.probabilistic_circuit.out_degree(node) == 0]

def reset_result_of_current_query(self):
"""
Reset the result of the current query recursively.
"""
self.result_of_current_query = None

for edge in self.edges_to_sub_circuits():
edge.target.reset_result_of_current_query()


def cache_inference_result(func):
"""
Decorator for caching the result of a function call in a 'ProbabilisticCircuitMixin' object.
"""

def wrapper(*args, **kwargs):
self: ProbabilisticCircuitMixin = args[0]
if self.result_of_current_query is None:
self.result_of_current_query = func(*args, **kwargs)
return self.result_of_current_query

return wrapper


class Component(ProbabilisticCircuitMixin, ProbabilisticModel):
"""
Class for non-leaf components in circuits.
Expand All @@ -76,36 +89,24 @@ def __init__(self):
class SmoothSumUnit(Component):
representation = "+"

@cache_inference_result
def _likelihood(self, event: Iterable) -> float:

# query cache
if self.result_of_current_query is not None:
return self.result_of_current_query

result = 0.

for edge in self.edges_to_sub_circuits():
result += edge.weight * edge.target._likelihood(event)

# update cache
self.result_of_current_query = result

return result

@cache_inference_result
def _probability(self, event: EncodedEvent) -> float:

# query cache
if self.result_of_current_query is not None:
return self.result_of_current_query

result = 0.

for edge in self.edges_to_sub_circuits():
result += edge.weight * edge.target._probability(event)

# update cache
self.result_of_current_query = result

return result


Expand Down Expand Up @@ -136,10 +137,6 @@ def merge_modes_if_one_dimensional(self, modes: List[EncodedEvent]) -> List[Enco

def _mode(self) -> Tuple[Iterable[EncodedEvent], float]:

# query cache
if self.result_of_current_query is not None:
return self.result_of_current_query

modes = []
likelihoods = []

Expand All @@ -160,10 +157,6 @@ def _mode(self) -> Tuple[Iterable[EncodedEvent], float]:
result.extend(mode)

modes = self.merge_modes_if_one_dimensional(result)

# update cache
self.result_of_current_query = (modes, maximum_likelihood)

return modes, maximum_likelihood


Expand All @@ -174,12 +167,9 @@ class DecomposableProductUnit(Component):

representation = "⊗"

@cache_inference_result
def _likelihood(self, event: Iterable) -> float:

# query cache
if self.result_of_current_query is not None:
return self.result_of_current_query

variables = self.variables

result = 1.
Expand All @@ -191,21 +181,14 @@ def _likelihood(self, event: Iterable) -> float:

result *= subcircuit._likelihood(partial_event)

# update cache
self.result_of_current_query = result

return result

@cache_inference_result
def _probability(self, event: EncodedEvent) -> float:

# query cache
if self.result_of_current_query is not None:
return self.result_of_current_query

result = 1.

for edge in self.edges_to_sub_circuits():

subcircuit = edge.target
subcircuit_variables = edge.target.variables

Expand All @@ -214,17 +197,10 @@ def _probability(self, event: EncodedEvent) -> float:
# construct partial event for child
result *= subcircuit._probability(subcircuit_event)

# update cache
self.result_of_current_query = result

return result

def _mode(self) -> Tuple[Iterable[EncodedEvent], float]:

# query cache
if self.result_of_current_query is not None:
return self.result_of_current_query

modes = []
resulting_likelihood = 1.

Expand All @@ -247,13 +223,9 @@ def _mode(self) -> Tuple[Iterable[EncodedEvent], float]:

result.append(mode)

# update cache
self.result_of_current_query = (result, resulting_likelihood)

return result, resulting_likelihood



class Edge:
"""
Class representing a directed edge in a probabilistic circuit.
Expand Down Expand Up @@ -339,6 +311,10 @@ def is_valid(self) -> bool:
return nx.is_directed_acyclic_graph(self) and nx.is_weakly_connected(self)

def add_node(self, component: ProbabilisticCircuitMixin, **attr):

if component in self.nodes():
return

component.probabilistic_circuit = self
component.id = max(node.id for node in self.nodes) + 1 if len(self.nodes) > 0 else 0
super().add_node(component, **attr)
Expand All @@ -353,6 +329,7 @@ def add_edge(self, edge: Edge, **kwargs):
if isinstance(edge.source, DecomposableProductUnit) and isinstance(edge, DirectedWeightedEdge):
raise ValueError(f"Product units can only have un-weighted edges. Got {type(edge)} instead.")

self.add_nodes_from([edge.source, edge.target])
super().add_edge(edge.source, edge.target, edge=edge, **kwargs)

def add_edges_from(self, edges: Iterable[Edge], **kwargs):
Expand Down Expand Up @@ -393,4 +370,17 @@ def _mode(self) -> Tuple[Iterable[EncodedEvent], float]:
root = self.root
result = self.root._mode()
root.reset_result_of_current_query()
return result
return result

def marginal(self, variables: Iterable[Variable]) -> Optional[Self]:
...

def _conditional(self, event: EncodedEvent) -> Tuple[Optional[Self], float]:
...

def sample(self, amount: int) -> Iterable:
...

def moment(self, order: OrderType, center: CenterType) -> MomentType:
...

26 changes: 20 additions & 6 deletions test/test_graph_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def setUp(self):

self.model = model

def show(self):
nx.draw(self.model, with_labels=True)
plt.show()

def test_setup(self):
node_ids = set()
for node in self.model.nodes():
Expand Down Expand Up @@ -108,6 +112,15 @@ def test_caching_reset(self):
for node in self.model.nodes():
self.assertIsNone(node.result_of_current_query)

def test_caching(self):
event = Event({self.real: portion.closed(0, 5),
self.real2: portion.closed(2, 5)})
_ = self.model.root.probability(event)

for node in self.model.nodes():
if not isinstance(node, LeafComponent):
self.assertIsNotNone(node.result_of_current_query)

def test_mode(self):
mode, likelihood = list(self.model.nodes)[2].mode()
self.assertEqual(likelihood, 0.5)
Expand All @@ -118,23 +131,24 @@ def test_mode_raising(self):
_ = self.model.mode()

def test_mode_with_product(self):
non_deterministic_node = list(self.model.nodes)[5]
non_deterministic_node = [node for node in self.model.nodes() if node.id == 5][0]

for descendant in nx.descendants(self.model, non_deterministic_node):
self.model.remove_node(descendant)

self.model.remove_node(non_deterministic_node)
new_node = LeafComponent(UniformDistribution(self.real2, portion.closed(2, 3)))
self.model.add_node(new_node)

nx.draw(self.model, with_labels=True)
plt.show()

new_edge = Edge(self.model.root, new_node)
self.model.add_edge(new_edge)


self.assertTrue(new_node in self.model.nodes())

mode, likelihood = self.model.mode()
self.assertEqual(likelihood, 0.5)
self.assertEqual(mode, [Event({self.real: portion.closed(0, 1),
self.real2: portion.closed(2, 3)})])


if __name__ == '__main__':
unittest.main()

0 comments on commit 37d895a

Please sign in to comment.