Skip to content

Commit

Permalink
Started to work on sampling using vmap instead of loops
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Aug 19, 2024
1 parent 30268fb commit 579cec8
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/probabilistic_model/learning/torch/uniform_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,29 @@ def sample_from_frequencies(self, frequencies: torch.Tensor) -> torch.Tensor:
result.append(samples)
return torch.stack(result).coalesce()

def sample_from_frequencies_vmap(self, frequencies: torch.Tensor) -> torch.Tensor:
max_frequency = max(frequencies)

values_for_sparse_tensor = torch.distributions.Uniform(low=0, high=1).sample((sum(frequencies), ))
indices = torch.repeat_interleave(torch.arange(len(frequencies)), frequencies)

# generate the second dimension indexing dimension of the sparse tensor

def concatenate_without_loop(frequencies):
cumulative_sums = torch.cumsum(torch.tensor(frequencies), dim=0)
second_index_dimension = torch.arange(cumulative_sums[-1]).long()
second_index_dimension -= torch.repeat_interleave(cumulative_sums[:-1], frequencies[:-1])
return second_index_dimension

second_index_dimension = torch.concatenate([torch.arange(frequency) for frequency in frequencies])
print(second_index_dimension)
print(second_index_dimension.shape)
indices = torch.stack([indices, torch.arange(sum(frequencies))])

print(indices.shape)
exit()
return samples

def log_conditional_from_simple_interval(self, interval: SimpleInterval) -> Tuple[Self, torch.Tensor]:
probabilities = self.probability_of_simple_event(SimpleEvent({self.variable: interval})).log()
intersections = [interval.intersection_with(SimpleInterval(lower.item(), upper.item(),
Expand Down
14 changes: 14 additions & 0 deletions test/test_torch/test_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from random_events.interval import closed, open
from random_events.product_algebra import SimpleEvent
from random_events.variable import Continuous
from sympy.physics.units import frequency
from torch.testing import assert_close

from probabilistic_model.learning.torch import SumLayer
Expand Down Expand Up @@ -95,5 +96,18 @@ def test_sampling(self):
self.assertTrue(all(l_n0 > 0))
self.assertTrue(all(l_n1 > 0))


class UniformSamplingSpeedTest(unittest.TestCase):
x: Continuous = Continuous("x")

p_x = UniformLayer(x, torch.Tensor([[0, 1]] * 100))

def test_sampling(self):
frequencies = torch.full((self.p_x.number_of_nodes, ), 10)
frequencies[0] = 20
frequencies[1] = 2
# samples = self.p_x.sample_from_frequencies(frequencies)
samples_vmap = self.p_x.sample_from_frequencies_vmap(frequencies)

if __name__ == '__main__':
unittest.main()

0 comments on commit 579cec8

Please sign in to comment.