From 4cdc50e2c89015f49973eadddfbee88ba2744f06 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Thu, 2 Jun 2022 23:42:56 +0200
Subject: Add conformer conv layer

---
 text_recognizer/networks/conformer/conv.py | 35 ++++++++++++++++++++++++++++++
 1 file changed, 35 insertions(+)
 create mode 100644 text_recognizer/networks/conformer/conv.py

(limited to 'text_recognizer/networks/conformer')

diff --git a/text_recognizer/networks/conformer/conv.py b/text_recognizer/networks/conformer/conv.py
new file mode 100644
index 0000000..f031dc7
--- /dev/null
+++ b/text_recognizer/networks/conformer/conv.py
@@ -0,0 +1,35 @@
+"""Conformer convolutional block."""
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from torch import nn, Tensor
+
+
+from text_recognizer.networks.conformer.depth_wise_conv import DepthwiseConv1D
+from text_recognizer.networks.conformer.glu import GLU
+
+
+class ConformerConv(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        expansion_factor: int = 2,
+        kernel_size: int = 31,
+        dropout: int = 0.0,
+    ) -> None:
+        super().__init__()
+        inner_dim = expansion_factor * dim
+        self.layers = nn.Sequential(
+            nn.LayerNorm(dim),
+            Rearrange("b n c -> b c n"),
+            nn.Conv1D(dim, 2 * inner_dim, 1),
+            GLU(dim=1),
+            DepthwiseConv1D(inner_dim, inner_dim, kernel_size),
+            nn.BatchNorm1d(inner_dim),
+            nn.Mish(inplace=True),
+            nn.Conv1D(inner_dim, dim, 1),
+            Rearrange("b c n -> b n c"),
+            nn.Dropout(dropout),
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.layers(x)
-- 
cgit v1.2.3-70-g09d2