Skip to content

Commit

Permalink
Circuits are now more efficient to serialize and give the correct res…
Browse files Browse the repository at this point in the history
…ult when containing cycles.
  • Loading branch information
tomsch420 committed Mar 8, 2024
1 parent 417a809 commit 65a4a5e
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/probabilistic_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.3.5"
__version__ = "3.3.6"
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __hash__(self):
def simplify(self) -> Self:
return self.__copy__()

def empty_copy(self) -> Self:
return self.__copy__()


class ContinuousDistribution(UnivariateDistribution, PMContinuousDistribution, ProbabilisticCircuitMixin):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1139,12 +1139,44 @@ def __eq__(self, other: 'ProbabilisticCircuit'):
return self.root == other.root

def to_json(self) -> Dict[str, Any]:
return {**super().to_json(), "root": self.root.to_json()}

# get super result
result = super().to_json()

hash_to_node_map = dict()

for node in self.nodes:
node_json = node.empty_copy().to_json()
hash_to_node_map[hash(node)] = node_json

unweighted_edges = [(hash(source), hash(target)) for source, target
in self.unweighted_edges]
weighted_edges = [(hash(source), hash(target), weight)
for source, target, weight in self.weighted_edges]
result["hash_to_node_map"] = hash_to_node_map
result["unweighted_edges"] = unweighted_edges
result["weighted_edges"] = weighted_edges
return result

@classmethod
def _from_json(cls, data: Dict[str, Any]) -> Self:
root = ProbabilisticCircuitMixin.from_json(data["root"])
return root.probabilistic_circuit
result = ProbabilisticCircuit()
hash_remap: Dict[int, ProbabilisticCircuitMixin] = dict()

for hash_, node_data in data["hash_to_node_map"].items():
node = ProbabilisticCircuitMixin.from_json(node_data)
hash_remap[hash_] = node
result.add_node(node)

for source_hash, target_hash in data["unweighted_edges"]:
result.add_edge(hash_remap[source_hash], hash_remap[target_hash])

for source_hash, target_hash, weight in data["weighted_edges"]:
result.add_edge(hash_remap[source_hash], hash_remap[target_hash], weight=weight)

return result



def update_variables(self, new_variables: VariableMap):
"""
Expand Down
10 changes: 10 additions & 0 deletions test/test_probabilistic_circuits/test_graph_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,16 @@ def test_sample_not_equal(self):
same_samples = [s for s in samples if s == sample]
self.assertEqual(len(same_samples), 1)

def test_serialization(self):
model = self.model.probabilistic_circuit
serialized_model = model.to_json()
deserialized_model = ProbabilisticCircuit.from_json(serialized_model)
self.assertIsInstance(deserialized_model, ProbabilisticCircuit)
self.assertEqual(len(model.nodes), len(deserialized_model.nodes))
self.assertEqual(len(model.edges), len(deserialized_model.edges))
event = Event({self.x: portion.closed(-1, 1), self.y: portion.closed(-1, 1)})
self.assertEqual(model.probability(event), deserialized_model.probability(event))


class NormalizationTestCase(unittest.TestCase):
x: Continuous = Continuous("x")
Expand Down

0 comments on commit 65a4a5e

Please sign in to comment.