Source code for echoflow.core.base

from typing import Optional

import torch
import torch.nn as nn


[docs]class BaseFlow(nn.Module): def __init__(self): r"""Normalizing flow layer.""" super(BaseFlow, self).__init__()
[docs] def forward( self, inputs: torch.Tensor, contexts: Optional[torch.Tensor] = None, inverse: bool = False, ): r"""Transform a batch of data. Parameters ---------- inputs: The input tensor. contexts: An optional context tensor (for conditional sampling). inverse: Whether to apply the direct or inverse transform. Returns ---------- outputs: torch.Tensor The output tensor. logdet: torch.Tensor The log-determinant of the Jacobian. """ raise NotImplementedError()