Source code for echoflow.core.made

# Derived from https://github.com/ikostrikov/pytorch-flows
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from .base import BaseFlow


[docs]def get_mask( in_features: int, out_features: int, in_flow_features: int, mask_type=None ): if mask_type == "input": in_degrees = torch.arange(in_features) % in_flow_features else: in_degrees = torch.arange(in_features) % (in_flow_features - 1) if mask_type == "output": out_degrees = torch.arange(out_features) % in_flow_features - 1 else: out_degrees = torch.arange(out_features) % (in_flow_features - 1) return (out_degrees.unsqueeze(-1) >= in_degrees.unsqueeze(0)).float()
[docs]class MaskedLinear(BaseFlow): def __init__( self, input_dims: int, out_features: int, weight_mask: torch.Tensor, context_dims: int = 0, ): super(MaskedLinear, self).__init__() self.weight_mask = weight_mask self.linear = nn.Linear(input_dims, out_features) if context_dims: self.context_linear = nn.Linear(context_dims, out_features)
[docs] def forward( self, inputs: torch.Tensor, contexts: Optional[torch.Tensor] = None, inverse: bool = False, ): output = F.linear( inputs, self.linear.weight * self.weight_mask, self.linear.bias ) if contexts is not None: output += self.context_linear(contexts) return output
[docs]class MADE(BaseFlow): def __init__(self, input_dims, hidden_dims, context_dims=0): super(MADE, self).__init__() input_mask = get_mask(input_dims, hidden_dims, input_dims, mask_type="input") hidden_mask = get_mask(hidden_dims, hidden_dims, input_dims) output_mask = get_mask( hidden_dims, input_dims * 2, input_dims, mask_type="output" ) self.joiner = MaskedLinear(input_dims, hidden_dims, input_mask, context_dims) self.trunk = nn.Sequential( nn.LeakyReLU(inplace=True), MaskedLinear(hidden_dims, hidden_dims, hidden_mask), nn.LeakyReLU(inplace=True), MaskedLinear(hidden_dims, input_dims * 2, output_mask), )
[docs] def forward( self, inputs: torch.Tensor, contexts: Optional[torch.Tensor] = None, inverse: bool = False, ): if not inverse: h = self.joiner(inputs, contexts) m, a = self.trunk(h).chunk(2, 1) u = (inputs - m) * torch.exp(-a) return u, -a.sum(-1, keepdim=True) else: x = torch.zeros_like(inputs) for i in range(inputs.shape[1]): h = self.joiner(x, contexts) m, a = self.trunk(h).chunk(2, 1) x[:, i] += inputs[:, i] * torch.exp(a[:, i]) + m[:, i] return x, -a.sum(-1, keepdim=True)