Skip to content

Commit

Permalink
Test case is more expressive now
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Feb 23, 2024
1 parent dfbc4c2 commit ea08206
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions test/test_jpt/test_jpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import networkx as nx
import numpy as np
import pandas as pd
import portion
import random_events
import sklearn.datasets
from jpt import infer_from_dataframe as old_infer_from_dataframe
Expand Down Expand Up @@ -389,6 +390,7 @@ def setUpClass(cls):

model_sl_sw = JPT(variables, min_samples_leaf=0.4, features=[cls.sl, cls.sw], targets=variables)
model_sl_sw.fit(df)

cls.model_sl_sw = model_sl_sw.marginal([cls.sl, cls.sw])

model_pl_pw = JPT(variables, min_samples_leaf=0.4, features=[cls.pl, cls.pw], targets=variables)
Expand Down Expand Up @@ -444,7 +446,12 @@ def test_to_bayesian_network(self):
root = DiscreteDistribution(self.species_latent_variable, self.model_species.weights)
self.assertEqual(root.weights, [1 / 3] * 3)
bayesian_network.add_node(root)
self.assertEqual(bayesian_network.probability(Event()), 1.)

p_species = ConditionalProbabilisticCircuit(self.model_species.variables)
p_species.from_unit(self.model_species)
bayesian_network.add_node(p_species)
bayesian_network.add_edge(root, p_species)


# mount the interaction term with the latent variable of the sepal distribution
p_sepal_species = ConditionalProbabilityTable(self.sepal_latent_variable)
Expand Down Expand Up @@ -479,7 +486,16 @@ def test_to_bayesian_network(self):
self.assertEqual(bayesian_network.probability(Event()), 1.)
self.assertAlmostEqual(bayesian_network.as_probabilistic_circuit().probability(Event()), 1)

p_specie_1 = Event({self.species_latent_variable: 0})
self.assertAlmostEqual(bayesian_network.probability(p_specie_1), 1 / 3)
self.assertAlmostEqual(bayesian_network.as_probabilistic_circuit().probability(p_specie_1), 1 / 3)
e_species_1 = Event({self.species: 0})
bn_p_species_1 = bayesian_network.probability(e_species_1)
self.assertAlmostEqual(bn_p_species_1, 1 / 3)
self.assertAlmostEqual(bayesian_network.as_probabilistic_circuit().probability(e_species_1), 1 / 3)

complex_event = Event({self.species: 0,
self.sl: portion.closed(4.5, 5.5)})
pc = bayesian_network.as_probabilistic_circuit()
pc_m = pc.marginal([v for v in pc.variables if not v.name.endswith(".latent")]).simplify()
self.assertEqual(pc_m.variables, (self.pl, self.pw, self.sl, self.sw, self.species))

self.assertAlmostEqual(pc_m.probability(complex_event), 0.2333333)
self.assertLess(len(pc_m.weighted_edges), math.prod([len(v.domain) for v in bayesian_network.variables]))

0 comments on commit ea08206

Please sign in to comment.