diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-02 23:42:56 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-02 23:42:56 +0200 | 
| commit | 4cdc50e2c89015f49973eadddfbee88ba2744f06 (patch) | |
| tree | 448a053a7d4353847b98d99ee823ba2179c923c7 /text_recognizer/networks | |
| parent | 1f3ab1c091b44b119765f785eab16e7dd06dfa4d (diff) | |
Add conformer conv layer
Diffstat (limited to 'text_recognizer/networks')
| -rw-r--r-- | text_recognizer/networks/conformer/conv.py | 35 | 
1 files changed, 35 insertions, 0 deletions
diff --git a/text_recognizer/networks/conformer/conv.py b/text_recognizer/networks/conformer/conv.py new file mode 100644 index 0000000..f031dc7 --- /dev/null +++ b/text_recognizer/networks/conformer/conv.py @@ -0,0 +1,35 @@ +"""Conformer convolutional block.""" +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import nn, Tensor + + +from text_recognizer.networks.conformer.depth_wise_conv import DepthwiseConv1D +from text_recognizer.networks.conformer.glu import GLU + + +class ConformerConv(nn.Module): +    def __init__( +        self, +        dim: int, +        expansion_factor: int = 2, +        kernel_size: int = 31, +        dropout: int = 0.0, +    ) -> None: +        super().__init__() +        inner_dim = expansion_factor * dim +        self.layers = nn.Sequential( +            nn.LayerNorm(dim), +            Rearrange("b n c -> b c n"), +            nn.Conv1D(dim, 2 * inner_dim, 1), +            GLU(dim=1), +            DepthwiseConv1D(inner_dim, inner_dim, kernel_size), +            nn.BatchNorm1d(inner_dim), +            nn.Mish(inplace=True), +            nn.Conv1D(inner_dim, dim, 1), +            Rearrange("b c n -> b n c"), +            nn.Dropout(dropout), +        ) + +    def forward(self, x: Tensor) -> Tensor: +        return self.layers(x)  |