Skip to content

Commit

Permalink
Fixed test for sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Feb 6, 2024
1 parent f3d0a86 commit 6de04d3
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 136 deletions.
204 changes: 74 additions & 130 deletions examples/template_modelling.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/probabilistic_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.0.3"
__version__ = "3.0.4"
9 changes: 4 additions & 5 deletions test/test_probabilistic_circuits/test_graph_circuit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
import unittest

import numpy as np
Expand Down Expand Up @@ -344,9 +345,7 @@ def test_serialization(self):
self.assertEqual(self.model, deserialized)

def test_update_variables(self):
print(self.model.variables)
self.model.update_variables(VariableMap({self.real: self.real3}))
print(self.model.variables)
self.assertEqual(self.model.variables, (self.real2, self.real3))


Expand Down Expand Up @@ -411,6 +410,7 @@ class MountedInferenceTestCase(unittest.TestCase, ShowMixin):
model: DeterministicSumUnit

def setUp(self):
random.seed(69)
model = DeterministicSumUnit()
model.add_subcircuit(UniformDistribution(self.x, portion.closed(-1.5, -0.5)), 0.5)
model.add_subcircuit(UniformDistribution(self.x, portion.closed(0.5, 1.5)), 0.5)
Expand All @@ -432,12 +432,11 @@ def test_sample_from_uniform(self):
samples = leaf.sample(2)
self.assertNotEqual(samples[0], samples[1])

@unittest.skip("Sampling multiple things with undirected cycles is weird")
def test_sample(self):
# self.show(self.model)
samples: List = self.model.probabilistic_circuit.sample(2)
samples: List = self.model.sample(2)
self.assertEqual(len(samples), 2)
print(samples)

self.assertNotEqual(samples[0], samples[1])

def test_samples_in_sequence(self):
Expand Down

0 comments on commit 6de04d3

Please sign in to comment.