Source code for echoflow.core.coupling

# Derived from https://github.com/ikostrikov/pytorch-flows
from typing import Optional

import torch
import torch.nn as nn

from .base import BaseFlow


[docs]class Coupling(BaseFlow): r"""Coupling layer from RealNVP. The coupling layer partitions the input `x` into two parts, x1 and x2, and applies an invertible transform: .. math:: y1 &= x1 \\ y2 &= x2 * exp(s(x1)) + t(x1) which modifies only one of the partitions. """ def __init__( self, input_dims: int, hidden_dims: int, input_mask: torch.Tensor, context_dims: int = 0, ): """ Parameters ---------- input_dims: The number of input dimensions. hidden_dims: The hidden size to use for the scale/translate nets. input_mask: A binary mask for the input. context_dims: The number of context dimensions. If specified, then the output is conditioned on context. """ super(Coupling, self).__init__() self.input_dims = input_dims self.input_mask = input_mask assert ( input_mask.size(0) == input_dims ), "Expected input_mask to have size input_dims." self.scale_net = nn.Sequential( nn.Linear(input_dims + context_dims, hidden_dims), nn.Tanh(), nn.Linear(hidden_dims, hidden_dims), nn.Tanh(), nn.Linear(hidden_dims, input_dims), ) self.translate_net = nn.Sequential( nn.Linear(input_dims + context_dims, hidden_dims), nn.LeakyReLU(inplace=True), nn.Linear(hidden_dims, hidden_dims), nn.LeakyReLU(inplace=True), nn.Linear(hidden_dims, input_dims), )
[docs] def forward( self, inputs: torch.Tensor, contexts: Optional[torch.Tensor] = None, inverse: bool = False, ): assert inputs.dim() == 2, "Expected tensor of shape (batch_size, input_dims)." if contexts is not None: assert ( contexts.dim() == 2 ), "Expected tensor of shape (batch_size, context_dims)." masked_inputs = inputs * self.input_mask if contexts is not None: masked_inputs = torch.cat([masked_inputs, contexts], -1) if not inverse: log_s = self.scale_net(masked_inputs) * (1 - self.input_mask) t = self.translate_net(masked_inputs) * (1 - self.input_mask) s = torch.exp(log_s) return inputs * s + t, log_s.sum(-1, keepdim=True) else: log_s = self.scale_net(masked_inputs) * (1 - self.input_mask) t = self.translate_net(masked_inputs) * (1 - self.input_mask) s = torch.exp(-log_s) return (inputs - t) * s, -log_s.sum(-1, keepdim=True)