Source code for echoflow.echoflow

import math
from typing import Optional

import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import Dataset

from echoflow.core import MADE, BatchNorm, Coupling, OneHot, Reverse, SequentialFlow
from echoflow.transformer import SplitTransformer, TableTransformer


[docs]class FlowDataset(Dataset): def __init__(self, continuous, categorical, contexts): self.continuous = continuous self.categorical = categorical self.contexts = contexts def __len__(self): if self.continuous is not None: return len(self.continuous) if self.categorical is not None: return len(self.categorical) def __getitem__(self, idx): continuous = None if self.continuous is None else self.continuous[idx] categorical = None if self.categorical is None else self.categorical[idx] contexts = None if self.contexts is None else self.contexts[idx] return continuous, categorical, contexts
[docs]class EchoFlow: """Wrapper for training normalizing flow models.""" def __init__( self, lr: float = 0.0001, nb_epochs: int = 1000, batch_size: int = 100, nb_blocks: int = 3, block_type: str = "RNVP", use_kde: bool = False, ): self.lr = lr self.nb_blocks = nb_blocks self.nb_epochs = nb_epochs self.batch_size = batch_size self.block_type = block_type self.use_kde = use_kde
[docs] def fit(self, df: pd.DataFrame, context: Optional[pd.DataFrame] = None): """Fit the flow model. Parameters ---------- df: The dataframe containing the samples to model. contexts: The (optional) context dataframe for conditional sampling. """ self.has_context = context is not None self.df_transformer = SplitTransformer(self.use_kde) continuous, categorical = self.df_transformer.fit_transform(df) self.input_dims = self.df_transformer.continuous_dims + sum( self.df_transformer.cardinality ) self.context_transformer = TableTransformer(self.use_kde) contexts = ( self.context_transformer.fit_transform(context) if context is not None else None ) layers = [] for _ in range(self.nb_blocks): if self.block_type == "RNVP": input_mask = torch.arange(0, self.input_dims) % 2 layers.extend( [ Coupling( self.input_dims, 100, input_mask, self.context_transformer.dims, ), BatchNorm(self.input_dims), Coupling( self.input_dims, 100, 1.0 - input_mask, self.context_transformer.dims, ), BatchNorm(self.input_dims), ] ) else: layers.extend( [ MADE( self.input_dims, 100, self.context_transformer.dims, ), BatchNorm(self.input_dims), Reverse(self.input_dims), ] ) self.flow = SequentialFlow(*layers) self.flow.train() dataset = FlowDataset(continuous, categorical, contexts) if categorical is not None: self.categorical_encoder = OneHot(self.df_transformer.cardinality) def collate_fn(data): continuous, categorical, contexts = zip(*data) continuous = ( None if continuous[0] is None else torch.stack(continuous, dim=0) ) categorical = ( None if categorical[0] is None else torch.stack(categorical, dim=0) ) contexts = None if contexts[0] is None else torch.stack(contexts, dim=0) return continuous, categorical, contexts dataloader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn ) optimizer = optim.Adam(self.flow.parameters(), lr=self.lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") for epoch in range(1, self.nb_epochs + 1): train_loss = [] for _, (continuous, categorical, contexts) in enumerate(dataloader): optimizer.zero_grad() loss = torch.zeros(1) inputs = [] if continuous is not None: inputs.append(continuous) if categorical is not None: categorical, _loss = self.categorical_encoder(categorical) inputs.append(categorical) # type: ignore loss += _loss inputs = torch.cat(inputs, dim=1) loss -= self._log_likelihood(inputs, contexts).mean() loss.backward() train_loss.append(loss.item()) optimizer.step() train_loss = sum(train_loss) / len(train_loss) scheduler.step(train_loss) if epoch % 10 == 0: print(f"Epoch {epoch} | Train Loss {train_loss:.3f}")
[docs] def sample( self, num_samples: Optional[int] = None, context: Optional[pd.DataFrame] = None ) -> pd.DataFrame: """Generate samples via the inverse transform. Either `num_samples` or `contexts` must be provided. If both are provided, then they must be consistent (i.e. there are `num_samples` rows in `contexts`). Parameters ---------- num_samples: The number of samples. contexts: The (optional) context dataframe for conditional sampling. """ if self.has_context: assert context is not None contexts = ( self.context_transformer.transform(context) if context is not None else None ) if num_samples is None: assert contexts is not None num_samples = contexts.size(0) elif contexts is not None: assert num_samples == contexts.size(0) self.flow.eval() noise = torch.randn(num_samples, self.input_dims) samples, _ = self.flow(noise, contexts, inverse=True) continuous, categorical = None, None if self.df_transformer.continuous_dims: continuous = samples[:, : self.df_transformer.continuous_dims] if self.df_transformer.categorical_dims: categorical, _ = self.categorical_encoder( samples[:, self.df_transformer.continuous_dims :], inverse=True ) return self.df_transformer.inverse_transform(continuous, categorical)
[docs] def log_likelihood(self, df: pd.DataFrame, context: Optional[pd.DataFrame] = None): """Compute the log-likelihood of the data. Parameters ---------- df: The dataframe containing the samples to model. contexts: The (optional) context dataframe. If it was provided in the fit method, it must be provided here as well. """ if self.has_context: assert context is not None continuous, categorical = self.df_transformer.fit_transform(df) contexts = ( self.context_transformer.transform(context) if context is not None else None ) inputs = [] if continuous is not None: inputs.append(continuous) if categorical is not None: categorical, _ = self.categorical_encoder(categorical) inputs.append(categorical) # type: ignore inputs = torch.cat(inputs, dim=1) return self._log_likelihood(inputs, contexts)
def _log_likelihood( self, inputs: torch.Tensor, contexts: Optional[torch.Tensor] = None ): """Compute the log-likelihood to be maximized. Parameters ---------- inputs: The input tensor. contexts: An optional context tensor (for conditional sampling). If it was provided in the fit method, it must be provided here as well. """ u, log_jacob = self.flow(inputs, contexts) log_probs = (-0.5 * u.pow(2) - 0.5 * math.log(2 * math.pi)).sum( -1, keepdim=True ) return (log_probs + log_jacob).sum(-1, keepdim=True)