From d6345d85695e4adb66cadf29aae5a4e53efca245 Mon Sep 17 00:00:00 2001 From: Vincent Dumoulin Date: Thu, 6 Nov 2014 14:47:40 -0500 Subject: [PATCH] Add NADE model --- code/pylearn2/costs/__init__.py | 0 code/pylearn2/costs/nade.py | 37 ++ .../models/directed_probabilistic/__init__.py | 299 +++++++++++ .../models/directed_probabilistic/nade.py | 475 ++++++++++++++++++ code/pylearn2/scripts/nade.yaml | 46 ++ code/pylearn2/utils/unrolled_scan.py | 131 +++++ 6 files changed, 988 insertions(+) create mode 100644 code/pylearn2/costs/__init__.py create mode 100644 code/pylearn2/costs/nade.py create mode 100644 code/pylearn2/models/directed_probabilistic/__init__.py create mode 100644 code/pylearn2/models/directed_probabilistic/nade.py create mode 100644 code/pylearn2/scripts/nade.yaml create mode 100755 code/pylearn2/utils/unrolled_scan.py diff --git a/code/pylearn2/costs/__init__.py b/code/pylearn2/costs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/code/pylearn2/costs/nade.py b/code/pylearn2/costs/nade.py new file mode 100644 index 0000000..db82863 --- /dev/null +++ b/code/pylearn2/costs/nade.py @@ -0,0 +1,37 @@ +""" +Neural autoregressive density estimator (NADE)-related costs +""" +__authors__ = "Vincent Dumoulin" +__copyright__ = "Copyright 2014, Universite de Montreal" +__credits__ = ["Vincent Dumoulin"] +__license__ = "3-clause BSD" +__maintainer__ = "Vincent Dumoulin" + + +import theano.tensor as T +from pylearn2.costs.cost import Cost, DefaultDataSpecsMixin +from pylearn2.utils import wraps + + +class NADECost(DefaultDataSpecsMixin, Cost): + """ + NADE negative log-likelihood + """ + @wraps(Cost.expr) + def expr(self, model, data, ** kwargs): + self.get_data_specs(model)[0].validate(data) + X = data + return -T.mean(model.log_likelihood(X)) + + +class CNADECost(DefaultDataSpecsMixin, Cost): + """ + CNADE negative log-likelihood + """ + supervised = True + + @wraps(Cost.expr) + def expr(self, model, data, ** kwargs): + self.get_data_specs(model)[0].validate(data) + X, Y = data + return -T.mean(model.log_likelihood(X, Y)) diff --git a/code/pylearn2/models/directed_probabilistic/__init__.py b/code/pylearn2/models/directed_probabilistic/__init__.py new file mode 100644 index 0000000..fdffcc8 --- /dev/null +++ b/code/pylearn2/models/directed_probabilistic/__init__.py @@ -0,0 +1,299 @@ +""" +Directed probabilistic models +""" +__authors__ = "Vincent Dumoulin" +__copyright__ = "Copyright 2014, Universite de Montreal" +__credits__ = ["Vincent Dumoulin"] +__license__ = "3-clause BSD" +__maintainer__ = "Vincent Dumoulin" + + +import numpy +import theano.tensor as T +from theano.compat import OrderedDict +from theano.tensor.shared_randomstreams import RandomStreams +from pylearn2.models.model import Model +from pylearn2.utils import sharedX +from pylearn2.utils import wraps +from pylearn2.space import VectorSpace, NullSpace + + +theano_rng = RandomStreams(seed=23541) + + +class Distribution(Model): + """ + WRITEME + """ + def _initialize_weights(self, dim_0, dim_1): + """ + Initialize a (dim_0, dim_1)-shaped weight matrix + + Parameters + ---------- + dim_0 : int + First dimension of the weights matrix + dim_1 : int + Second dimension of the weights matrix + + Returns + ------- + rval : `numpy.ndarray` + A (dim_0, dim_1)-shaped, properly initialized weights matrix + """ + rval = (2 * numpy.random.normal(size=(dim_0, dim_1)) - 1) / dim_0 + return rval + + def get_layer_monitoring_channels(self): + rval = OrderedDict() + + for param in self.get_params(): + rval[param.name + "_min"] = param.min() + rval[param.name + "_max"] = param.max() + rval[param.name + "_mean"] = param.mean() + + return rval + + +class JointDistribution(Distribution): + def _sample(self, num_samples): + raise NotImplementedError() + + def sample(self, num_samples, return_log_likelihood=False, + return_probabilities=False): + """ + Samples from the modeled joint distribution p(x) + + Parameters + ---------- + num_samples : int + Number of samples to draw + return_log_likelihood : bool, optional + If `True`, returns the log-likelihood of the samples in addition to + the samples themselves. Defaults to `False`. + return_probabilities : bool, optional + If `True`, returns the probabilities from which samples were drawn + in addition to the samples themselves. Defaults to `False`. + + Returns + ------- + samples : tensor-like + Batch of `num_samples` samples from p(x) + log_likelihood : tensor-like, optional + Log-likelihood of the drawn samples according to p(x). Returned + only if `return_log_likelihood` is set to `True`. + probabilities : tensor-like, optional + Probabilities from which the samples were drawn. Returned only if + `return_probabilities` is set to `True`. + """ + rval = self._sample(num_samples=num_samples) + samples, log_likelihood, probabilities = rval + + if not return_log_likelihood and not return_probabilities: + return samples + else: + rval = [samples] + if return_log_likelihood: + rval.append(log_likelihood) + if return_probabilities: + rval.append(probabilities) + return tuple(rval) + + def _log_likelihood(self, X): + raise NotImplementedError() + + def log_likelihood(self, X): + """ + Computes the log-likelihood of a batch of observed examples on a + per-example basis + + Parameters + ---------- + X : tensor-like + Batch of observed examples + + Returns + ------- + rval : tensor-like + Log-likelihood for the batch of visible examples, with shape + (X.shape[0],) + """ + return self._log_likelihood(X=X) + + +class ConditionalDistribution(Distribution): + def _sample(self, num_samples): + raise NotImplementedError() + + def sample(self, Y, return_log_likelihood=False, + return_probabilities=False): + """ + Samples from the conditional distribution p(x | y) + + Parameters + ---------- + return_log_likelihood : bool, optional + If `True`, returns the conditional log-likelihood of the samples in + addition to the samples themselves. Defaults to `False`. + return_probabilities : bool, optional + If `True`, returns the conditional probabilities from which samples + were drawn in addition to the samples themselves. Defaults to + `False`. + + Returns + ------- + samples : tensor-like + Batch of `num_samples` samples from p(x) + log_likelihood : tensor-like, optional + Log-likelihood of the drawn samples according to p(x | y). Returned + only if `return_log_likelihood` is set to `True`. + probabilities : tensor-like, optional + Probabilities from which the samples were drawn. Returned only if + `return_probabilities` is set to `True`. + """ + rval = self._sample(Y=Y) + samples, log_likelihood, probabilities = rval + + if not return_log_likelihood and not return_probabilities: + return samples + else: + rval = [samples] + if return_log_likelihood: + rval.append(log_likelihood) + if return_probabilities: + rval.append(probabilities) + return tuple(rval) + + def _log_likelihood(self, X, Y): + raise NotImplementedError() + + def log_likelihood(self, X, Y): + """ + Computes the conditional log-likelihood of a batch of observed examples + on a per-example basis + + Parameters + ---------- + X : tensor-like + Batch of observed examples + Y : tensor-like + Batch of conditioning examples + + Returns + ------- + rval : tensor-like + Conditional Log-likelihood for the batch of visible examples, with + shape (X.shape[0],) + """ + return self._log_likelihood(X=X, Y=Y) + + +class ProductOfBernoulli(JointDistribution): + """ + Random binary vector whose distribution is a product of Bernoulli + distributions, i.e. + + p(v) = \prod_i v_i ** p_i * (1 - v_i) ** (1 - p_i) + """ + def __init__(self, dim): + """ + Parameters + ---------- + dim : int + Dimension of the random binary vector + """ + self.dim = dim + + # Parameter initialization + b_value = numpy.zeros(self.dim) + self.b = sharedX(b_value, 'b') + self.p = T.nnet.sigmoid(self.b) + + # Space initialization + self.input_space = NullSpace() + self.output_space = VectorSpace(dim=self.dim) + + def _sample(self, num_samples): + samples = theano_rng.uniform((num_samples, self.dim)) <= self.p + log_likelihood = self.log_likelihood(samples) + probabilities = T.zeros_like(samples) + self.p + return samples, log_likelihood, probabilities + + def _log_likelihood(self, X): + return (X * T.log(self.p) + (1 - X) * T.log(1 - self.p)).sum(axis=1) + + @wraps(Model.get_params) + def get_params(self): + return [self.b] + + +class StochasticSigmoid(ConditionalDistribution): + """ + Implements the conditional distribution of a random binary vector x given + an input vector y as a product of Bernoulli distributions, i.e. + + p(x | y) = \prod_i p(x_i | y), + + where + + p(x_i | y) = sigmoid(y.W_i + b_i) + """ + def __init__(self, dim, dim_cond, clamp_sigmoid=False): + """ + Parameters + ---------- + dim : int + Dimension of the modeled vector x + dim_cond : int + Dimension of the conditioning vector y + """ + self.dim_cond = dim_cond + self.dim = dim + self.clamp_sigmoid = clamp_sigmoid + + # Bias initialization + b_value = numpy.zeros(self.dim) + self.b = sharedX(b_value, 'b') + + # Weights initialization + W_value = self._initialize_weights(self.dim_cond, self.dim) + self.W = sharedX(W_value, 'W') + + # Space initialization + self.input_space = VectorSpace(dim=self.dim_cond) + self.target_space = VectorSpace(dim=self.dim) + + def _sigmoid(self, x): + """ + WRITEME + + Parameters + ---------- + x : WRITEME + """ + if self.clamp_sigmoid: + return T.nnet.sigmoid(x)*0.9999 + 0.000005 + else: + return T.nnet.sigmoid(x) + + def _sample(self, Y): + batch_size = Y.shape[0] + probabilities = self._sigmoid(T.dot(Y, self.W) + self.b) + samples = theano_rng.uniform((batch_size, self.dim)) <= probabilities + log_likelihood = ( + samples * T.log(probabilities) + + (1 - samples) * T.log(1 - probabilities) + ).sum(axis=1) + + return samples, log_likelihood, probabilities + + def _log_likelihood(self, X, Y): + p = self._sigmoid(T.dot(Y, self.W) + self.b) + return (X * T.log(p) + (1 - X) * T.log(1 - p)).sum(axis=1) + + @wraps(Model.get_params) + def get_params(self): + return [self.W, self.b] + + def get_weights(self): + return self.W.get_value() diff --git a/code/pylearn2/models/directed_probabilistic/nade.py b/code/pylearn2/models/directed_probabilistic/nade.py new file mode 100644 index 0000000..2b3fcaa --- /dev/null +++ b/code/pylearn2/models/directed_probabilistic/nade.py @@ -0,0 +1,475 @@ +""" +Neural autoregressive density estimator (NADE) implementation +""" +__authors__ = "Vincent Dumoulin" +__copyright__ = "Copyright 2014, Universite de Montreal" +__credits__ = ["Jorg Bornschein", "Vincent Dumoulin"] +__license__ = "3-clause BSD" +__maintainer__ = "Vincent Dumoulin" + + +import numpy +import theano +import theano.tensor as T +from theano.tensor.shared_randomstreams import RandomStreams +from pylearn2.models.model import Model +from pylearn2.utils import sharedX +from pylearn2.space import VectorSpace +# from research.code.pylearn2.utils.unrolled_scan import unrolled_scan +from research.code.pylearn2.models.directed_probabilistic import ( + JointDistribution, ConditionalDistribution +) + + +theano_rng = RandomStreams(seed=2341) + + +class NADEBase(Model): + """ + WRITEME + """ + def __init__(self, dim, dim_hid, clamp_sigmoid=False, unroll_scan=1): + """ + Parameters + ---------- + dim : int + Number of observed binary variables + dim_hid : int + Number of latent binary variables + clamp_sigmoid : bool, optional + WRITEME. Defaults to `False`. + unroll_scan : int, optional + WRITEME. Defaults to 1. + """ + super(NADEBase, self).__init__() + + self.dim = dim + self.dim_hid = dim_hid + self.clamp_sigmoid = clamp_sigmoid + self.unroll_scan = unroll_scan + + self.input_space = VectorSpace(dim=self.dim) + + # Visible biases + b_value = numpy.zeros(self.dim) + self.b = sharedX(b_value, 'b') + # Hidden biases + c_value = numpy.zeros(self.dim_hid) + self.c = sharedX(c_value, 'c') + # Encoder weights + W_value = self._initialize_weights(self.dim, self.dim_hid) + self.W = sharedX(W_value, 'W') + # Decoder weights + V_value = self._initialize_weights(self.dim_hid, self.dim) + self.V = sharedX(V_value, 'V') + + def _initialize_weights(self, dim_0, dim_1): + """ + Initialize a (dim_0, dim_1)-shaped weight matrix + + Parameters + ---------- + dim_0 : int + First dimension of the weights matrix + dim_1 : int + Second dimension of the weights matrix + + Returns + ------- + rval : `numpy.ndarray` + A (dim_0, dim_1)-shaped, properly initialized weights matrix + """ + rval = (2 * numpy.random.normal(size=(dim_0, dim_1)) - 1) / dim_0 + return rval + + def sigmoid(self, x): + """ + WRITEME + + Parameters + ---------- + x : WRITEME + """ + if self.clamp_sigmoid: + return T.nnet.sigmoid(x)*0.9999 + 0.000005 + else: + return T.nnet.sigmoid(x) + + def get_params(self): + """ + Returns + ------- + params : list of tensor-like + The model's parameters + """ + return [self.b, self.c, self.W, self.V] + + def get_weights(self): + """ + Aliases to `NADE.get_encoder_weights` + """ + return self.get_encoder_weights() + + def set_weights(self, weights): + """ + Aliases to `NADE.set_encoder_weights` + """ + self.set_encoder_weights(weights) + + def get_encoder_weights(self): + """ + Returns + ------- + rval : `numpy.ndarray` + Encoder weights + """ + return self.W.get_value() + + def set_encoder_weights(self, weights): + """ + Sets encoder weight values + + Parameters + ---------- + weights : `numpy.ndarray` + Encoder weight values to assign to self.W + """ + self.W.set_value(weights) + + def get_decoder_weights(self): + """ + Returns + ------- + rval : `numpy.ndarray` + Decoder weights + """ + return self.V.get_value() + + def set_decoder_weights(self, weights): + """ + Sets decoder weight values + + Parameters + ---------- + weights : `numpy.ndarray` + Decoder weight values to assign to self.V + """ + self.V.set_value(weights) + + def get_visible_biases(self): + """ + Returns + ------- + rval : `numpy.ndarray` + Visible biases + """ + return self.b.get_value() + + def set_visible_biases(self, biases): + """ + Sets visible bias values + + Parameters + ---------- + biases : `numpy.ndarray` + Visible bias values to assign to self.b + """ + self.b.set_value(biases) + + def get_hidden_biases(self): + """ + Returns + ------- + rval : `numpy.ndarray` + Hidden biases + """ + return self.c.get_value() + + def set_hidden_biases(self, biases): + """ + Sets hidden bias values + + Parameters + ---------- + biases : `numpy.ndarray` + Hidden bias values to assign to self.c + """ + self.c.set_value(biases) + + def _base_log_likelihood(self, X, W, V, b, c): + """ + Computes the log-likelihood of a batch of visible examples + + Parameters + ---------- + X : tensor-like + Batch of visible examples + W : tensor-like + Encoder weights + V : tensor-like + Decoder weights + b : tensor-like + Visible biases + c : tensor-like + Hidden biases + + Returns + ------- + rval : tensor-like + Log-likelihood for the batch of visible examples + """ + # Transformation matrix. A 3D tensor of the form + # + # [[[ 0, 0, ..., 0], + # [X[0, 0], 0, ..., 0], + # [X[0, 0], X[0, 1], ..., 0], + # [X[0, 0], X[0, 1], ..., X[0, d]]], + # ... + # [[ 0, 0, ..., 0], + # [X[n, 0], 0, ..., 0], + # [X[n, 0], X[n, 1], ..., 0], + # [X[n, 0], X[n, 1], ..., X[n, d]]]] + # + # Its purpose is make the `W_{.,