Skip to content

Commit

Permalink
Started to refactor and limit the scope of bayesian networks.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Feb 20, 2024
1 parent 16e3e63 commit a7eeb3c
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 44 deletions.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,27 @@
from random_events.variables import Variable, Symbolic, Integer
from typing_extensions import Self, List, Tuple, Iterable, Optional, Dict

from .probabilistic_circuit.distributions import SymbolicDistribution, IntegerDistribution
from .probabilistic_model import ProbabilisticModel
from .distributions.multinomial import MultinomialDistribution
from probabilistic_model.probabilistic_circuit.distributions import SymbolicDistribution, IntegerDistribution
from probabilistic_model.probabilistic_model import ProbabilisticModel
from probabilistic_model.distributions.multinomial import MultinomialDistribution
import networkx as nx
import numpy as np

from .probabilistic_circuit.probabilistic_circuit import (ProbabilisticCircuit, DeterministicSumUnit,
DecomposableProductUnit, ProbabilisticCircuitMixin)
from ..probabilistic_circuit.probabilistic_circuit import (ProbabilisticCircuit,
DeterministicSumUnit,
DecomposableProductUnit,
ProbabilisticCircuitMixin)
from ..distributions.distributions import DiscreteDistribution


class BayesianNetworkMixin:
class BayesianNetworkMixin(ProbabilisticModel):
"""
Mixin class for (conditional) probability distributions in bayesian networks.
Mixin class for conditional probability distributions in tree shaped bayesian networks.
"""

bayesian_network: BayesianNetwork

forward_message: MultinomialDistribution
forward_message: DiscreteDistribution
"""
The marginal distribution (message) as calculated in the forward pass.
"""
Expand All @@ -35,41 +38,39 @@ class BayesianNetworkMixin:
"""

@property
def parents(self) -> List[Self]:
return list(self.bayesian_network.predecessors(self))

@property
def is_root(self):
return len(self.parents) == 0

@property
def variables(self) -> Tuple[Variable, ...]:
raise NotImplementedError
def parent(self) -> Optional[Self]:
"""
The parent node if it exists and None if this is a root.
:return:
"""
parents = list(self.bayesian_network.predecessors(self))
if len(parents) > 1:
raise ValueError("Bayesian Network is not a tree.")
elif len(parents) == 1:
return parents[0]
else:
return None

@property
def parent_variables(self) -> Tuple[Variable, ...]:
parent_variables = [variable for parent in self.parents for variable in parent.variables]
return tuple(sorted(parent_variables))
def is_root(self) -> bool:
"""
:return: Rather this is the root or not.
"""
return self.parent is None

@property
def parent_and_node_variables(self):
return self.parent_variables + self.variables

def __hash__(self):
return id(self)

def _likelihood(self, event: Iterable, parent_event: Iterable) -> float:
def parent_and_node_variables(self) -> Tuple[Variable, ...]:
"""
Calculate the conditional likelihood of the event given the parent event.
:param event: The event to calculate the likelihood for.
:param parent_event: The parent event to condition on.
:return: The likelihood of the event given the parent event.
Get the parent variables together with this nodes variable.
:return: A tuple containing first the parent variable and second this nodes variable.
"""
raise NotImplementedError
if self.is_root:
return self.variables
else:
return self.parent.variables + self.variables

def as_probabilistic_circuit(self) -> DeterministicSumUnit:
raise NotImplementedError
def __hash__(self):
return id(self)

def as_probabilistic_circuit_with_parent_message(self) -> DeterministicSumUnit:
"""
Expand All @@ -86,6 +87,14 @@ def joint_distribution_with_parents(self) -> MultinomialDistribution:
"""
raise NotImplementedError

def forward_pass(self, event: EncodedEvent):
"""
Calculate the forward pass for this node given the event.
This includes calculating the forward message and the forward probability of said event.
:param event: The event to account for
"""
raise NotImplementedError


class ConditionalMultinomialDistribution(BayesianNetworkMixin, MultinomialDistribution):

Expand Down Expand Up @@ -135,7 +144,7 @@ def _likelihood(self, event: Iterable, parent_event: Optional[Iterable] = None)
parent_event = tuple()
return self.probabilities[tuple(parent_event) + tuple(event)].item()

def calculate_forward_message(self, event: EncodedEvent):
def forward_pass(self, event: EncodedEvent):
"""
Calculate the forward message for this node given the event and the forward probability of said event.
:param event: The event to account for
Expand Down Expand Up @@ -204,7 +213,7 @@ def _likelihood(self, event: Iterable, parent_event: Iterable) -> float:
circuit = self.circuits[tuple(parent_event)]
return circuit._likelihood(event)

def calculate_forward_message(self, event: EncodedEvent):
def forward_pass(self, event: EncodedEvent):
parent = self.parents[0]
probability = 0.
for parent_probability, circuit in zip(parent.forward_message.probabilities, self.circuits.values()):
Expand Down Expand Up @@ -244,6 +253,9 @@ def joint_distribution_with_parents(self) -> MultinomialDistribution:


class BayesianNetwork(ProbabilisticModel, nx.DiGraph):
"""
Class for Bayesian Networks that are tree shaped and have univariate inner nodes.
"""

def __init__(self):
ProbabilisticModel.__init__(self, None)
Expand Down Expand Up @@ -284,7 +296,7 @@ def forward_pass(self, event: EncodedEvent):
"""
# calculate forward pass
for node in self.nodes:
node.calculate_forward_message(event)
node.forward_pass(event)

def _probability(self, event: EncodedEvent) -> float:
self.forward_pass(event)
Expand Down
44 changes: 44 additions & 0 deletions src/probabilistic_model/bayesian_network/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from random_events.events import Event, EncodedEvent
from typing_extensions import Tuple, Dict, Iterable, List

from .bayesian_network import BayesianNetworkMixin
from ..probabilistic_model import ProbabilisticModel
from random_events.variables import Discrete
from ..distributions.distributions import DiscreteDistribution


class RootDistribution(BayesianNetworkMixin, DiscreteDistribution):

def forward_pass(self, event: EncodedEvent):
self.forward_message, self.forward_probability = self._conditional(event)


class ConditionalProbabilityTable(BayesianNetworkMixin):

variables: Tuple[Discrete, ...]
conditional_probability_distributions: Dict[Tuple, DiscreteDistribution] = dict()

def __init__(self, variable: Discrete):
ProbabilisticModel.__init__(self, [variable])

@property
def variable(self) -> Discrete:
return self.variables[0]

def likelihood(self, event: Iterable) -> float:
return self._likelihood([variable.encode(value) for variable, value in zip(self.parent_and_node_variables, event)])

def _likelihood(self, event: Iterable) -> float:
parent_event = tuple(event[:1])
node_event = tuple(event[1:])
return self.conditional_probability_distributions[parent_event]._likelihood(node_event)

def __repr__(self):
return f"P({self.variable.name}|{self.parent.variable.name})"

def to_tabulate(self) -> List[List[str]]:
table = [[self.parent.variable.name, self.variable.name, repr(self)]]
for parent_event, distribution in self.conditional_probability_distributions.items():
for event, probability in zip(self.variable.domain, distribution.weights):
table.append([str(parent_event[0]), str(event), str(probability)])
return table
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import portion
from random_events.events import Event

from probabilistic_model.bayesian_network import (BayesianNetwork, ConditionalMultinomialDistribution,
ConditionalProbabilisticCircuit)
from probabilistic_model.bayesian_network.bayesian_network import (BayesianNetwork, ConditionalMultinomialDistribution,
ConditionalProbabilisticCircuit)
from probabilistic_model.probabilistic_circuit.distributions import UniformDistribution
from probabilistic_model.distributions.multinomial import MultinomialDistribution
from random_events.variables import Symbolic, Continuous, Integer
Expand All @@ -14,9 +14,6 @@
import matplotlib.pyplot as plt
import networkx as nx

from probabilistic_model.probabilistic_circuit.probabilistic_circuit import DeterministicSumUnit, \
DecomposableProductUnit


class MinimalBayesianNetworkTestCase(unittest.TestCase):

Expand Down
50 changes: 50 additions & 0 deletions test/test_bayesian_network/test_distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest

from random_events.variables import Symbolic

from probabilistic_model.bayesian_network.distributions import ConditionalProbabilityTable, RootDistribution
from probabilistic_model.bayesian_network.bayesian_network import BayesianNetwork
from probabilistic_model.distributions.distributions import SymbolicDistribution

import tabulate


class DistributionTestCase(unittest.TestCase):

x = Symbolic("x", [0, 1, 2])
y = Symbolic("y", [0, 1])

p_x = ConditionalProbabilityTable(x)
p_yx = ConditionalProbabilityTable(y)

def setUp(self):
bayesian_network = BayesianNetwork()

# create the root distribution for x
self.p_x = RootDistribution(self.x, [0.5, 0.3, 0.2])

# create the conditional probability table for y
self.p_yx.conditional_probability_distributions[(0,)] = SymbolicDistribution(self.y, [0.5, 0.5])
self.p_yx.conditional_probability_distributions[(1,)] = SymbolicDistribution(self.y, [0.3, 0.7])
self.p_yx.conditional_probability_distributions[(2,)] = SymbolicDistribution(self.y, [0.1, 0.9])

# add the distributions to the bayesian network
bayesian_network.add_node(self.p_x)
bayesian_network.add_node(self.p_yx)

# add the edge between x and y
bayesian_network.add_edge(self.p_x, self.p_yx)

def test_to_tabulate(self):
table = tabulate.tabulate(self.p_yx.to_tabulate())
self.assertIsInstance(table, str)
# print(table)

def test_likelihood(self):
self.assertEqual(self.p_yx.likelihood([0, 1]), 0.5)

def test_probability(self):
...

if __name__ == '__main__':
unittest.main()

0 comments on commit a7eeb3c

Please sign in to comment.