1 2 3 4 5 6 7 8 9 10 11 12 13
"""Generic residual layer.""" from typing import Callable from torch import Tensor, nn class Residual(nn.Module): def __init__(self, fn: Callable) -> None: super().__init__() self.fn = fn def forward(self, x: Tensor) -> Tensor: return self.fn(x) + x