blob: 1547df659f4b2f39e41c21426cec9a1eac608208 (
plain)
1
2
3
4
5
6
7
8
|
"""Residual function."""
from torch import nn, Tensor
class Residual(nn.Module):
def forward(self, x: Tensor, residual: Tensor) -> Tensor:
"""Applies the residual function."""
return x + residual
|