Skip to content

Commit

Permalink
Added function to check determinism
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Mar 19, 2024
1 parent e8c0fcc commit 92dfb36
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple, Iterable, TYPE_CHECKING

import networkx as nx
import numpy as np
import portion
from random_events.events import EncodedEvent, VariableMap, Event, ComplexEvent
from random_events.variables import Variable, Symbolic, Continuous
Expand Down Expand Up @@ -165,6 +166,12 @@ def mount(self, other: 'ProbabilisticCircuitMixin'):
def cache_result(self) -> bool:
return self._cache_result

def is_deterministic(self) -> bool:
"""
:return: Rather this node is deterministic or not.
"""
raise NotImplementedError

@cache_result.setter
def cache_result(self, value: bool):
"""
Expand Down Expand Up @@ -452,24 +459,15 @@ def plot_2d(self, sample_amount: int = 5000) -> List[go.Scatter]:
traces.append(go.Scatter(x=[expectation[self.variables[0]]], y=[expectation[self.variables[1]]],
mode="markers", name="Expectation"))

mode_trace = None
mode_traces = None
try:
x_mode_trace = []
y_mode_trace = []
modes, _ = self.mode()
for mode in modes.events:
for x_mode in mode[self.variables[0]]:
for y_mode in mode[self.variables[1]]:
x_mode_trace.extend([x_mode.lower, x_mode.upper, x_mode.upper, x_mode.lower, x_mode.lower, None])
y_mode_trace.extend([y_mode.lower, y_mode.lower, y_mode.upper, y_mode.upper, y_mode.lower, None])
x_mode_trace.extend([x_mode.lower, x_mode.upper, x_mode.upper, x_mode.lower, x_mode.lower, None])
y_mode_trace.extend([y_mode.lower, y_mode.lower, y_mode.upper, y_mode.upper, y_mode.lower, None])
mode_trace = go.Scatter(x=x_mode_trace, y=y_mode_trace, mode="lines+markers", name="Mode", fill="toself")
mode_traces = modes.plot()
except NotImplementedError:
...

if mode_trace:
traces.append(mode_trace)
if mode_traces:
traces.extend(mode_traces)

return traces

Expand Down Expand Up @@ -833,6 +831,19 @@ def normalize(self):
for subcircuit in self.subcircuits:
self.probabilistic_circuit.edges[self, subcircuit]["weight"] /= total_weight

def is_deterministic(self) -> bool:

# for every unique combination of subcircuits
for index, subcircuit in enumerate(self.subcircuits):
for subcircuit_ in self.subcircuits[index+1:]:

# if they intersect, the sum is not deterministic
if not subcircuit_.domain.intersection(subcircuit.domain).is_empty():
return False

# if none intersect, the subcircuit is deterministic
return True


class DeterministicSumUnit(SmoothSumUnit):
"""
Expand Down Expand Up @@ -892,6 +903,9 @@ def sub_circuit_index_of_sample(self, sample: Iterable) -> Optional[int]:
return index
return None

def is_deterministic(self) -> bool:
return True


class DecomposableProductUnit(ProbabilisticCircuitMixin):
"""
Expand Down Expand Up @@ -952,6 +966,9 @@ def _probability(self, event: EncodedEvent) -> float:

return result

def is_deterministic(self) -> bool:
return True

@cache_inference_result
def _mode(self) -> Tuple[ComplexEvent, float]:

Expand Down Expand Up @@ -1277,8 +1294,6 @@ def _from_json(cls, data: Dict[str, Any]) -> Self:

return result



def update_variables(self, new_variables: VariableMap):
"""
Update the variables of this unit and its descendants.
Expand Down Expand Up @@ -1326,3 +1341,9 @@ def plot(self):

def plotly_layout(self):
return self.root.plotly_layout()

def is_deterministic(self) -> bool:
"""
:return: Rather this circuit is deterministic or not.
"""
return all(node.is_deterministic() for node in self.nodes if isinstance(node, SmoothSumUnit))
15 changes: 12 additions & 3 deletions test/test_probabilistic_circuits/test_graph_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def test_sample_not_equal(self):
same_samples = [s for s in samples if s == sample]
self.assertEqual(len(same_samples), 1)

def test_determinism(self):
self.assertTrue(self.model.is_deterministic())


class SumUnitTestCase(unittest.TestCase, ShowMixin):
x: Continuous = Continuous("x")
Expand Down Expand Up @@ -234,6 +237,9 @@ def test_sample_not_equal(self):
same_samples = [s for s in samples if s == sample]
self.assertEqual(len(same_samples), 1)

def test_determinism(self):
self.assertTrue(self.model.is_deterministic())


class MinimalGraphCircuitTestCase(unittest.TestCase, ShowMixin):
integer = Integer("integer", (1, 2, 4))
Expand Down Expand Up @@ -368,6 +374,9 @@ def test_sample_not_equal(self):
same_samples = [s for s in samples if s == sample]
self.assertEqual(len(same_samples), 1)

def test_determinism(self):
self.assertFalse(self.model.is_deterministic())


class FactorizationTestCase(unittest.TestCase, ShowMixin):
x: Continuous = Continuous("x")
Expand Down Expand Up @@ -480,7 +489,7 @@ def test_simplify(self):
def test_plot_2d(self):
traces = self.model.plot_2d()
assert len(traces) > 0
# go.Figure(traces, self.model.plotly_layout()).show()
# go.Figure(traces, self.model.plotly_layout()).show()


class ComplexMountedInferenceTestCase(unittest.TestCase, ShowMixin):
Expand Down Expand Up @@ -557,8 +566,8 @@ def setUp(self):

def test_plot_2d(self):
traces = self.model.plot()
assert len(traces) > 0
# go.Figure(traces, self.model.plotly_layout()).show()
# assert len(traces) > 0
go.Figure(traces, self.model.plotly_layout()).show()


class ComplexInferenceTestCase(unittest.TestCase):
Expand Down

0 comments on commit 92dfb36

Please sign in to comment.