From 65a4a5e51446266135af293e1657e2059d614af3 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Fri, 8 Mar 2024 10:24:53 +0100 Subject: [PATCH] Circuits are now more efficient to serialize and give the correct result when containing cycles. --- src/probabilistic_model/__init__.py | 2 +- .../distributions/distributions.py | 3 ++ .../probabilistic_circuit.py | 38 +++++++++++++++++-- .../test_graph_circuit.py | 10 +++++ 4 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/probabilistic_model/__init__.py b/src/probabilistic_model/__init__.py index c90ab1b..1549c12 100644 --- a/src/probabilistic_model/__init__.py +++ b/src/probabilistic_model/__init__.py @@ -1 +1 @@ -__version__ = "3.3.5" +__version__ = "3.3.6" diff --git a/src/probabilistic_model/probabilistic_circuit/distributions/distributions.py b/src/probabilistic_model/probabilistic_circuit/distributions/distributions.py index c60d871..8651ad8 100644 --- a/src/probabilistic_model/probabilistic_circuit/distributions/distributions.py +++ b/src/probabilistic_model/probabilistic_circuit/distributions/distributions.py @@ -37,6 +37,9 @@ def __hash__(self): def simplify(self) -> Self: return self.__copy__() + def empty_copy(self) -> Self: + return self.__copy__() + class ContinuousDistribution(UnivariateDistribution, PMContinuousDistribution, ProbabilisticCircuitMixin): diff --git a/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py b/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py index 30700ce..044d9f5 100644 --- a/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py +++ b/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py @@ -1139,12 +1139,44 @@ def __eq__(self, other: 'ProbabilisticCircuit'): return self.root == other.root def to_json(self) -> Dict[str, Any]: - return {**super().to_json(), "root": self.root.to_json()} + + # get super result + result = super().to_json() + + hash_to_node_map = dict() + + for node in self.nodes: + node_json = node.empty_copy().to_json() + hash_to_node_map[hash(node)] = node_json + + unweighted_edges = [(hash(source), hash(target)) for source, target + in self.unweighted_edges] + weighted_edges = [(hash(source), hash(target), weight) + for source, target, weight in self.weighted_edges] + result["hash_to_node_map"] = hash_to_node_map + result["unweighted_edges"] = unweighted_edges + result["weighted_edges"] = weighted_edges + return result @classmethod def _from_json(cls, data: Dict[str, Any]) -> Self: - root = ProbabilisticCircuitMixin.from_json(data["root"]) - return root.probabilistic_circuit + result = ProbabilisticCircuit() + hash_remap: Dict[int, ProbabilisticCircuitMixin] = dict() + + for hash_, node_data in data["hash_to_node_map"].items(): + node = ProbabilisticCircuitMixin.from_json(node_data) + hash_remap[hash_] = node + result.add_node(node) + + for source_hash, target_hash in data["unweighted_edges"]: + result.add_edge(hash_remap[source_hash], hash_remap[target_hash]) + + for source_hash, target_hash, weight in data["weighted_edges"]: + result.add_edge(hash_remap[source_hash], hash_remap[target_hash], weight=weight) + + return result + + def update_variables(self, new_variables: VariableMap): """ diff --git a/test/test_probabilistic_circuits/test_graph_circuit.py b/test/test_probabilistic_circuits/test_graph_circuit.py index 983c1fe..3ae91b1 100644 --- a/test/test_probabilistic_circuits/test_graph_circuit.py +++ b/test/test_probabilistic_circuits/test_graph_circuit.py @@ -510,6 +510,16 @@ def test_sample_not_equal(self): same_samples = [s for s in samples if s == sample] self.assertEqual(len(same_samples), 1) + def test_serialization(self): + model = self.model.probabilistic_circuit + serialized_model = model.to_json() + deserialized_model = ProbabilisticCircuit.from_json(serialized_model) + self.assertIsInstance(deserialized_model, ProbabilisticCircuit) + self.assertEqual(len(model.nodes), len(deserialized_model.nodes)) + self.assertEqual(len(model.edges), len(deserialized_model.edges)) + event = Event({self.x: portion.closed(-1, 1), self.y: portion.closed(-1, 1)}) + self.assertEqual(model.probability(event), deserialized_model.probability(event)) + class NormalizationTestCase(unittest.TestCase): x: Continuous = Continuous("x")