Source code for echoflow.core.sequential

from typing import Optional

import torch
import torch.nn as nn

from .base import BaseFlow


[docs]class SequentialFlow(BaseFlow): r"""Apply a sequence of flows.""" def __init__(self, *modules): """ Parameters ---------- *args: A list of BaseFlow layers to apply in order. """ super(SequentialFlow, self).__init__() self._layers = nn.ModuleList(modules)
[docs] def forward( self, inputs: torch.Tensor, contexts: Optional[torch.Tensor] = None, inverse: bool = False, ): self.input_dims = inputs.size(1) logdets = torch.zeros(inputs.size(0), 1) if not inverse: for layer in self._layers: inputs, logdet = layer(inputs, contexts, inverse) logdets += logdet else: for layer in reversed(self._layers): inputs, logdet = layer(inputs, contexts, inverse) logdets += logdet return inputs, logdets