1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
"""Generic residual layer.""" from typing import Callable from torch import Tensor, nn class Residual(nn.Module): """Residual layer.""" def __init__(self, fn: Callable) -> None: super().__init__() self.fn = fn def forward(self, x: Tensor) -> Tensor: """Applies residual fn.""" return self.fn(x) + x