Skip to content

Commit

Permalink
Integration tests with JPTs now working.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Feb 23, 2024
1 parent a214d74 commit dfbc4c2
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/probabilistic_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.2.9"
__version__ = "3.3.1"
36 changes: 34 additions & 2 deletions src/probabilistic_model/bayesian_network/distributions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from matplotlib import pyplot as plt
from random_events.events import Event, EncodedEvent, VariableMap
from typing_extensions import Tuple, Dict, Iterable, List, Type, Union, Optional
from typing_extensions import Tuple, Dict, Iterable, List, Type, Union, Optional, Self

from .bayesian_network import BayesianNetworkMixin
from ..probabilistic_model import ProbabilisticModel
Expand All @@ -13,6 +13,7 @@
from ..probabilistic_circuit.distributions import (SymbolicDistribution as PCSymbolicDistribution,
IntegerDistribution as PCIntegerDistribution,
DiscreteDistribution as PCDiscreteDistribution)
from ..distributions.multinomial import MultinomialDistribution


class DiscreteDistribution(BayesianNetworkMixin, PCDiscreteDistribution):
Expand Down Expand Up @@ -171,6 +172,28 @@ def interaction_term(self, node_latent_variable: Discrete, parent_latent_variabl
self.parent.variable: parent_latent_variable}))
return interaction_term

def from_multinomial_distribution(self, distribution: MultinomialDistribution) -> Self:
"""
Get the conditional probability table from a multinomial distribution.
:param distribution: The multinomial distribution to get the data from
:return:
"""
assert len(distribution.variables) == 2
assert self.variable in distribution.variables

parent_variable = distribution.variables[0] \
if distribution.variables[0] != self.variable else distribution.variables[1]

for parent_event in parent_variable.domain:
parent_event = Event({parent_variable: parent_event})
conditional, _ = distribution.conditional(parent_event)
marginal = conditional.marginal(self.variables).normalize()
self.conditional_probability_distributions[parent_event[parent_variable]] = (
DiscreteDistribution(self.variable, marginal.probabilities.tolist()))

return self


class ConditionalProbabilisticCircuit(ConditionalProbabilityTable):

Expand Down Expand Up @@ -224,4 +247,13 @@ def interaction_term(self, node_latent_variable: Discrete, parent_latent_variabl

return result.probabilistic_circuit


def from_unit(self, unit: ProbabilisticCircuitMixin) -> Self:
"""
Get the conditional probability table from a probabilistic circuit by mounting all children as conditional
probability distributions.
:param unit: The probabilistic circuit to get the data from
:return: The conditional probability distribution
"""
for index, subcircuit in enumerate(unit.subcircuits):
self.conditional_probability_distributions[(index, )] = subcircuit.__copy__().probabilistic_circuit
return self
33 changes: 32 additions & 1 deletion src/probabilistic_model/distributions/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from random_events.events import EncodedEvent

from ..probabilistic_model import ProbabilisticModel
from typing_extensions import Self
from typing_extensions import Self, Any

from ..probabilistic_circuit.probabilistic_circuit import (ProbabilisticCircuit, DeterministicSumUnit,
DecomposableProductUnit)
Expand Down Expand Up @@ -156,3 +156,34 @@ def as_probabilistic_circuit(self) -> DeterministicSumUnit:
result.add_subcircuit(product_unit, probability)

return result

def encode_full_evidence_event(self, event: Iterable) -> List[int]:
"""
Encode a full evidence event into a list of integers.
:param event: The event to encode.
:return: The encoded event.
"""
return [variable.encode(value) for variable, value in zip(self.variables, event)]

def fit(self, data: Iterable[Iterable[Any]]) -> Self:
"""
Fit the distribution to the data.
:param data: The data to fit the distribution to.
:return: The fitted distribution.
"""
encoded_data = np.zeros((len(data), len(self.variables)), dtype=int)
for index, sample in enumerate(data):
indices = self.encode_full_evidence_event(sample)
encoded_data[index] = indices

return self._fit(encoded_data)

def _fit(self, data: np.ndarray) -> Self:
probabilities = np.zeros_like(self.probabilities)
uniques, counts = np.unique(data, return_counts=True, axis=0)

for unique, count in zip(uniques.astype(int), counts):
probabilities[tuple(unique)] = count

self.probabilities = probabilities / probabilities.sum()
return self
4 changes: 2 additions & 2 deletions src/probabilistic_model/learning/jpt/jpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,10 @@ def _from_json(cls, data: Dict[str, Any]) -> Self:
result.probabilistic_circuit.add_edge(result, subcircuit, weight=weight)
return result

def marginal(self, variables: Iterable[Variable]) -> Optional[Self]:
def marginal(self, variables: Iterable[Variable], simplify_if_univariate=True) -> Optional[Self]:
result = super().marginal(variables)

if result is None or len(result.variables) > 1:
if result is None or len(result.variables) > 1 or not simplify_if_univariate:
return result

variable = result.variables[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ def sub_circuit_index_of_sample(self, sample: Iterable) -> Optional[int]:
return index
return None


class DecomposableProductUnit(ProbabilisticCircuitMixin):
"""
Decomposable Product Units for Probabilistic Circuits
Expand Down
95 changes: 87 additions & 8 deletions test/test_jpt/test_jpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@
from random_events.variables import Variable

from probabilistic_model.bayesian_network.bayesian_network import BayesianNetwork
from probabilistic_model.bayesian_network.distributions import (DiscreteDistribution, ConditionalProbabilisticCircuit,
ConditionalProbabilityTable)
from probabilistic_model.distributions.multinomial import MultinomialDistribution
from probabilistic_model.learning.jpt.jpt import JPT
from probabilistic_model.learning.jpt.variables import (ScaledContinuous, infer_variables_from_dataframe, Integer,
Symbolic, Continuous)
from probabilistic_model.learning.nyga_distribution import NygaDistribution
from probabilistic_model.probabilistic_circuit.distributions.distributions import IntegerDistribution, \
SymbolicDistribution
from probabilistic_model.probabilistic_circuit.probabilistic_circuit import DecomposableProductUnit, \
DeterministicSumUnit
from probabilistic_model.distributions.multinomial import MultinomialDistribution
from probabilistic_model.probabilistic_circuit.probabilistic_circuit import DecomposableProductUnit


class ShowMixin:
Expand Down Expand Up @@ -356,13 +357,10 @@ def test_serialization(self):


class BayesianJPTTestCase(unittest.TestCase):

model_sl_sw: JPT
model_pl_pw: JPT
model_species: JPT

bayesian_network: BayesianNetwork

sl: Continuous
sw: Continuous
pl: Continuous
Expand All @@ -372,6 +370,12 @@ class BayesianJPTTestCase(unittest.TestCase):
species_sepal_interaction_term: MultinomialDistribution
species_petal_interaction_term: MultinomialDistribution

subcircuit_indices: pd.DataFrame

species_latent_variable: random_events.variables.Discrete
sepal_latent_variable: random_events.variables.Discrete
petal_latent_variable: random_events.variables.Discrete

@classmethod
def setUpClass(cls):
iris = sklearn.datasets.load_iris(as_frame=True)
Expand All @@ -393,14 +397,89 @@ def setUpClass(cls):

model_species = JPT(variables, min_samples_leaf=0.3, features=[cls.species], targets=variables)
model_species.fit(df)
cls.model_species = DeterministicSumUnit.marginal(model_species, [cls.species])
cls.model_species = model_species.marginal([cls.species], simplify_if_univariate=False)

subcircuit_indices = np.zeros((len(df), 3))
for index, sample in enumerate(df.values):
sl, sw, pl, pw, species = sample
subcircuit_indices[index, 0] = cls.model_sl_sw.sub_circuit_index_of_sample((sl, sw))
subcircuit_indices[index, 1] = cls.model_pl_pw.sub_circuit_index_of_sample((pl, pw))
subcircuit_indices[index, 2] = cls.model_species.sub_circuit_index_of_sample((species,))

cls.subcircuit_indices = pd.DataFrame(subcircuit_indices, columns=["sl_sw", "pl_pw", "species"])

cls.species_latent_variable = random_events.variables.Discrete("species.latent",
range(len(cls.model_species.subcircuits)))
cls.sepal_latent_variable = random_events.variables.Discrete("sepal.latent",
range(len(cls.model_sl_sw.subcircuits)))
cls.petal_latent_variable = random_events.variables.Discrete("petal.latent",
range(len(cls.model_sl_sw.subcircuits)))

cls.species_sepal_interaction_term = MultinomialDistribution(
[cls.sepal_latent_variable, cls.species_latent_variable])
cls.species_sepal_interaction_term._fit(subcircuit_indices[:, (0, 2)])

cls.species_petal_interaction_term = MultinomialDistribution(
[cls.petal_latent_variable, cls.species_latent_variable])
cls.species_petal_interaction_term._fit(subcircuit_indices[:, (1, 2)])

def test_setup(self):
self.assertEqual(self.model_sl_sw.variables, (self.sl, self.sw))
self.assertEqual(self.model_pl_pw.variables, (self.pl, self.pw))
self.assertEqual(self.model_species.variables, (self.species, ))
self.assertEqual(self.model_species.variables, (self.species,))

self.assertGreater(len(self.model_sl_sw.subcircuits), 1)
self.assertGreater(len(self.model_pl_pw.subcircuits), 1)
self.assertGreater(len(self.model_species.subcircuits), 1)

self.assertFalse(self.subcircuit_indices.isna().any().any())

self.assertEqual(self.species_petal_interaction_term.probabilities.sum(), 1.)
self.assertEqual(self.species_sepal_interaction_term.probabilities.sum(), 1.)

def test_to_bayesian_network(self):

# create bayesian network with root node
bayesian_network = BayesianNetwork()
root = DiscreteDistribution(self.species_latent_variable, self.model_species.weights)
self.assertEqual(root.weights, [1 / 3] * 3)
bayesian_network.add_node(root)
self.assertEqual(bayesian_network.probability(Event()), 1.)

# mount the interaction term with the latent variable of the sepal distribution
p_sepal_species = ConditionalProbabilityTable(self.sepal_latent_variable)
p_sepal_species.from_multinomial_distribution(self.species_sepal_interaction_term)
bayesian_network.add_node(p_sepal_species)
bayesian_network.add_edge(root, p_sepal_species)
self.assertEqual(bayesian_network.probability(Event()), 1.)

# mount the distributions of the sepal variables
p_sepal = ConditionalProbabilisticCircuit(self.model_sl_sw.variables)
p_sepal.from_unit(self.model_sl_sw)
[self.assertIsInstance(circuit.root, DecomposableProductUnit) for circuit in
p_sepal.conditional_probability_distributions.values()]
bayesian_network.add_node(p_sepal)
bayesian_network.add_edge(p_sepal_species, p_sepal)

# mount the interaction term with the latent variable of the petal distribution
p_petal_species = ConditionalProbabilityTable(self.petal_latent_variable)
p_petal_species.from_multinomial_distribution(self.species_petal_interaction_term)
bayesian_network.add_node(p_petal_species)
bayesian_network.add_edge(root, p_petal_species)

# mount the distributions of the petal variables
p_petal = ConditionalProbabilisticCircuit(self.model_pl_pw.variables)
p_petal.from_unit(self.model_pl_pw)
[self.assertIsInstance(circuit.root, DecomposableProductUnit) for circuit in
p_petal.conditional_probability_distributions.values()]
bayesian_network.add_node(p_petal)
bayesian_network.add_edge(p_petal_species, p_petal)

# test some queries
self.assertEqual(bayesian_network.probability(Event()), 1.)
self.assertAlmostEqual(bayesian_network.as_probabilistic_circuit().probability(Event()), 1)

p_specie_1 = Event({self.species_latent_variable: 0})
self.assertAlmostEqual(bayesian_network.probability(p_specie_1), 1 / 3)
self.assertAlmostEqual(bayesian_network.as_probabilistic_circuit().probability(p_specie_1), 1 / 3)

0 comments on commit dfbc4c2

Please sign in to comment.