Skip to content

Commit

Permalink
Added better rejection sampling for
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Mar 14, 2024
1 parent 64b1845 commit 2179b7b
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 39 deletions.
155 changes: 131 additions & 24 deletions src/probabilistic_model/distributions/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple, List, Optional, Dict, Any, Union

import numpy as np
from scipy.stats import gamma
from scipy.stats import gamma, expon


import portion
Expand All @@ -25,15 +25,15 @@ class GaussianDistribution(ContinuousDistribution):
The mean of the Gaussian distribution.
"""

variance: float
scale: float
"""
The variance of the Gaussian distribution.
"""

def __init__(self, variable: Continuous, mean: float, variance: float):
def __init__(self, variable: Continuous, mean: float, scale: float):
super().__init__(variable)
self.mean = mean
self.variance = variance
self.scale = scale

@property
def domain(self) -> Event:
Expand All @@ -51,7 +51,7 @@ def _pdf(self, value: float) -> float:
"""
if value == -portion.inf or value == portion.inf:
return 0
return 1/math.sqrt(2 * math.pi * self.variance) * math.exp(-1/2 * (value - self.mean) ** 2 / self.variance)
return 1/math.sqrt(2 * math.pi * self.scale) * math.exp(-1 / 2 * (value - self.mean) ** 2 / self.scale)

def _cdf(self, value: float) -> float:
r"""
Expand All @@ -67,13 +67,13 @@ def _cdf(self, value: float) -> float:
return 0
elif value == portion.inf:
return 1
return 0.5 * (1 + math.erf((value - self.mean) / math.sqrt(2 * self.variance)))
return 0.5 * (1 + math.erf((value - self.mean) / math.sqrt(2 * self.scale)))

def _mode(self) -> Tuple[List[EncodedEvent], float]:
return [EncodedEvent({self.variable: portion.singleton(self.mean)})], self._pdf(self.mean)

def sample(self, amount: int) -> List[List[float]]:
return [[random.gauss(self.mean, self.variance)] for _ in range(amount)]
return [[random.gauss(self.mean, self.scale)] for _ in range(amount)]

def raw_moment(self, order: int) -> float:
r"""
Expand All @@ -90,7 +90,7 @@ def raw_moment(self, order: int) -> float:
raw_moment = 0 # Initialize the raw moment
for j in range(math.floor(order/2)+1):
mu_term= self.mean ** (order - 2*j)
sigma_term = self.variance ** j
sigma_term = self.scale ** j

raw_moment += (math.comb(order, 2*j) * mu_term * sigma_term * math.factorial(2*j) /
(math.factorial(j) * (2 ** j)))
Expand Down Expand Up @@ -134,23 +134,23 @@ def conditional_from_simple_interval(self, interval: portion.Interval) \
resulting_distribution = TruncatedGaussianDistribution(self.variable,
interval=intersection,
mean=self.mean,
variance=self.variance)
scale=self.scale)
return resulting_distribution, probability

def __eq__(self, other):
return self.mean == other.mean and self.variance == other.variance and super().__eq__(other)
return self.mean == other.mean and self.scale == other.scale and super().__eq__(other)

@property
def representation(self):
return f"N({self.mean}, {self.variance})"
return f"N({self.mean}, {self.scale})"

def __copy__(self):
return self.__class__(self.variable, self.mean, self.variance)
return self.__class__(self.variable, self.mean, self.scale)

def to_json(self) -> Dict[str, Any]:
return {**super().to_json(),
"mean": self.mean,
"variance": self.variance}
"variance": self.scale}

@classmethod
def _from_json(cls, data: Dict[str, Any]) -> Self:
Expand All @@ -163,8 +163,8 @@ class TruncatedGaussianDistribution(GaussianDistribution):
Class for Truncated Gaussian distributions.
"""

def __init__(self, variable: Continuous, interval: portion.Interval, mean: float, variance: float):
super().__init__(variable, mean, variance)
def __init__(self, variable: Continuous, interval: portion.Interval, mean: float, scale: float):
super().__init__(variable, mean, scale)
self.interval = interval

@property
Expand Down Expand Up @@ -216,9 +216,8 @@ def _mode(self) -> Tuple[List[EncodedEvent], float]:
else:
return [EncodedEvent({self.variable: portion.singleton(self.upper)})], self._pdf(self.upper)

def sample(self, amount: int) -> List[List[float]]:
def rejection_sample(self, amount: int) -> List[List[float]]:
"""
.. note::
This uses rejection sampling and hence is inefficient.
Expand All @@ -227,7 +226,7 @@ def sample(self, amount: int) -> List[List[float]]:
samples = [sample for sample in samples if sample[0] in self.interval]
rejected_samples = amount - len(samples)
if rejected_samples > 0:
samples.extend(self.sample(rejected_samples))
samples.extend(self.rejection_sample(rejected_samples))
return samples

def moment(self, order: OrderType, center: CenterType) -> MomentType:
Expand Down Expand Up @@ -258,9 +257,9 @@ def moment(self, order: OrderType, center: CenterType) -> MomentType:
order = order[self.variable]
center = center[self.variable]

lower_bound=(self.lower-self.mean)/math.sqrt(self.variance) #normalize the lower bound
upper_bound=(self.upper-self.mean)/math.sqrt(self.variance) #normalize the upper bound
normalized_center = (center-self.mean)/math.sqrt(self.variance) #normalize the center
lower_bound=(self.lower-self.mean)/math.sqrt(self.scale) #normalize the lower bound
upper_bound=(self.upper-self.mean)/math.sqrt(self.scale) #normalize the upper bound
normalized_center = (center-self.mean)/math.sqrt(self.scale) #normalize the center
truncated_moment = 0

for k in range(order + 1):
Expand All @@ -280,7 +279,7 @@ def moment(self, order: OrderType, center: CenterType) -> MomentType:
truncated_moment += (multiplying_constant * (gamma_term_lower + gamma_term_upper) * (-normalized_center)
** (order - k))

truncated_moment *= (math.sqrt(self.variance) ** order) / self.normalizing_constant
truncated_moment *= (math.sqrt(self.scale) ** order) / self.normalizing_constant

return VariableMap({self.variable: truncated_moment})

Expand All @@ -289,10 +288,10 @@ def __eq__(self, other):

@property
def representation(self):
return f"N({self.mean},{self.variance} | {self.interval})"
return f"N({self.mean},{self.scale} | {self.interval})"

def __copy__(self):
return self.__class__(self.variable, self.interval, self.mean, self.variance)
return self.__class__(self.variable, self.interval, self.mean, self.scale)

def to_json(self) -> Dict[str, Any]:
return {**super().to_json(), "interval": portion.to_data(self.interval)}
Expand All @@ -302,3 +301,111 @@ def _from_json(cls, data: Dict[str, Any]) -> Self:
variable = Continuous.from_json(data["variable"])
interval = portion.from_data(data["interval"])
return cls(variable, interval, data["mean"], data["variance"])

def robert_rejection_sample(self, amount: int) -> np.ndarray:
"""
Use robert rejection sampling to sample from the truncated Gaussian distribution.
:param amount: The amount of samples to generate
:return: The samples
"""

# handle the case where the distribution is not the standard normal
new_interval = self.interval.replace(lower=(self.interval.lower - self.mean) / np.sqrt(self.scale),
upper=(self.interval.upper - self.mean) / np.sqrt(self.scale))
standard_distribution = self.__class__(self.variable, new_interval, 0, 1)

# if the gaussian is truncated on the left
if standard_distribution.interval.lower <= -float("inf"):
# flip the interval
flipped_interval = new_interval.replace(lower=-new_interval.upper, upper=-new_interval.lower)
standard_distribution.interval = flipped_interval
else:
flipped_interval = None

# if the gaussian is truncated on the right
if standard_distribution.interval.upper >= float("inf"):

# if the interval includes the mean
if standard_distribution.interval.lower < 0:
# fallback plan is rejection sampling
samples = np.array(standard_distribution.rejection_sample(amount))[:, 0]
else:
samples = (standard_distribution.
robert_rejection_sample_from_standard_normal_with_single_truncation(amount))
else:
# sample from double truncated standard normal instead
samples = standard_distribution.robert_rejection_sample_from_standard_normal_with_double_truncation(amount)

# transform samples to this distributions mean and scale
samples *= np.sqrt(self.scale)
samples += self.mean

# if the interval was flipped, flip the samples back
if flipped_interval:
samples *= -1

return samples

def robert_rejection_sample_from_standard_normal_with_double_truncation(self, amount: int) -> np.ndarray:
"""
Use robert rejection sampling to sample from the truncated standard normal distribution.
:param amount: The amount of samples to generate
:return: The samples
"""
assert self.scale == 1 and self.mean == 0
# sample from uniform distribution over this distribution's interval
uniform_samples = np.random.uniform(self.interval.lower, self.interval.upper, amount)

# if the mean in the interval
if 0 in self.interval:
limiting_function = np.exp((uniform_samples**2) / -2)

# if the mean is below the interval
elif self.interval.upper < 0:
limiting_function = np.exp((self.interval.upper**2 - uniform_samples**2) / 2)

# if the mean is above the interval
elif self.interval.lower > 0:
limiting_function = np.exp((self.interval.lower**2 - uniform_samples**2) / 2)
else:
raise ValueError("This should never happen")

# generate standard uniform samples
different_uniform_samples = np.random.uniform(0, 1, amount)

# accept samples that are below the limiting function
accepted_samples = uniform_samples[different_uniform_samples <= limiting_function]

# if any got rejected
if len(accepted_samples) < amount:
# resample the rejected samples
accepted_samples = np.concatenate([accepted_samples,
self.robert_rejection_sample(amount - len(accepted_samples))])

return accepted_samples

def robert_rejection_sample_from_standard_normal_with_single_truncation(self, amount: int) -> np.ndarray:
"""
Sample from a one-sided, truncated standard normal distribution using Robert rejection sampling.
:param amount:
:return:
"""

def translated_exponential_pdf(x: float, scale: float, shift: float) -> float:
return scale * np.exp(-scale * (x - shift))

scale = (self.interval.lower + np.sqrt(self.interval.lower**2 + 4)) / 2
samples_from_shifted_exponential = np.random.exponential(scale=scale, size=amount) + self.interval.lower
limiting_function = translated_exponential_pdf(samples_from_shifted_exponential, scale, self.interval.lower)

uniform_samples = np.random.uniform(0, 1, amount)
accepted_samples = samples_from_shifted_exponential[uniform_samples <= limiting_function]
if len(accepted_samples) < amount:
accepted_samples = np.concatenate([accepted_samples,
self.robert_rejection_sample(amount - len(accepted_samples))])
return accepted_samples

def sample(self, amount: int) -> List[List[float]]:
return self.robert_rejection_sample(amount).reshape(-1, 1).tolist()
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def conditional_from_simple_interval(self, interval: portion.Interval) -> (
Tuple)[Optional[TruncatedGaussianDistribution], float]:
conditional, probability = PMGaussianDistribution.conditional_from_simple_interval(self, interval)
return TruncatedGaussianDistribution(conditional.variable, conditional.interval,
conditional.mean, conditional.variance), probability
conditional.mean, conditional.scale), probability


class TruncatedGaussianDistribution(GaussianDistribution, ContinuousDistribution, PMTruncatedGaussianDistribution):
Expand Down
Loading

0 comments on commit 2179b7b

Please sign in to comment.