diff --git a/src/probabilistic_model/probabilistic_circuit/nx/probabilistic_circuit.py b/src/probabilistic_model/probabilistic_circuit/nx/probabilistic_circuit.py index 3d69d58..e6c313d 100644 --- a/src/probabilistic_model/probabilistic_circuit/nx/probabilistic_circuit.py +++ b/src/probabilistic_model/probabilistic_circuit/nx/probabilistic_circuit.py @@ -955,9 +955,10 @@ def simplify(self): # if this has only one child if len(self.subcircuits) == 1: - # connect the parents of this with the children of this - for parent in self.probabilistic_circuit.predecessors(self): - self.probabilistic_circuit.add_edge(parent, self.subcircuits[0]) + # redirect every incoming edge to the child + incoming_edges = list(self.probabilistic_circuit.in_edges(self, data=True)) + for parent, _, data in incoming_edges: + self.probabilistic_circuit.add_edge(parent, self.subcircuits[0], **data) # remove this node self.probabilistic_circuit.remove_node(self) @@ -1080,7 +1081,7 @@ def sample(self, amount: int) -> np.array: def moment(self, order: OrderType, center: CenterType) -> MomentType: return self.root.moment(order, center) - def simplify(self): + def simplify(self) -> Self: """ Simplify the circuit inplace. """ @@ -1088,6 +1089,7 @@ def simplify(self): for layer in reversed(bfs_layers): for node in layer: node.simplify() + return self @property def support(self) -> Event: diff --git a/test/test_bayesian_network/test_bayesian_network.py b/test/test_bayesian_network/test_bayesian_network.py index 85a58a7..26419ad 100644 --- a/test/test_bayesian_network/test_bayesian_network.py +++ b/test/test_bayesian_network/test_bayesian_network.py @@ -168,7 +168,8 @@ def plot(self): plt.show() def test_as_probabilistic_circuit(self): - circuit = self.bayesian_network.as_probabilistic_circuit().simplify() + circuit = self.bayesian_network.as_probabilistic_circuit() + circuit.simplify() self.assertEqual(circuit.probability(circuit.universal_simple_event().as_composite_set()), 1.) event = SimpleEvent({self.x: Set(XEnum.ZERO, XEnum(1)), self.y: closed(1.5, 2)}) self.assertAlmostEqual(0.075, circuit.probability(event.as_composite_set())) diff --git a/test/test_jpt/test_jpt.py b/test/test_jpt/test_jpt.py index cfd83ae..9fab6bc 100644 --- a/test/test_jpt/test_jpt.py +++ b/test/test_jpt/test_jpt.py @@ -466,7 +466,8 @@ def test_to_bayesian_network(self): complex_event = SimpleEvent( {self.species: singleton(0), self.sl: closed(4.5, 5.5)}).as_composite_set() pc = bayesian_network.as_probabilistic_circuit() - pc_m = pc.marginal([v for v in pc.variables if not v.name.endswith(".latent")]).simplify() + pc_m = pc.marginal([v for v in pc.variables if not v.name.endswith(".latent")]) + pc_m = pc_m.simplify() self.assertEqual(pc_m.variables, SortedSet([self.pl, self.pw, self.sl, self.sw, self.species])) self.assertAlmostEqual(pc_m.probability(complex_event), 0.2333333) diff --git a/test/test_nx/test_probabilistic_circuit.py b/test/test_nx/test_probabilistic_circuit.py index 64bc64c..cf2e143 100644 --- a/test/test_nx/test_probabilistic_circuit.py +++ b/test/test_nx/test_probabilistic_circuit.py @@ -293,13 +293,10 @@ def setUp(self): next_model.mount_with_interaction_terms(model, transition_model) self.model = next_model - @unittest.skip("This test is not working since the caching removal.") def test_simplify(self): - simplified = self.model.probabilistic_circuit.simplify().root - print(self.model.probabilistic_circuit) - print(simplified.probabilistic_circuit) - self.assertEqual(len(simplified.probabilistic_circuit.nodes()), len(self.model.probabilistic_circuit.nodes)) - self.assertEqual(len(simplified.probabilistic_circuit.edges()), len(self.model.probabilistic_circuit.edges)) + simplified = self.model.probabilistic_circuit.__copy__().simplify() + self.assertEqual(len(simplified.nodes()), len(self.model.probabilistic_circuit.nodes)) + self.assertEqual(len(simplified.edges()), len(self.model.probabilistic_circuit.edges)) def test_sample_not_equal(self): samples = self.model.sample(10)