Skip to content

Commit

Permalink
extend {klpq,klqp,laplace,map,wake_sleep}.py,gans to Trace
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Jan 19, 2018
1 parent 97d4961 commit bfa81f7
Show file tree
Hide file tree
Showing 21 changed files with 676 additions and 1,996 deletions.
27 changes: 9 additions & 18 deletions edward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,20 @@
bigan_inference,
complete_conditional,
gan_inference,
implicit_klqp,
klpq,
klqp,
reparameterization_klqp,
reparameterization_kl_klqp,
reparameterization_entropy_klqp,
score_klqp,
score_kl_klqp,
score_entropy_klqp,
score_rb_klqp,
klqp_implicit,
klqp_reparameterization,
klqp_reparameterization_kl,
klqp_score,
laplace,
map,
wake_sleep,
wgan_inference,
)
# from edward.inferences import MonteCarlo, HMC, MetropolisHastings, SGLD, SGHMC, Gibbs
from edward.models import RandomVariable, Trace
from edward.util import copy, dot, \
from edward.util import dot, \
get_ancestors, get_blanket, get_children, get_control_variate_coef, \
get_descendants, get_parents, get_siblings, get_variables, \
is_independent, Progbar, random_variables, rbf, \
Expand All @@ -53,29 +49,24 @@
'bigan_inference',
'complete_conditional',
'gan_inference',
'implicit_klqp',
'MonteCarlo',
'HMC',
'MetropolisHastings',
'SGLD',
'SGHMC',
'klpq',
'klqp',
'reparameterization_klqp',
'reparameterization_kl_klqp',
'reparameterization_entropy_klqp',
'score_klqp',
'score_kl_klqp',
'score_entropy_klqp',
'score_rb_klqp',
'klqp_implicit',
'klqp_reparameterization',
'klqp_reparameterization_kl',
'klqp_score',
'laplace',
'map',
'wake_sleep',
'wgan_inference',
'Gibbs',
'RandomVariable',
'Trace',
'copy',
'dot',
'get_ancestors',
'get_blanket',
Expand Down
14 changes: 5 additions & 9 deletions edward/inferences/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from edward.inferences.gan_inference import *
# from edward.inferences.gibbs import *
# from edward.inferences.hmc import *
from edward.inferences.implicit_klqp import *
from edward.inferences.inference import *
from edward.inferences.klpq import *
from edward.inferences.klqp import *
from edward.inferences.klqp_implicit import *
from edward.inferences.laplace import *
from edward.inferences.map import *
# from edward.inferences.metropolis_hastings import *
Expand All @@ -28,18 +28,14 @@
'bigan_inference',
'complete_conditional',
'gan_inference',
'implicit_klqp',
'Gibbs',
'HMC',
'klpq',
'klqp',
'reparameterization_klqp',
'reparameterization_kl_klqp',
'reparameterization_entropy_klqp',
'score_klqp',
'score_kl_klqp',
'score_entropy_klqp',
'score_rb_klqp',
'klqp_implicit',
'klqp_reparameterization',
'klqp_reparameterization_kl',
'klqp_score',
'laplace',
'map',
'MetropolisHastings',
Expand Down
46 changes: 19 additions & 27 deletions edward/inferences/bigan_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import six
import tensorflow as tf

from edward.inferences.inference import (check_and_maybe_build_data,
check_and_maybe_build_latent_vars, transform, check_and_maybe_build_dict, check_and_maybe_build_var_list)
from edward.models import Trace
from edward.inferences.inference import call_function_up_to_args


def bigan_inference(latent_vars=None, data=None, discriminator=None,
auto_transform=True, scale=None, var_list=None,
collections=None):
def bigan_inference(model, variational, discriminator, align_data,
align_latent, collections=None, *args, **kwargs):
"""Adversarially Learned Inference [@dumuolin2017adversarially] or
Bidirectional Generative Adversarial Networks [@donahue2017adversarial]
for joint learning of generator and inference networks.
Expand Down Expand Up @@ -44,20 +43,23 @@ def bigan_inference(latent_vars=None, data=None, discriminator=None,
zf = gen_latent(x_ph)
inference = ed.BiGANInference({z_ph: zf}, {xf: x_ph}, discriminator)
```
`align_latent` must only align one random variable in `model` and
`variational`. `model` must return the generated data.
"""
if not callable(discriminator):
raise TypeError("discriminator must be a callable function.")
latent_vars = check_and_maybe_build_latent_vars(latent_vars)
data = check_and_maybe_build_data(data)
latent_vars, _ = transform(latent_vars, auto_transform)
scale = check_and_maybe_build_dict(scale)
var_list = check_and_maybe_build_var_list(var_list, latent_vars, data)
with Trace() as posterior_trace:
call_function_up_to_args(variational, *args, **kwargs)
with Trace() as model_trace:
x_fake = call_function_up_to_args(model, *args, **kwargs)

x_true = list(six.itervalues(self.data))[0]
x_fake = list(six.iterkeys(self.data))[0]
x_true = align_data(x_fake.name)

z_true = list(six.iterkeys(self.latent_vars))[0]
z_fake = list(six.itervalues(self.latent_vars))[0]
for name, node in six.iteritems(model_trace):
aligned = align_latent(name)
if aligned != name:
z_true = node.value
z_fake = posterior_trace[aligned].value
break

with tf.variable_scope("Disc"):
# xtzf := x_true, z_fake
Expand All @@ -80,14 +82,4 @@ def bigan_inference(latent_vars=None, data=None, discriminator=None,

loss_d = tf.reduce_mean(loss_d) + tf.reduce_sum(reg_terms_d)
loss = tf.reduce_mean(loss) + tf.reduce_sum(reg_terms)

var_list_d = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc")
var_list = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Gen")

grads_d = tf.gradients(loss_d, var_list_d)
grads = tf.gradients(loss, var_list)
grads_and_vars_d = list(zip(grads_d, var_list_d))
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars, loss_d, grads_and_vars_d
return loss, loss_d
2 changes: 1 addition & 1 deletion edward/inferences/conjugacy/conjugacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from edward.inferences.conjugacy.simplify \
import symbolic_suff_stat, full_simplify, expr_contains, reconstruct_expr
from edward.models.random_variables import *
from edward.util import copy, get_blanket
from edward.util import get_blanket


def mvn_diag_from_natural_params(p1, p2):
Expand Down
35 changes: 9 additions & 26 deletions edward/inferences/gan_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import six
import tensorflow as tf

from edward.inferences.inference import (check_and_maybe_build_data,
transform, check_and_maybe_build_dict, check_and_maybe_build_var_list)
from edward.models import Trace
from edward.inferences.inference import call_function_up_to_args


def gan_inference(data=None, discriminator=None,
scale=None, var_list=None, collections=None):
def gan_inference(model, discriminator, align_data,
collections=None, *args, **kwargs):
"""Parameter estimation with GAN-style training
[@goodfellow2014generative].
Expand Down Expand Up @@ -55,18 +55,11 @@ def gan_inference(data=None, discriminator=None,
Function (with parameters) to discriminate samples. It should
output logit probabilities (real-valued) and not probabilities
in $[0, 1]$.
var_list: list of tf.Variable, optional.
List of TensorFlow variables to optimize over (in the generative
model). Default is all trainable variables that `data` depends on.
`model` must return the generated data.
"""
if not callable(discriminator):
raise TypeError("discriminator must be a callable function.")
data = check_and_maybe_build_data(data)
scale = check_and_maybe_build_dict(scale)
var_list = check_and_maybe_build_var_list(var_list, {}, data)

x_true = list(six.itervalues(data))[0]
x_fake = list(six.iterkeys(data))[0]
x_fake = call_function_up_to_args(model, *args, **kwargs)
x_true = align_data(x_fake.name)
with tf.variable_scope("Disc"):
d_true = discriminator(x_true)

Expand All @@ -90,14 +83,4 @@ def gan_inference(data=None, discriminator=None,
labels=tf.ones_like(d_fake), logits=d_fake)
loss_d = tf.reduce_mean(loss_d) + tf.reduce_sum(reg_terms_d)
loss = tf.reduce_mean(loss) + tf.reduce_sum(reg_terms)

var_list_d = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc")
if var_list is None:
var_list = [v for v in tf.trainable_variables() if v not in var_list_d]

grads_d = tf.gradients(loss_d, var_list_d)
grads = tf.gradients(loss, var_list)
grads_and_vars_d = list(zip(grads_d, var_list_d))
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars, loss_d, grads_and_vars_d
return loss, loss_d
1 change: 0 additions & 1 deletion edward/inferences/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from collections import OrderedDict
from edward.inferences.monte_carlo import MonteCarlo
from edward.models import RandomVariable
from edward.util import copy

try:
from edward.models import Normal, Uniform
Expand Down
Loading

0 comments on commit bfa81f7

Please sign in to comment.