summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/ff.py
blob: 9181323f2dc8f16f8eec1ec6e5c06ab5aaf2b8f5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""Feedforward layer in transformer."""
from torch import Tensor, nn


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        inner_dim: int,
        dropout_rate: float = 0.0,
    ) -> None:
        super().__init__()
        self.ff = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, inner_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(inner_dim, dim),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.ff(x)