summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conformer')
-rw-r--r--text_recognizer/networks/conformer/conv.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/text_recognizer/networks/conformer/conv.py b/text_recognizer/networks/conformer/conv.py
new file mode 100644
index 0000000..f031dc7
--- /dev/null
+++ b/text_recognizer/networks/conformer/conv.py
@@ -0,0 +1,35 @@
+"""Conformer convolutional block."""
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from torch import nn, Tensor
+
+
+from text_recognizer.networks.conformer.depth_wise_conv import DepthwiseConv1D
+from text_recognizer.networks.conformer.glu import GLU
+
+
+class ConformerConv(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ expansion_factor: int = 2,
+ kernel_size: int = 31,
+ dropout: int = 0.0,
+ ) -> None:
+ super().__init__()
+ inner_dim = expansion_factor * dim
+ self.layers = nn.Sequential(
+ nn.LayerNorm(dim),
+ Rearrange("b n c -> b c n"),
+ nn.Conv1D(dim, 2 * inner_dim, 1),
+ GLU(dim=1),
+ DepthwiseConv1D(inner_dim, inner_dim, kernel_size),
+ nn.BatchNorm1d(inner_dim),
+ nn.Mish(inplace=True),
+ nn.Conv1D(inner_dim, dim, 1),
+ Rearrange("b c n -> b n c"),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.layers(x)