Skip to content

Commit

Permalink
Added normalization for sum units.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Mar 6, 2024
1 parent 235304e commit ee64a62
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/probabilistic_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.3.3"
__version__ = "3.3.4"
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,14 @@ def simplify(self) -> Self:

return result

def normalize(self):
"""
Normalize the weights of the subcircuits such that they sum up to 1 inplace.
"""
total_weight = sum([weight for weight, _ in self.weighted_subcircuits])
for subcircuit in self.subcircuits:
self.probabilistic_circuit.edges[self, subcircuit]["weight"] /= total_weight


class DeterministicSumUnit(SmoothSumUnit):
"""
Expand Down
15 changes: 15 additions & 0 deletions test/test_probabilistic_circuits/test_graph_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,5 +486,20 @@ def test_simplify(self):
self.assertEqual(len(simplified.probabilistic_circuit.nodes()), len(self.model.probabilistic_circuit.nodes))
self.assertEqual(len(simplified.probabilistic_circuit.edges()), len(self.model.probabilistic_circuit.edges))


class NormalizationTestCase(unittest.TestCase):
x: Continuous = Continuous("x")

def test_normalization(self):
u1 = UniformDistribution(self.x, portion.closed(0, 1))
u2 = UniformDistribution(self.x, portion.closed(3, 4))
sum_unit = DeterministicSumUnit()
sum_unit.add_subcircuit(u1, 0.5)
sum_unit.add_subcircuit(u2, 0.3)
sum_unit.normalize()
self.assertAlmostEqual(sum_unit.weights[0], 0.5/0.8)
self.assertAlmostEqual(sum_unit.weights[1], 0.3/0.8)


if __name__ == '__main__':
unittest.main()

0 comments on commit ee64a62

Please sign in to comment.