Skip to content

Commit

Permalink
Added simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Oct 18, 2024
1 parent ffa2aa2 commit fa3f4db
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ def log_conditional_of_simple_event(self, event: SimpleEvent,

return result, log_prob

# @cache_inference_result
def simplify(self) -> Self:
return self.__copy__()
...

def empty_copy(self) -> Self:
return self.__copy__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def __hash__(self):
return id(self)

def __eq__(self, other):
# TODO isnt hash equal here better?
return isinstance(other, self.__class__) and self.subcircuits == other.subcircuits

def __copy__(self):
Expand All @@ -345,8 +344,8 @@ def empty_copy(self) -> Self:

def simplify(self) -> Self:
"""
Simplify the circuit by removing nodes and redirected edges that have no impact.
Essentially, this method transform the circuit into an alternating order of sum and product units.
Simplify the circuit by removing nodes and redirected edges that have no impact in-place.
Essentially, this method transforms the circuit into an alternating order of sum and product units.
:return: The simplified circuit.
"""
Expand Down Expand Up @@ -669,45 +668,45 @@ def mount_from_bayesian_network(self, other: 'SumUnit'):
proxy_product_node.add_subcircuit(own_subcircuit)
proxy_product_node.add_subcircuit(other_subcircuit)

@cache_inference_result
def simplify(self) -> Self:
# TODO check with multiple parents
def simplify(self):

# if this has only one child
if len(self.subcircuits) == 1:
return self.subcircuits[0].simplify()

# create empty copy
result = self.empty_copy()
# redirect every incoming edge to the child
incoming_edges = list(self.probabilistic_circuit.in_edges(self, data=True))
for parent, _, data in incoming_edges:
self.probabilistic_circuit.add_edge(parent, self.subcircuits[0], **data)

# remove this node
self.probabilistic_circuit.remove_node(self)

return

# for every subcircuit
for weight, subcircuit in self.weighted_subcircuits:

# if the weight is 0, skip this subcircuit
if weight == 0:
continue

# simplify the subcircuit
simplified_subcircuit = subcircuit.simplify()
# remove the edge
self.probabilistic_circuit.remove_edge(self, subcircuit)

# if the simplified subcircuit is of the same type as this
if type(simplified_subcircuit) is type(self):
if type(subcircuit) is type(self):

# type hinting
simplified_subcircuit: Self
subcircuit: Self

# mount the children of that circuit directly
for sub_weight, sub_subcircuit in simplified_subcircuit.weighted_subcircuits:
for sub_weight, sub_subcircuit in subcircuit.weighted_subcircuits:
new_weight = sub_weight * weight
if new_weight > 0:
result.add_subcircuit(sub_subcircuit, new_weight)

# if this cannot be simplified
else:
# add an edge to that subcircuit
self.add_subcircuit(sub_subcircuit, new_weight, mount=False)

# mount the simplified subcircuit
result.add_subcircuit(simplified_subcircuit, weight)
# remove the old node
self.probabilistic_circuit.remove_node(subcircuit)

return result

def normalize(self):
"""
Expand Down Expand Up @@ -952,37 +951,31 @@ def __copy__(self):
result.add_subcircuit(copied_subcircuit)
return result

@cache_inference_result
def simplify(self) -> Self:
def simplify(self):

# if this has only one child
if len(self.subcircuits) == 1:
return self.subcircuits[0].simplify()
# connect the parents of this with the children of this
for parent in self.probabilistic_circuit.predecessors(self):
self.probabilistic_circuit.add_edge(parent, self.subcircuits[0])

# create empty copy
result = self.empty_copy()
# remove this node
self.probabilistic_circuit.remove_node(self)
return

# for every subcircuit
for subcircuit in self.subcircuits:

# simplify the subcircuit
simplified_subcircuit = subcircuit.simplify()

# if the simplified subcircuit is of the same type as this
if type(simplified_subcircuit) is type(self):
if type(subcircuit) is type(self):

# type hinting
simplified_subcircuit: Self
subcircuit: Self

# mount the children of that circuit directly
for sub_subcircuit in simplified_subcircuit.subcircuits:
result.add_subcircuit(sub_subcircuit)
for sub_subcircuit in subcircuit.subcircuits:
subcircuit.add_subcircuit(sub_subcircuit, mount=False)

# if this cannot be simplified
else:
# mount the simplified subcircuit
result.add_subcircuit(simplified_subcircuit)

return result


class ProbabilisticCircuit(ProbabilisticModel, nx.DiGraph, SubclassJSONSerializer):
Expand Down Expand Up @@ -1087,9 +1080,14 @@ def sample(self, amount: int) -> np.array:
def moment(self, order: OrderType, center: CenterType) -> MomentType:
return self.root.moment(order, center)

@graph_inference_caching_wrapper
def simplify(self) -> Self:
return self.root.simplify().probabilistic_circuit
def simplify(self):
"""
Simplify the circuit inplace.
"""
bfs_layers = list(nx.bfs_layers(self, self.root))
for layer in reversed(bfs_layers):
for node in layer:
node.simplify()

@property
def support(self) -> Event:
Expand All @@ -1109,6 +1107,18 @@ def is_decomposable(self) -> bool:
def __eq__(self, other: 'ProbabilisticCircuit'):
return self.root == other.root

def __copy__(self):
result = self.__class__()
new_node_map = {node: node.__copy__() for node in self.nodes}
result.add_nodes_from(new_node_map.values())
new_unweighted_edges = [(new_node_map[source], new_node_map[target]) for source, target in self.unweighted_edges]
new_weighted_edges = [(new_node_map[source], new_node_map[target], weight)
for source, target, weight in self.weighted_edges]
result.add_edges_from(new_unweighted_edges)
result.add_weighted_edges_from(new_weighted_edges)
return result


def to_json(self) -> Dict[str, Any]:

# get super result
Expand Down
21 changes: 13 additions & 8 deletions test/test_nx/test_probabilistic_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,6 @@ def test_plot_non_deterministic(self):
self.assertGreater(len(traces), 0)
# go.Figure(mixture.plot(), mixture.plotly_layout()).show()

def test_simplify(self):
simplified = self.model.simplify()
self.assertEqual(len(simplified.probabilistic_circuit.nodes()), 7)
self.assertEqual(len(simplified.probabilistic_circuit.edges()), 6)


class ComplexMountedInferenceTestCase(unittest.TestCase):
x: Continuous = Continuous("x")
Expand Down Expand Up @@ -556,9 +551,19 @@ def test_conditioning_with_multiple_parents_no_orphans(self):
self.assertEqual(len(list(conditional.edges())), 6)
self.assertEqual(event.simple_sets[0][self.z], conditional.support.simple_sets[0][self.z])

def test_conditioning_with_orphans(self):
self.model.plot_structure()
plt.show()
def test_copy(self):
copy = self.model.__copy__()
self.assertEqual(self.model, copy)
copy.root.add_subcircuit(UniformDistribution(self.x, SimpleInterval(0, 1)), 0.5)
self.assertNotEqual(self.model, copy)

def test_simplify(self):
event = SimpleEvent({self.y: open_closed(0, 0.25) | closed(0.5, 0.75)}).as_composite_set()
conditional, prob = self.model.conditional(event)
conditional.simplify()
self.assertEqual(len(list(conditional.nodes())), 6)
self.assertEqual(len(list(conditional.edges())), 5)



class SmallCircuitTestCast(unittest.TestCase):
Expand Down

0 comments on commit fa3f4db

Please sign in to comment.