Source code for echoflow.core.one_hot

from typing import List, Optional

import torch

from .base import BaseFlow


[docs]class OneHot(BaseFlow): def __init__(self, cardinality: List[int]): super(OneHot, self).__init__() self.cardinality = cardinality self.scale = 0.1 self.input_dims = len(cardinality) self.output_dims = sum(cardinality) assert all(c > 1 for c in cardinality)
[docs] def forward( self, inputs: torch.Tensor, contexts: Optional[torch.Tensor] = None, inverse: bool = False, ): offset = 0 if not inverse: inputs = inputs.type(torch.LongTensor) # type: ignore x = torch.randn(inputs.size(0), self.output_dims) * self.scale for i, cardinality in enumerate(self.cardinality): x.scatter_(1, (offset + inputs[:, i]).unsqueeze(1), 1) offset += cardinality return x, 0.0 else: x = torch.LongTensor(inputs.size(0), self.input_dims) for i, cardinality in enumerate(self.cardinality): x[:, i] = torch.argmax(inputs[:, offset : offset + cardinality], dim=1) offset += cardinality return x, 0.0