From 92dfb3629a735ac813b4577285288eabbe71be6d Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Tue, 19 Mar 2024 17:25:04 +0100 Subject: [PATCH] Added function to check determinism --- .../probabilistic_circuit.py | 51 +++++++++++++------ .../test_graph_circuit.py | 15 ++++-- 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py b/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py index 018400b..b3ef32a 100644 --- a/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py +++ b/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py @@ -5,6 +5,7 @@ from typing import Tuple, Iterable, TYPE_CHECKING import networkx as nx +import numpy as np import portion from random_events.events import EncodedEvent, VariableMap, Event, ComplexEvent from random_events.variables import Variable, Symbolic, Continuous @@ -165,6 +166,12 @@ def mount(self, other: 'ProbabilisticCircuitMixin'): def cache_result(self) -> bool: return self._cache_result + def is_deterministic(self) -> bool: + """ + :return: Rather this node is deterministic or not. + """ + raise NotImplementedError + @cache_result.setter def cache_result(self, value: bool): """ @@ -452,24 +459,15 @@ def plot_2d(self, sample_amount: int = 5000) -> List[go.Scatter]: traces.append(go.Scatter(x=[expectation[self.variables[0]]], y=[expectation[self.variables[1]]], mode="markers", name="Expectation")) - mode_trace = None + mode_traces = None try: - x_mode_trace = [] - y_mode_trace = [] modes, _ = self.mode() - for mode in modes.events: - for x_mode in mode[self.variables[0]]: - for y_mode in mode[self.variables[1]]: - x_mode_trace.extend([x_mode.lower, x_mode.upper, x_mode.upper, x_mode.lower, x_mode.lower, None]) - y_mode_trace.extend([y_mode.lower, y_mode.lower, y_mode.upper, y_mode.upper, y_mode.lower, None]) - x_mode_trace.extend([x_mode.lower, x_mode.upper, x_mode.upper, x_mode.lower, x_mode.lower, None]) - y_mode_trace.extend([y_mode.lower, y_mode.lower, y_mode.upper, y_mode.upper, y_mode.lower, None]) - mode_trace = go.Scatter(x=x_mode_trace, y=y_mode_trace, mode="lines+markers", name="Mode", fill="toself") + mode_traces = modes.plot() except NotImplementedError: ... - if mode_trace: - traces.append(mode_trace) + if mode_traces: + traces.extend(mode_traces) return traces @@ -833,6 +831,19 @@ def normalize(self): for subcircuit in self.subcircuits: self.probabilistic_circuit.edges[self, subcircuit]["weight"] /= total_weight + def is_deterministic(self) -> bool: + + # for every unique combination of subcircuits + for index, subcircuit in enumerate(self.subcircuits): + for subcircuit_ in self.subcircuits[index+1:]: + + # if they intersect, the sum is not deterministic + if not subcircuit_.domain.intersection(subcircuit.domain).is_empty(): + return False + + # if none intersect, the subcircuit is deterministic + return True + class DeterministicSumUnit(SmoothSumUnit): """ @@ -892,6 +903,9 @@ def sub_circuit_index_of_sample(self, sample: Iterable) -> Optional[int]: return index return None + def is_deterministic(self) -> bool: + return True + class DecomposableProductUnit(ProbabilisticCircuitMixin): """ @@ -952,6 +966,9 @@ def _probability(self, event: EncodedEvent) -> float: return result + def is_deterministic(self) -> bool: + return True + @cache_inference_result def _mode(self) -> Tuple[ComplexEvent, float]: @@ -1277,8 +1294,6 @@ def _from_json(cls, data: Dict[str, Any]) -> Self: return result - - def update_variables(self, new_variables: VariableMap): """ Update the variables of this unit and its descendants. @@ -1326,3 +1341,9 @@ def plot(self): def plotly_layout(self): return self.root.plotly_layout() + + def is_deterministic(self) -> bool: + """ + :return: Rather this circuit is deterministic or not. + """ + return all(node.is_deterministic() for node in self.nodes if isinstance(node, SmoothSumUnit)) diff --git a/test/test_probabilistic_circuits/test_graph_circuit.py b/test/test_probabilistic_circuits/test_graph_circuit.py index 8b9e30b..3ed1ae1 100644 --- a/test/test_probabilistic_circuits/test_graph_circuit.py +++ b/test/test_probabilistic_circuits/test_graph_circuit.py @@ -124,6 +124,9 @@ def test_sample_not_equal(self): same_samples = [s for s in samples if s == sample] self.assertEqual(len(same_samples), 1) + def test_determinism(self): + self.assertTrue(self.model.is_deterministic()) + class SumUnitTestCase(unittest.TestCase, ShowMixin): x: Continuous = Continuous("x") @@ -234,6 +237,9 @@ def test_sample_not_equal(self): same_samples = [s for s in samples if s == sample] self.assertEqual(len(same_samples), 1) + def test_determinism(self): + self.assertTrue(self.model.is_deterministic()) + class MinimalGraphCircuitTestCase(unittest.TestCase, ShowMixin): integer = Integer("integer", (1, 2, 4)) @@ -368,6 +374,9 @@ def test_sample_not_equal(self): same_samples = [s for s in samples if s == sample] self.assertEqual(len(same_samples), 1) + def test_determinism(self): + self.assertFalse(self.model.is_deterministic()) + class FactorizationTestCase(unittest.TestCase, ShowMixin): x: Continuous = Continuous("x") @@ -480,7 +489,7 @@ def test_simplify(self): def test_plot_2d(self): traces = self.model.plot_2d() assert len(traces) > 0 - # go.Figure(traces, self.model.plotly_layout()).show() + # go.Figure(traces, self.model.plotly_layout()).show() class ComplexMountedInferenceTestCase(unittest.TestCase, ShowMixin): @@ -557,8 +566,8 @@ def setUp(self): def test_plot_2d(self): traces = self.model.plot() - assert len(traces) > 0 - # go.Figure(traces, self.model.plotly_layout()).show() + # assert len(traces) > 0 + go.Figure(traces, self.model.plotly_layout()).show() class ComplexInferenceTestCase(unittest.TestCase):