Source code for cdt.utils.torch

"""PyTorch utilities for models.

Author: Diviyan Kalainathan, Olivier Goudet
Date: 09/3/2018

.. MIT License
..
.. Copyright (c) 2018 Diviyan Kalainathan
..
.. Permission is hereby granted, free of charge, to any person obtaining a copy
.. of this software and associated documentation files (the "Software"), to deal
.. in the Software without restriction, including without limitation the rights
.. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
.. copies of the Software, and to permit persons to whom the Software is
.. furnished to do so, subject to the following conditions:
..
.. The above copyright notice and this permission notice shall be included in all
.. copies or substantial portions of the Software.
..
.. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
.. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
.. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
.. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
.. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
.. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
.. SOFTWARE.
"""
import math
import torch as th
from torch.nn import Parameter
from torch.nn.modules.batchnorm import _BatchNorm
import torch.distributions.relaxed_bernoulli as relaxed_bernoulli
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import SigmoidTransform,AffineTransform
from torch.distributions.uniform import Uniform


def _sample_gumbel(shape, eps=1e-10, out=None):
    """
    Implementation of pytorch.
    (https://github.com/pytorch/pytorch/blob/e4eee7c2cf43f4edba7a14687ad59d3ed61d9833/torch/nn/functional.py)
    Sample from Gumbel(0, 1)
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
    (MIT license)
    """
    U = out.resize_(shape).uniform_() if out is not None else th.rand(shape)
    return - th.log(eps - th.log(U + eps))


def _gumbel_softmax_sample(logits, tau=1, eps=1e-10):
    """
    Implementation of pytorch.
    (https://github.com/pytorch/pytorch/blob/e4eee7c2cf43f4edba7a14687ad59d3ed61d9833/torch/nn/functional.py)
    Draw a sample from the Gumbel-Softmax distribution
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
    (MIT license)
    """
    dims = logits.dim()
    gumbel_noise = _sample_gumbel(logits.size(), eps=eps, out=logits.data.new())
    y = logits + gumbel_noise
    return th.softmax(y / tau, dims-1)


[docs]def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): """ Implementation of pytorch. (https://github.com/pytorch/pytorch/blob/e4eee7c2cf43f4edba7a14687ad59d3ed61d9833/torch/nn/functional.py) Sample from the Gumbel-Softmax distribution and optionally discretize. Args: logits: `[batch_size, n_class]` unnormalized log-probs tau: non-negative scalar temperature hard: if ``True``, take `argmax`, but differentiate w.r.t. soft sample y Returns: [batch_size, n_class] sample from the Gumbel-Softmax distribution. If hard=True, then the returned sample will be one-hot, otherwise it will be a probability distribution that sums to 1 across classes Constraints: - this implementation only works on batch_size x num_features tensor for now based on https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb , (MIT license) """ shape = logits.size() assert len(shape) == 2 y_soft = _gumbel_softmax_sample(logits, tau=tau, eps=eps) if hard: _, k = y_soft.data.max(-1) # this bit is based on # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5 y_hard = logits.data.new(*shape).zero_().scatter_(-1, k.view(-1, 1), 1.0) # this cool bit of code achieves two things: # - makes the output value exactly one-hot (since we add then # subtract y_soft value) # - makes the gradient equal to y_soft gradient (since we strip # all other gradients) y = y_hard - y_soft.data + y_soft else: y = y_soft return y
def _sample_logistic(shape, out=None): U = out.resize_(shape).uniform_() if out is not None else th.rand(shape) #U2 = out.resize_(shape).uniform_() if out is not None else th.rand(shape) return th.log(U) - th.log(1-U) def _sigmoid_sample(logits, tau=1): """ Implementation of Bernouilli reparametrization based on Maddison et al. 2017 """ dims = logits.dim() logistic_noise = _sample_logistic(logits.size(), out=logits.data.new()) y = logits + logistic_noise return th.sigmoid(y / tau) def gumbel_sigmoid(logits, ones_tensor, zeros_tensor, tau=1, hard=False): shape = logits.size() y_soft = _sigmoid_sample(logits, tau=tau) if hard: y_hard = th.where(y_soft > 0.5, ones_tensor, zeros_tensor) y = y_hard.data - y_soft.data + y_soft else: y = y_soft return y
[docs]class ChannelBatchNorm1d(_BatchNorm): r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . .. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta The mean and standard-deviation are calculated per-dimension over the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size `C` (where `C` is the input size). By default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default :attr:`momentum` of 0.1. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well. .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Because the Batch Normalization is done over the `C` dimension, computing statistics on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Can be set to ``None`` for cumulative moving average (i.e. simple average). Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``True`` Shape: - Input: :math:`(N, C)` or :math:`(N, C, L)` - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) Examples:: >>> # With Learnable Parameters >>> m = nn.BatchNorm1d(100) >>> # Without Learnable Parameters >>> m = nn.BatchNorm1d(100, affine=False) >>> input = torch.randn(20, 100) >>> output = m(input) .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: https://arxiv.org/abs/1502.03167 """ def __init__(self, num_channels, num_features, *args, **kwargs): super(ChannelBatchNorm1d, self).__init__(num_channels*num_features, *args, **kwargs) self.num_channels = num_channels self.num_features = num_features def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' .format(input.dim()))
[docs] def forward(self, input): _input = input.contiguous().view(-1, self.num_channels * self.num_features) output = super(ChannelBatchNorm1d, self).forward(_input) return output.view(-1, self.num_channels, self.num_features)
class ParallelBatchNorm1d(th.nn.Module): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(ParallelBatchNorm1d, self).__init__() # self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: self.weight = Parameter(th.Tensor(num_features)) self.bias = Parameter(th.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer('running_mean', th.zeros(1)) self.register_buffer('running_var', th.ones(1)) self.register_buffer('num_batches_tracked', th.tensor(0, dtype=th.long)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.register_parameter('num_batches_tracked', None) self.reset_parameters() def reset_running_stats(self): if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_() def reset_parameters(self): self.reset_running_stats() if self.affine: th.nn.init.uniform_(self.weight) th.nn.init.zeros_(self.bias) def forward(self, input): exponential_average_factor = 0.0 if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum # calculate running estimates if self.training: mean = input.mean([0]) # use biased var in train var = input.var([0], unbiased=False) n = input.numel() with th.no_grad(): self.running_mean = exponential_average_factor * mean\ + (1 - exponential_average_factor) * self.running_mean # update running_var with unbiased var self.running_var = exponential_average_factor * var * n / (n - 1)\ + (1 - exponential_average_factor) * self.running_var else: mean = self.running_mean var = self.running_var input = (input - mean) / (th.sqrt(var + self.eps)) if self.affine: input = input * self.weight + self.bias return input def functional_linear3d(input, weight, bias=None): r""" Apply a linear transformation to the incoming data: :math:`y = xA^T + b`. Shape: - Input: :math:`(N, *, in\_features)` where `*` means any number of additional dimensions - Weight: :math:`(out\_features, in\_features)` - Bias: :math:`(out\_features)` - Output: :math:`(N, *, out\_features)` """ output = input.transpose(0, 1).matmul(weight) if bias is not None: output += bias.unsqueeze(1) return output.transpose(0, 1)
[docs]class Linear3D(th.nn.Module): r"""Applies a linear transformation to the incoming data: :math:`y = Ax + b`. Args: in_features: size of each input sample out_features: size of each output sample bias: If set to False, the layer will not learn an additive bias. Default: ``True`` Shape: - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of additional dimensions - Output: :math:`(N, *, out\_features)` where all but the last dimension are the same shape as the input. Attributes: weight: the learnable weights of the module of shape `(out_features x in_features)` bias: the learnable bias of the module of shape `(out_features)` Examples:: >>> m = nn.Linear(3, 20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) """ def __init__(self, channels, in_features, out_features, batch_size=-1, bias=True, noise=False): super(Linear3D, self).__init__() self.in_features = in_features self.out_features = out_features self.channels = channels if noise: self.in_features += 1 self.weight = Parameter(th.Tensor(channels, self.in_features, out_features)) if bias: self.bias = Parameter(th.Tensor(channels, out_features)) else: self.register_parameter('bias', None) if noise: self.register_buffer("noise", th.Tensor(batch_size, channels, 1)) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv)
[docs] def forward(self, input, adj_matrix=None, permutation_matrix=None): input_ = [input] if input.dim() == 2: if permutation_matrix is not None: input_.append(input.unsqueeze(1).expand([input.shape[0], self.channels, permutation_matrix.shape[1]])) elif hasattr(self, "noise"): input_.append(input.unsqueeze(1).expand([input.shape[0], self.channels, self.in_features - 1 ])) else: input_.append(input.unsqueeze(1).expand([input.shape[0], self.channels, self.in_features])) if adj_matrix is not None and permutation_matrix is not None: input_.append((input_[-1].transpose(0, 1) @ (adj_matrix.t().unsqueeze(2) * permutation_matrix)).transpose(0, 1)) elif adj_matrix is not None: input_.append(input_[-1] * adj_matrix.t().unsqueeze(0)) elif permutation_matrix is not None: input_.append((input_[-1].transpose(0, 1) @ permutation_matrix).t()) if hasattr(self, 'noise'): self.noise.normal_() input_.append(th.cat([input_[-1], self.noise], 2)) return functional_linear3d(input_[-1], self.weight, self.bias)
[docs] def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None )
def apply_filter(self,permutation_matrix): transpose_weight = self.weight.transpose(1, 2) @ permutation_matrix self.weight = Parameter(transpose_weight.transpose(1, 2))
class GraphSampler(th.nn.Module): """Ensemble of all the generators.""" def __init__(self, n_noises, gnh, graph_size, mask=None): """Init the model.""" super(GraphSampler, self).__init__() if not isinstance(graph_size, (list, tuple)): self.graph_size = (graph_size, graph_size) else: self.graph_size = graph_size if mask is None: mask = 1 - th.eye(*self.graph_size) if not (type(mask)==bool and not mask): self.register_buffer("mask", mask) ones_tensor = th.ones(*self.graph_size) self.register_buffer("ones_tensor", ones_tensor) zeros_tensor = th.zeros(*self.graph_size) self.register_buffer("zeros_tensor", zeros_tensor) self.register_buffer("noise_graph_sampler", th.Tensor(1, n_noises)) layers = [] layers.append(th.nn.Linear(n_noises, gnh)) #layers.append(th.nn.BatchNorm1d(gnh)) layers.append(th.nn.LeakyReLU(.2)) layers.append(th.nn.Linear(gnh, gnh)) #layers.append(th.nn.BatchNorm1d(gnh)) layers.append(th.nn.LeakyReLU(.2)) layers.append(th.nn.Linear(gnh, gnh)) #layers.append(th.nn.BatchNorm1d(gnh)) layers.append(th.nn.LeakyReLU(.2)) # layers.append(th.nn.Linear(gnh, gnh)) # layers.append(th.nn.BatchNorm1d(gnh)) # layers.append(th.nn.LeakyReLU(.2)) layers.append(th.nn.Linear(gnh, graph_size*graph_size)) self.layers = th.nn.Sequential(*layers) self.reset_parameters() def forward(self): self.noise_graph_sampler.normal_() output_sampler = self.layers(self.noise_graph_sampler).view(*self.graph_size) sample_soft = th.sigmoid(output_sampler) sample_hard = th.where(output_sampler > 0, self.ones_tensor, self.zeros_tensor) #print(output_sampler* self.mask) #print(sample_soft* self.mask) #print(sample_hard* self.mask) sample = sample_hard - sample_soft.data + sample_soft return sample * self.mask def reset_parameters(self): for layer in self.layers: if hasattr(layer, 'reset_parameters'): layer.weight.data.normal_()
[docs]class MatrixSampler(th.nn.Module): """Matrix Sampler, following a Bernoulli distribution. Differenciable.""" def __init__(self, graph_size, mask=None, gumble=False): super(MatrixSampler, self).__init__() if not isinstance(graph_size, (list, tuple)): self.graph_size = (graph_size, graph_size) else: self.graph_size = graph_size self.weights = th.nn.Parameter(th.FloatTensor(*self.graph_size)) self.weights.data.zero_() if mask is None: mask = 1 - th.eye(*self.graph_size) if not (type(mask)==bool and not mask): self.register_buffer("mask", mask) self.gumble = gumble ones_tensor = th.ones(*self.graph_size) self.register_buffer("ones_tensor", ones_tensor) zeros_tensor = th.zeros(*self.graph_size) self.register_buffer("zeros_tensor", zeros_tensor)
[docs] def forward(self, tau=1, drawhard=True): """Return a sampled graph.""" if(self.gumble): drawn_proba = gumbel_softmax(th.stack([self.weights.view(-1), -self.weights.view(-1)], 1), tau=tau, hard=drawhard)[:, 0].view(*self.graph_size) else: drawn_proba = gumbel_sigmoid(2 * self.weights, self.ones_tensor, self.zeros_tensor, tau=tau, hard=drawhard) if hasattr(self, "mask"): return self.mask * drawn_proba else: return drawn_proba
def get_proba(self): if hasattr(self, "mask"): return th.sigmoid(2 * self.weights) * self.mask else: return th.sigmoid(2 * self.weights) def set_skeleton(self, mask): self.register_buffer("mask", mask)
class MatrixSampler2(th.nn.Module): """Matrix Sampler, following a Bernoulli distribution. Differenciable.""" def __init__(self, graph_size, mask=None): super(MatrixSampler2, self).__init__() if not isinstance(graph_size, (list, tuple)): self.graph_size = (graph_size, graph_size) else: self.graph_size = graph_size self.weights = th.nn.Parameter(th.FloatTensor(*self.graph_size)) self.weights.data.zero_() self.v_weights = th.nn.Parameter(th.where(th.eye(*self.graph_size)>0, th.ones(*self.graph_size).fill_(1), th.zeros(*self.graph_size)) .repeat(self.graph_size[1], 1, 1) .transpose(0, 2)) if mask is None: mask = 1 - th.eye(*self.graph_size) if not (type(mask)==bool and not mask): self.register_buffer("mask", mask) def forward(self, tau=1, drawhard=True): """Return a sampled graph.""" # drawn_proba = gumbel_softmax(th.stack([self.weights.view(-1), -self.weights.view(-1)], 1), # tau=tau, hard=drawhard)[:, 0].view(*self.graph_size) # corr_weights = (drawn_proba.unsqueeze(0) * # (self.v_weights/ (.5 * self.v_weights.abs().sum(1, keepdim=True)))).sum(0) corr_weights = (self.weights.unsqueeze(1) * (self.v_weights/ self.v_weights.abs().sum(1, keepdim=True))).sum(0) out_proba = gumbel_softmax(th.stack([corr_weights.view(-1), -corr_weights.view(-1)], 1), tau=tau, hard=drawhard)[:, 0].view(*self.graph_size) if hasattr(self, "mask"): return self.mask * out_proba else: return out_proba def get_proba(self): if hasattr(self, "mask"): return th.sigmoid(2 * self.weights) * self.mask else: return th.sigmoid(2 * self.weights) def set_skeleton(self, mask): self.register_buffer("mask", mask) class MatrixSampler3(th.nn.Module): """Matrix Sampler, following a Bernoulli distribution. Differenciable.""" def __init__(self, graph_size, mask=None, gumbel=True, k=None): super(MatrixSampler3, self).__init__() if not isinstance(graph_size, (list, tuple)): self.graph_size = (graph_size, graph_size) else: self.graph_size = graph_size self.k = k if k is not None else self.graph_size[0] - 1 self.in_weights = th.nn.Parameter(th.FloatTensor(self.graph_size[0], self.k)) self.out_weights = th.nn.Parameter(th.FloatTensor(self.k, self.graph_size[1])) self.in_weights.data.normal_() self.out_weights.data.normal_() self.gumbel_softmax = gumbel if not gumbel: ones_tensor = th.ones(*self.graph_size) zeros_tensor = th.zeros(*self.graph_size) self.register_buffer("ones_tensor", ones_tensor) self.register_buffer("zeros_tensor", zeros_tensor) if mask is None: mask = 1 - th.eye(*self.graph_size) if not (type(mask)==bool and not mask): self.register_buffer("mask", mask) def forward(self, tau=1, drawhard=True): """Return a sampled graph.""" corr_weights = self.in_weights @ self.out_weights if self.gumbel_softmax: out_sample = gumbel_softmax(th.stack([corr_weights.view(-1), -corr_weights.view(-1)], 1), tau=tau, hard=drawhard)[:, 0].view(*self.graph_size) else: sample_soft = th.sigmoid(corr_weights) sample_hard = th.where(corr_weights > 0, self.ones_tensor, self.zeros_tensor) out_sample = sample_hard - sample_soft.data + sample_soft if hasattr(self, "mask"): return self.mask * out_sample else: return out_sample def get_proba(self): if hasattr(self, "mask"): return th.sigmoid(2 * (self.in_weights @ self.out_weights)) * self.mask else: return th.sigmoid(2 * (self.in_weights @ self.out_weights)) def set_skeleton(self, mask): self.register_buffer("mask", mask) class SimpleMatrixConnection(th.nn.Module): """Matrix Sampler, following a Bernoulli distribution. Differenciable.""" def __init__(self, graph_size, mask=None): super(SimpleMatrixConnection, self).__init__() if not isinstance(graph_size, (list, tuple)): self.graph_size = (graph_size, graph_size) else: self.graph_size = graph_size self.weights = th.nn.Parameter(th.FloatTensor(*self.graph_size)) self.weights.data.normal_() if mask is None: mask = 1 - th.eye(*self.graph_size) if not (type(mask) == bool and not mask): self.register_buffer("mask", mask) ones_tensor = th.ones(*self.graph_size) self.register_buffer("ones_tensor", ones_tensor) zeros_tensor = th.zeros(*self.graph_size) self.register_buffer("zeros_tensor", zeros_tensor) def forward(self): """Return a sampled graph.""" sample_soft = th.sigmoid(2 * self.weights) sample_hard = th.where(self.weights > 0, self.ones_tensor, self.zeros_tensor) sample = sample_hard - sample_soft.data + sample_soft if hasattr(self, "mask"): return self.mask * sample_soft else: return sample_soft def get_proba(self): if hasattr(self, "mask"): return th.sigmoid(2 * self.weights) * self.mask else: return th.sigmoid(2 * self.weights)