Skip to content

Commit

Permalink
Gradient is not yet working
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Jun 20, 2024
1 parent 955a0c5 commit 7471188
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 15 deletions.
26 changes: 11 additions & 15 deletions src/probabilistic_model/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from typing import Optional

import numpy as np
import plotly.graph_objects as go
from random_events.interval import *
from random_events.product_algebra import Event, SimpleEvent, VariableMap
from random_events.variable import *
from random_events.interval import *
from typing_extensions import Union, Iterable, Any, Self, Dict, List, Tuple
import plotly.graph_objects as go
from probabilistic_model.constants import SCALING_FACTOR_FOR_EXPECTATION_IN_PLOT


from probabilistic_model.constants import SCALING_FACTOR_FOR_EXPECTATION_IN_PLOT
from ..probabilistic_model import ProbabilisticModel, OrderType, MomentType, CenterType
from ..utils import SubclassJSONSerializer, MissingDict, interval_as_array

Expand All @@ -25,7 +24,7 @@ class UnivariateDistribution(ProbabilisticModel, SubclassJSONSerializer):

@property
def variables(self) -> Tuple[Variable, ...]:
return (self.variable, )
return (self.variable,)

def support(self) -> Event:
return SimpleEvent({self.variable: self.univariate_support}).as_composite_set()
Expand Down Expand Up @@ -96,8 +95,8 @@ def cdf(self, x: np.array) -> np.array:
def probability_of_simple_event(self, event: SimpleEvent) -> float:
interval: Interval = event[self.variable]
points = interval_as_array(interval)
upper_bound_cdf = self.cdf(points[:, (1, )])
lower_bound_cdf = self.cdf(points[:, (0, )])
upper_bound_cdf = self.cdf(points[:, (1,)])
lower_bound_cdf = self.cdf(points[:, (0,)])
return (upper_bound_cdf - lower_bound_cdf).sum()

def log_conditional(self, event: Event) -> Tuple[Optional[Self], float]:
Expand Down Expand Up @@ -183,17 +182,15 @@ def left_included_condition(self, x: np.array) -> np.array:
:param x: The data
:return: A boolean array
"""
return ((self.interval.lower <= x if self.interval.left == Bound.CLOSED else self.interval.lower < x).
reshape(-1, 1))
return (self.interval.lower <= x if self.interval.left == Bound.CLOSED else self.interval.lower < x)

def right_included_condition(self, x: np.array) -> np.array:
"""
Check if x is included in the right bound of the interval.
:param x: The data
:return: A boolean array
"""
return ((x < self.interval.upper if self.interval.right == Bound.OPEN else x <= self.interval.upper).
reshape(-1, 1))
return (x < self.interval.upper if self.interval.right == Bound.OPEN else x <= self.interval.upper)

def included_condition(self, x: np.array) -> np.array:
"""
Expand All @@ -204,7 +201,7 @@ def included_condition(self, x: np.array) -> np.array:
return self.left_included_condition(x) & self.right_included_condition(x)

def log_likelihood(self, x: np.array) -> np.array:
result = np.full(len(x), -np.inf)
result = np.full(x.shape[:-1], -np.inf)
include_condition = self.included_condition(x)
filtered_x = x[include_condition].reshape(-1, 1)
result[include_condition[:, 0]] = self.log_likelihood_without_bounds_check(filtered_x)
Expand Down Expand Up @@ -406,7 +403,7 @@ def univariate_support(self) -> Interval:
return result

def cdf(self, x: np.array) -> np.array:
result = np.zeros((len(x), ))
result = np.zeros((len(x),))
maximum_value = max(x)
for value, p in self.probabilities.items():
if value > maximum_value:
Expand Down Expand Up @@ -474,7 +471,7 @@ def log_likelihood(self, events: np.array) -> np.array:
return result

def cdf(self, x: np.array) -> np.array:
result = np.zeros((len(x), ))
result = np.zeros((len(x),))
result[x[:, 0] >= self.location] = 1.
return result

Expand Down Expand Up @@ -551,4 +548,3 @@ def plot(self, **kwargs) -> List:
SCALING_FACTOR_FOR_EXPECTATION_IN_PLOT],
mode="lines+markers", name="Mode")
return [pdf_trace, cdf_trace, expectation_trace, mode_trace]

Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from __future__ import annotations

from abc import abstractmethod

import networkx as nx
import numpy as np
from equinox import AbstractVar
from networkx.classes.digraph import _CachedPropertyResetterPred
from random_events.interval import SimpleInterval
from random_events.variable import Continuous
from typing_extensions import Self, Type, Dict

from probabilistic_model.probabilistic_circuit.probabilistic_circuit import (SmoothSumUnit,
ProbabilisticCircuit as PMProbabilisticCircuit,
DecomposableProductUnit,
ProbabilisticCircuitMixin)
from probabilistic_model.probabilistic_circuit.distributions import UniformDistribution as PCUniformDistribution
import equinox
import jax.numpy as jnp
from jax import Array
from random_events.utils import recursive_subclasses


class ProbabilisticCircuit(PMProbabilisticCircuit):

@classmethod
def from_probabilistic_circuit(cls, probabilistic_circuit: PMProbabilisticCircuit) -> Self:

node_remap: Dict = dict()

result = cls()
for node in probabilistic_circuit.nodes:
jax_node = inverse_class_of(type(node)).from_unit(unit=node, probabilistic_circuit=result)
nx.DiGraph.add_node(result, jax_node)
node_remap[node] = jax_node

for edge in probabilistic_circuit.edges:
result.add_edge(node_remap[edge[0]], node_remap[edge[1]])
return result

def log_likelihood(self, events: Array) -> Array:
return self.root.log_likelihood(events)


def inverse_class_of(clazz: Type[ProbabilisticCircuitMixin]) -> Type[ModuleMixin]:
for subclass in recursive_subclasses(ModuleMixin):
if issubclass(clazz, subclass.origin_class()):
return subclass
raise TypeError(f"Could not find class for {clazz}")


class ModuleMixin:
"""
Mixin for JAX modules that are capable of being converted to the original probabilistic circuit module.
JAX modules are limited in functionality, as only the log_likelihood method is supported and automatically
differentiable.
"""

@staticmethod
@abstractmethod
def origin_class() -> Type[ProbabilisticCircuitMixin]:
"""
The original class of the module.
:return: The original class of the module.
"""
raise NotImplementedError

@classmethod
@abstractmethod
def from_unit(cls, unit: ProbabilisticCircuitMixin, probabilistic_circuit: ProbabilisticCircuit) -> Self:
"""
Create a new instance of this class from a unit.
:param unit: The unit to read the parameters from.
:param probabilistic_circuit: The probabilistic circuit where the unit should be added.
:return: The jax version of the unit.
"""
raise NotImplementedError


class UniformDistribution(PCUniformDistribution, equinox.Module, ModuleMixin):

variable: Continuous
interval: SimpleInterval
probabilistic_circuit: ProbabilisticCircuit = equinox.field(static=True)

def __init__(self, variable: Continuous, interval: SimpleInterval, probabilistic_circuit: ProbabilisticCircuit):
self.variable = variable
self.interval = interval
self.probabilistic_circuit = probabilistic_circuit

@staticmethod
def origin_class() -> Type[PCUniformDistribution]:
return PCUniformDistribution

@classmethod
def from_unit(cls, unit: PCUniformDistribution, probabilistic_circuit: ProbabilisticCircuit) -> Self:
return cls(unit.variable, unit.interval, probabilistic_circuit)

def log_pdf_value(self) -> Array:
return -jnp.log(self.upper - self.lower)

def log_likelihood_without_bounds_check(self, x: Array) -> Array:
return jnp.full((x.shape[:-1]), self.log_pdf_value())

def log_likelihood(self, x: Array) -> Array:
result = jnp.full(x.shape[:-1], -jnp.inf)
include_condition = self.included_condition(x)
filtered_x = x[include_condition].reshape(-1, 1)
likelihoods = self.log_likelihood_without_bounds_check(filtered_x)
result = result.at[include_condition[:, 0]].set(likelihoods)
return result

def __hash__(self):
return id(self)


class SumUnit(equinox.Module, SmoothSumUnit, ModuleMixin):

_weights: Array
probabilistic_circuit: ProbabilisticCircuit = equinox.field(static=True)

def __init__(self, initial_weights: Array, probabilistic_circuit: ProbabilisticCircuit):
super().__init__()
self._weights = initial_weights
self.probabilistic_circuit = probabilistic_circuit

@staticmethod
def origin_class() -> Type[SmoothSumUnit]:
return SmoothSumUnit

@property
def weights(self) -> Array:
exp_weights = jnp.exp(self._weights)
return exp_weights / exp_weights.sum()

def log_likelihood(self, events: Array) -> Array:
result = jnp.zeros(events.shape[:-1])
for weight, subcircuit in zip(self.weights, self.subcircuits):
subcircuit_likelihood = jnp.exp(subcircuit.log_likelihood(events))
result += weight * subcircuit_likelihood
return jnp.log(result)

def __call__(self, x):
return self.log_likelihood(x)

@classmethod
def from_unit(cls, unit: SmoothSumUnit, probabilistic_circuit: ProbabilisticCircuit) -> Self:
result = cls(jnp.log(unit.weights), probabilistic_circuit)
return result

def __hash__(self):
return id(self)


class ProductUnit(equinox.Module, DecomposableProductUnit, ModuleMixin):
probabilistic_circuit: ProbabilisticCircuit = equinox.field(static=True)

@staticmethod
def origin_class() -> Type[DecomposableProductUnit]:
return DecomposableProductUnit


Empty file added test/test_jax/__init__.py
Empty file.
52 changes: 52 additions & 0 deletions test/test_jax/test_probabilistic_circuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import unittest

import equinox
import jax
from random_events.variable import Continuous

from probabilistic_model.learning.nyga_distribution import NygaDistribution
from probabilistic_model.probabilistic_circuit.jax.probabilistic_circuit import *
import plotly.graph_objects as go
import numpy as np
import jax.numpy as jnp


class TestJaxUnits(unittest.TestCase):

x: Continuous = Continuous("x")
y: Continuous = Continuous("y")
np.random.seed(69)
data: np.ndarray = np.random.multivariate_normal(np.array([0, 0]), np.array([[1, .5], [.5, 1]]), size=(1000))
nyga_distribution: NygaDistribution

@classmethod
def setUp(cls) -> None:
cls.nyga_distribution = NygaDistribution(cls.y, 50)
cls.nyga_distribution.fit(cls.data[:, 1])

def show(self):
fig = go.Figure(self.nyga_distribution.plot(), self.nyga_distribution.plotly_layout())
fig.show()

def test_from_probabilistic_circuit(self):
probabilistic_circuit = ProbabilisticCircuit.from_probabilistic_circuit(self.nyga_distribution.probabilistic_circuit)
self.assertIsInstance(probabilistic_circuit, ProbabilisticCircuit)
self.assertEqual(len(probabilistic_circuit.nodes), len(self.nyga_distribution.probabilistic_circuit.nodes))
self.assertEqual(len(probabilistic_circuit.edges), len(self.nyga_distribution.probabilistic_circuit.edges))

def test_likelihood(self):
probabilistic_circuit = ProbabilisticCircuit.from_probabilistic_circuit(self.nyga_distribution.probabilistic_circuit)
log_likelihood = probabilistic_circuit.log_likelihood(jnp.array(self.data[:, (1, )]))
self.assertTrue(jnp.allclose(log_likelihood, self.nyga_distribution.log_likelihood(self.data[:, (1, )])))

def test_grad(self):
pc = ProbabilisticCircuit.from_probabilistic_circuit(self.nyga_distribution.probabilistic_circuit).root

@jax.jit
@jax.grad
def loss_fn(model, x):
log_likelihood = jax.vmap(model)(x)
return jnp.mean(log_likelihood)

grad = loss_fn(pc, jnp.array(self.data[:, (1, )]))
print(grad)

0 comments on commit 7471188

Please sign in to comment.