summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-02 23:42:56 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-02 23:42:56 +0200
commit4cdc50e2c89015f49973eadddfbee88ba2744f06 (patch)
tree448a053a7d4353847b98d99ee823ba2179c923c7 /text_recognizer/networks/conformer
parent1f3ab1c091b44b119765f785eab16e7dd06dfa4d (diff)
Add conformer conv layer
Diffstat (limited to 'text_recognizer/networks/conformer')
-rw-r--r--text_recognizer/networks/conformer/conv.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/text_recognizer/networks/conformer/conv.py b/text_recognizer/networks/conformer/conv.py
new file mode 100644
index 0000000..f031dc7
--- /dev/null
+++ b/text_recognizer/networks/conformer/conv.py
@@ -0,0 +1,35 @@
+"""Conformer convolutional block."""
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from torch import nn, Tensor
+
+
+from text_recognizer.networks.conformer.depth_wise_conv import DepthwiseConv1D
+from text_recognizer.networks.conformer.glu import GLU
+
+
+class ConformerConv(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ expansion_factor: int = 2,
+ kernel_size: int = 31,
+ dropout: int = 0.0,
+ ) -> None:
+ super().__init__()
+ inner_dim = expansion_factor * dim
+ self.layers = nn.Sequential(
+ nn.LayerNorm(dim),
+ Rearrange("b n c -> b c n"),
+ nn.Conv1D(dim, 2 * inner_dim, 1),
+ GLU(dim=1),
+ DepthwiseConv1D(inner_dim, inner_dim, kernel_size),
+ nn.BatchNorm1d(inner_dim),
+ nn.Mish(inplace=True),
+ nn.Conv1D(inner_dim, dim, 1),
+ Rearrange("b c n -> b n c"),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.layers(x)