summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/mlp.py
blob: 031bde991e6e66e0708719af301b57a125a54e2d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""Conformer feedforward block."""
from torch import nn, Tensor


class MLP(nn.Module):
    def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(dim, mult * dim),
            nn.Mish(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(mult * dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.layers(x)