# Derived from https://github.com/ikostrikov/pytorch-flows
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import BaseFlow
[docs]def get_mask(
in_features: int, out_features: int, in_flow_features: int, mask_type=None
):
if mask_type == "input":
in_degrees = torch.arange(in_features) % in_flow_features
else:
in_degrees = torch.arange(in_features) % (in_flow_features - 1)
if mask_type == "output":
out_degrees = torch.arange(out_features) % in_flow_features - 1
else:
out_degrees = torch.arange(out_features) % (in_flow_features - 1)
return (out_degrees.unsqueeze(-1) >= in_degrees.unsqueeze(0)).float()
[docs]class MaskedLinear(BaseFlow):
def __init__(
self,
input_dims: int,
out_features: int,
weight_mask: torch.Tensor,
context_dims: int = 0,
):
super(MaskedLinear, self).__init__()
self.weight_mask = weight_mask
self.linear = nn.Linear(input_dims, out_features)
if context_dims:
self.context_linear = nn.Linear(context_dims, out_features)
[docs] def forward(
self,
inputs: torch.Tensor,
contexts: Optional[torch.Tensor] = None,
inverse: bool = False,
):
output = F.linear(
inputs, self.linear.weight * self.weight_mask, self.linear.bias
)
if contexts is not None:
output += self.context_linear(contexts)
return output
[docs]class MADE(BaseFlow):
def __init__(self, input_dims, hidden_dims, context_dims=0):
super(MADE, self).__init__()
input_mask = get_mask(input_dims, hidden_dims, input_dims, mask_type="input")
hidden_mask = get_mask(hidden_dims, hidden_dims, input_dims)
output_mask = get_mask(
hidden_dims, input_dims * 2, input_dims, mask_type="output"
)
self.joiner = MaskedLinear(input_dims, hidden_dims, input_mask, context_dims)
self.trunk = nn.Sequential(
nn.LeakyReLU(inplace=True),
MaskedLinear(hidden_dims, hidden_dims, hidden_mask),
nn.LeakyReLU(inplace=True),
MaskedLinear(hidden_dims, input_dims * 2, output_mask),
)
[docs] def forward(
self,
inputs: torch.Tensor,
contexts: Optional[torch.Tensor] = None,
inverse: bool = False,
):
if not inverse:
h = self.joiner(inputs, contexts)
m, a = self.trunk(h).chunk(2, 1)
u = (inputs - m) * torch.exp(-a)
return u, -a.sum(-1, keepdim=True)
else:
x = torch.zeros_like(inputs)
for i in range(inputs.shape[1]):
h = self.joiner(x, contexts)
m, a = self.trunk(h).chunk(2, 1)
x[:, i] += inputs[:, i] * torch.exp(a[:, i]) + m[:, i]
return x, -a.sum(-1, keepdim=True)