-
Notifications
You must be signed in to change notification settings - Fork 759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Marginal Util Function #778
base: master
Are you sure you want to change the base?
Changes from 17 commits
9788171
0525026
bee3b3b
4dbb36f
c4d7b9e
b7f4e86
5c78557
e652079
7b4020b
4e8c946
0f26cd7
c81e381
48600dd
f4b69d9
ac90792
f473a6a
4a59832
0053375
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
'get_session', | ||
'get_siblings', | ||
'get_variables', | ||
'marginal', | ||
'Progbar', | ||
'random_variables', | ||
'rbf', | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -778,3 +778,66 @@ def transform(x, *args, **kwargs): | |
new_x = TransformedDistribution(x, bij, *args, **kwargs) | ||
new_x.support = new_support | ||
return new_x | ||
|
||
|
||
def marginal(x, n): | ||
"""Performs a full graph sample on the provided random variable. | ||
|
||
Given a random variable and a sample size, adds an additional sample | ||
dimension to the root random variables in x's graph, and samples from | ||
a new graph in terms of that sample size. | ||
|
||
Args: | ||
x : RandomVariable. | ||
Random variable to perform full graph sample on. | ||
n : tf.Tensor or int | ||
The size of the full graph sample to take. | ||
|
||
Returns: | ||
tf.Tensor. | ||
The fully sampled values from x, of shape [n] + x.shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
#### Examples | ||
|
||
```python | ||
ed.get_session() | ||
loc = Normal(0.0, 100.0) | ||
y = Normal(loc, 0.0001) | ||
conditional_sample = y.sample(50) | ||
marginal_sample = ed.marginal(y, 50) | ||
|
||
np.std(conditional_sample.eval()) | ||
0.000100221 | ||
|
||
np.std(marginal_sample.eval()) | ||
106.55982 | ||
``` | ||
|
||
#### Notes | ||
|
||
The current implementation only works for graphs of RVs that don't use | ||
the `sample_shape` kwarg. | ||
""" | ||
ancestors = get_ancestors(x) | ||
if any([rv.sample_shape != () for rv in ancestors]) or x.sample_shape != (): | ||
raise NotImplementedError("`marginal` doesn't support graphs of RVs " | ||
"with non scalar sample_shape args.") | ||
elif ancestors == []: | ||
old_roots = [x] | ||
else: | ||
old_roots = [rv for rv in ancestors if get_ancestors(rv) == []] | ||
|
||
new_roots = [] | ||
for rv in old_roots: | ||
new_rv = copy(rv) | ||
new_rv._sample_shape = tf.TensorShape(n).concatenate(new_rv._sample_shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also came up when I looked into #774, and I think would need to be solved at the same time. Sample shape needs a TensorShape, and there's no nice way to turn tensor So I think either Let me know if you'd prefer this is implemented in the same PR, I'll push the other changes. |
||
new_rv._value = new_rv.sample(new_rv._sample_shape) | ||
new_roots.append(new_rv) | ||
dict_swap = dict(zip(old_roots, new_roots)) | ||
x_full = copy(x, dict_swap, replace_itself=True) | ||
if x_full.shape[1:] != x.shape: | ||
print(x_full.shape) | ||
print(x.shape) | ||
raise ValueError('Could not transform graph for bulk sampling.') | ||
|
||
return x_full |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import edward as ed | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from edward.models import Normal, InverseGamma | ||
from tensorflow.contrib.distributions import bijectors | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found this test very intuitive. Great work. One note: you don't use the |
||
|
||
|
||
class test_marginal_class(tf.test.TestCase): | ||
|
||
def test_bad_graph(self): | ||
with self.test_session(): | ||
loc = Normal(tf.zeros(5), 5.0) | ||
y_loc = tf.expand_dims(loc, 1) # this displaces the sample dimension | ||
inv_scale = Normal(tf.zeros(3), 1.0) | ||
y_scale = tf.expand_dims(tf.nn.softplus(inv_scale), 0) | ||
y = Normal(y_loc, y_scale) | ||
with self.assertRaises(ValueError): | ||
ed.marginal(y, 20) | ||
|
||
def test_sample_arg(self): | ||
with self.test_session(): | ||
y = Normal(0.0, 1.0, sample_shape=10) | ||
with self.assertRaises(NotImplementedError): | ||
ed.marginal(y, 20) | ||
|
||
def test_sample_arg_ancestor(self): | ||
with self.test_session(): | ||
x = Normal(0.0, 1.0, sample_shape=10) | ||
y = Normal(x, 0.0) | ||
with self.assertRaises(NotImplementedError): | ||
ed.marginal(y, 20) | ||
|
||
def test_no_ancestor(self): | ||
with self.test_session(): | ||
y = Normal(0.0, 1.0) | ||
sample = ed.marginal(y, 4) | ||
self.assertEqual(sample.shape, [4]) | ||
|
||
def test_no_ancestor_batch(self): | ||
with self.test_session(): | ||
y = Normal(tf.zeros([2, 3, 4]), 1.0) | ||
sample = ed.marginal(y, 5) | ||
self.assertEqual(sample.shape, [5, 2, 3, 4]) | ||
|
||
def test_single_ancestor(self): | ||
with self.test_session(): | ||
loc = Normal(0.0, 1.0) | ||
y = Normal(loc, 1.0) | ||
sample = ed.marginal(y, 4) | ||
self.assertEqual(sample.shape, [4]) | ||
|
||
def test_single_ancestor_batch(self): | ||
with self.test_session(): | ||
loc = Normal(tf.zeros([2, 3, 4]), 1.0) | ||
y = Normal(loc, 1.0) | ||
sample = ed.marginal(y, 5) | ||
self.assertEqual(sample.shape, [5, 2, 3, 4]) | ||
|
||
def test_sample_passthrough(self): | ||
with self.test_session(): | ||
loc = Normal(0.0, 100.0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's with very low probability this can produce a false negative/positive, but in general you should always set seed in tests when checking randomness. |
||
y = Normal(loc, 0.0001) | ||
conditional_sample = y.sample(50) | ||
marginal_sample = ed.marginal(y, 50) | ||
self.assertTrue(np.std(conditional_sample.eval()) < 1.0) | ||
self.assertTrue(np.std(marginal_sample.eval()) > 1.0) | ||
|
||
def test_multiple_ancestors(self): | ||
with self.test_session(): | ||
loc = Normal(0.0, 1.0) | ||
scale = InverseGamma(1.0, 1.0) | ||
y = Normal(loc, scale) | ||
sample = ed.marginal(y, 4) | ||
self.assertEqual(sample.shape, [4]) | ||
|
||
def test_multiple_ancestors_batch(self): | ||
with self.test_session(): | ||
loc = Normal(tf.zeros(5), 1.0) | ||
scale = InverseGamma(tf.ones(5), 1.0) | ||
y = Normal(loc, scale) | ||
sample = ed.marginal(y, 4) | ||
self.assertEqual(sample.shape, [4, 5]) | ||
|
||
def test_multiple_ancestors_batch_broadcast(self): | ||
with self.test_session(): | ||
loc = Normal(tf.zeros([5, 1]), 1.0) | ||
scale = InverseGamma(tf.ones([1, 6]), 1.0) | ||
y = Normal(loc, scale) | ||
sample = ed.marginal(y, 4) | ||
self.assertEqual(sample.shape, [4, 5, 6]) | ||
|
||
def test_multiple_ancestors_failed_broadcast(self): | ||
with self.test_session(): | ||
loc = Normal(tf.zeros([5, 1]), 1.0) | ||
scale = InverseGamma(tf.ones([6]), 1.0) | ||
y = Normal(loc, scale) | ||
with self.assertRaises(ValueError): | ||
sample = ed.marginal(y, 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functions should be placed according to alphabetical ordering of function names.