From 4cdc50e2c89015f49973eadddfbee88ba2744f06 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 2 Jun 2022 23:42:56 +0200 Subject: Add conformer conv layer --- text_recognizer/networks/conformer/conv.py | 35 ++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 text_recognizer/networks/conformer/conv.py (limited to 'text_recognizer/networks/conformer') diff --git a/text_recognizer/networks/conformer/conv.py b/text_recognizer/networks/conformer/conv.py new file mode 100644 index 0000000..f031dc7 --- /dev/null +++ b/text_recognizer/networks/conformer/conv.py @@ -0,0 +1,35 @@ +"""Conformer convolutional block.""" +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import nn, Tensor + + +from text_recognizer.networks.conformer.depth_wise_conv import DepthwiseConv1D +from text_recognizer.networks.conformer.glu import GLU + + +class ConformerConv(nn.Module): + def __init__( + self, + dim: int, + expansion_factor: int = 2, + kernel_size: int = 31, + dropout: int = 0.0, + ) -> None: + super().__init__() + inner_dim = expansion_factor * dim + self.layers = nn.Sequential( + nn.LayerNorm(dim), + Rearrange("b n c -> b c n"), + nn.Conv1D(dim, 2 * inner_dim, 1), + GLU(dim=1), + DepthwiseConv1D(inner_dim, inner_dim, kernel_size), + nn.BatchNorm1d(inner_dim), + nn.Mish(inplace=True), + nn.Conv1D(inner_dim, dim, 1), + Rearrange("b c n -> b n c"), + nn.Dropout(dropout), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.layers(x) -- cgit v1.2.3-70-g09d2