summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/residual.py
blob: 825a0fc62a118570d5c9c1645e51c029c8a5c673 (plain)
1
2
3
4
5
6
7
8
9
10
"""Residual function."""
from torch import nn, Tensor


class Residual(nn.Module):
    """Residual block."""

    def forward(self, x: Tensor, residual: Tensor) -> Tensor:
        """Applies the residual function."""
        return x + residual