summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/conv.py
blob: f031dc769adb57bb7e3191d77fe658851a214d58 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""Conformer convolutional block."""
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn, Tensor


from text_recognizer.networks.conformer.depth_wise_conv import DepthwiseConv1D
from text_recognizer.networks.conformer.glu import GLU


class ConformerConv(nn.Module):
    def __init__(
        self,
        dim: int,
        expansion_factor: int = 2,
        kernel_size: int = 31,
        dropout: int = 0.0,
    ) -> None:
        super().__init__()
        inner_dim = expansion_factor * dim
        self.layers = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange("b n c -> b c n"),
            nn.Conv1D(dim, 2 * inner_dim, 1),
            GLU(dim=1),
            DepthwiseConv1D(inner_dim, inner_dim, kernel_size),
            nn.BatchNorm1d(inner_dim),
            nn.Mish(inplace=True),
            nn.Conv1D(inner_dim, dim, 1),
            Rearrange("b c n -> b n c"),
            nn.Dropout(dropout),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.layers(x)