summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/conv.py
blob: ac13f5d54aba252ef859dcf0f6aeecf6888a0473 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Conformer convolutional block."""
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn, Tensor


from text_recognizer.networks.conformer.glu import GLU


class ConformerConv(nn.Module):
    def __init__(
        self,
        dim: int,
        expansion_factor: int = 2,
        kernel_size: int = 31,
        dropout: int = 0.0,
    ) -> None:
        super().__init__()
        inner_dim = expansion_factor * dim
        self.layers = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange("b n c -> b c n"),
            nn.Conv1d(dim, 2 * inner_dim, 1),
            GLU(dim=1),
            nn.Conv1d(
                in_channels=inner_dim,
                out_channels=inner_dim,
                kernel_size=kernel_size,
                groups=inner_dim,
                padding="same",
            ),
            nn.BatchNorm1d(inner_dim),
            nn.Mish(inplace=True),
            nn.Conv1d(inner_dim, dim, 1),
            Rearrange("b c n -> b n c"),
            nn.Dropout(dropout),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.layers(x)