diff --git a/src/probabilistic_model/__init__.py b/src/probabilistic_model/__init__.py index 80014d0..8a6dd7c 100644 --- a/src/probabilistic_model/__init__.py +++ b/src/probabilistic_model/__init__.py @@ -1 +1 @@ -__version__ = "3.3.3" +__version__ = "3.3.4" diff --git a/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py b/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py index fe959f3..832064a 100644 --- a/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py +++ b/src/probabilistic_model/probabilistic_circuit/probabilistic_circuit.py @@ -724,6 +724,14 @@ def simplify(self) -> Self: return result + def normalize(self): + """ + Normalize the weights of the subcircuits such that they sum up to 1 inplace. + """ + total_weight = sum([weight for weight, _ in self.weighted_subcircuits]) + for subcircuit in self.subcircuits: + self.probabilistic_circuit.edges[self, subcircuit]["weight"] /= total_weight + class DeterministicSumUnit(SmoothSumUnit): """ diff --git a/test/test_probabilistic_circuits/test_graph_circuit.py b/test/test_probabilistic_circuits/test_graph_circuit.py index 582cdea..dc55f74 100644 --- a/test/test_probabilistic_circuits/test_graph_circuit.py +++ b/test/test_probabilistic_circuits/test_graph_circuit.py @@ -486,5 +486,20 @@ def test_simplify(self): self.assertEqual(len(simplified.probabilistic_circuit.nodes()), len(self.model.probabilistic_circuit.nodes)) self.assertEqual(len(simplified.probabilistic_circuit.edges()), len(self.model.probabilistic_circuit.edges)) + +class NormalizationTestCase(unittest.TestCase): + x: Continuous = Continuous("x") + + def test_normalization(self): + u1 = UniformDistribution(self.x, portion.closed(0, 1)) + u2 = UniformDistribution(self.x, portion.closed(3, 4)) + sum_unit = DeterministicSumUnit() + sum_unit.add_subcircuit(u1, 0.5) + sum_unit.add_subcircuit(u2, 0.3) + sum_unit.normalize() + self.assertAlmostEqual(sum_unit.weights[0], 0.5/0.8) + self.assertAlmostEqual(sum_unit.weights[1], 0.3/0.8) + + if __name__ == '__main__': unittest.main()