Skip to content

Commit

Permalink
Fixed bug in simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Oct 18, 2024
1 parent fa3f4db commit 13ce073
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -955,9 +955,10 @@ def simplify(self):

# if this has only one child
if len(self.subcircuits) == 1:
# connect the parents of this with the children of this
for parent in self.probabilistic_circuit.predecessors(self):
self.probabilistic_circuit.add_edge(parent, self.subcircuits[0])
# redirect every incoming edge to the child
incoming_edges = list(self.probabilistic_circuit.in_edges(self, data=True))
for parent, _, data in incoming_edges:
self.probabilistic_circuit.add_edge(parent, self.subcircuits[0], **data)

# remove this node
self.probabilistic_circuit.remove_node(self)
Expand Down Expand Up @@ -1080,14 +1081,15 @@ def sample(self, amount: int) -> np.array:
def moment(self, order: OrderType, center: CenterType) -> MomentType:
return self.root.moment(order, center)

def simplify(self):
def simplify(self) -> Self:
"""
Simplify the circuit inplace.
"""
bfs_layers = list(nx.bfs_layers(self, self.root))
for layer in reversed(bfs_layers):
for node in layer:
node.simplify()
return self

@property
def support(self) -> Event:
Expand Down
3 changes: 2 additions & 1 deletion test/test_bayesian_network/test_bayesian_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def plot(self):
plt.show()

def test_as_probabilistic_circuit(self):
circuit = self.bayesian_network.as_probabilistic_circuit().simplify()
circuit = self.bayesian_network.as_probabilistic_circuit()
circuit.simplify()
self.assertEqual(circuit.probability(circuit.universal_simple_event().as_composite_set()), 1.)
event = SimpleEvent({self.x: Set(XEnum.ZERO, XEnum(1)), self.y: closed(1.5, 2)})
self.assertAlmostEqual(0.075, circuit.probability(event.as_composite_set()))
Expand Down
3 changes: 2 additions & 1 deletion test/test_jpt/test_jpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,8 @@ def test_to_bayesian_network(self):
complex_event = SimpleEvent(
{self.species: singleton(0), self.sl: closed(4.5, 5.5)}).as_composite_set()
pc = bayesian_network.as_probabilistic_circuit()
pc_m = pc.marginal([v for v in pc.variables if not v.name.endswith(".latent")]).simplify()
pc_m = pc.marginal([v for v in pc.variables if not v.name.endswith(".latent")])
pc_m = pc_m.simplify()
self.assertEqual(pc_m.variables, SortedSet([self.pl, self.pw, self.sl, self.sw, self.species]))

self.assertAlmostEqual(pc_m.probability(complex_event), 0.2333333)
Expand Down
9 changes: 3 additions & 6 deletions test/test_nx/test_probabilistic_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,10 @@ def setUp(self):
next_model.mount_with_interaction_terms(model, transition_model)
self.model = next_model

@unittest.skip("This test is not working since the caching removal.")
def test_simplify(self):
simplified = self.model.probabilistic_circuit.simplify().root
print(self.model.probabilistic_circuit)
print(simplified.probabilistic_circuit)
self.assertEqual(len(simplified.probabilistic_circuit.nodes()), len(self.model.probabilistic_circuit.nodes))
self.assertEqual(len(simplified.probabilistic_circuit.edges()), len(self.model.probabilistic_circuit.edges))
simplified = self.model.probabilistic_circuit.__copy__().simplify()
self.assertEqual(len(simplified.nodes()), len(self.model.probabilistic_circuit.nodes))
self.assertEqual(len(simplified.edges()), len(self.model.probabilistic_circuit.edges))

def test_sample_not_equal(self):
samples = self.model.sample(10)
Expand Down

0 comments on commit 13ce073

Please sign in to comment.