summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/conformer/block.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py
index e69de29..d9782e8 100644
--- a/text_recognizer/networks/conformer/block.py
+++ b/text_recognizer/networks/conformer/block.py
@@ -0,0 +1,34 @@
+"""Conformer block."""
+from copy import deepcopy
+from typing import Optional
+
+from torch import nn, Tensor
+from text_recognizer.networks.conformer.conv import ConformerConv
+
+from text_recognizer.networks.conformer.mlp import MLP
+from text_recognizer.networks.conformer.scale import Scale
+from text_recognizer.networks.transformer.attention import Attention
+from text_recognizer.networks.transformer.norm import PreNorm
+
+
+class ConformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ff: MLP,
+ attn: Attention,
+ conv: ConformerConv,
+ ) -> None:
+ super().__init__()
+ self.attn = PreNorm(dim, attn)
+ self.ff_1 = Scale(0.5, ff)
+ self.ff_2 = deepcopy(self.ff_1)
+ self.conv = conv
+ self.post_norm = nn.LayerNorm(dim)
+
+ def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor:
+ x = self.ff_1(x) + x
+ x = self.attn(x, mask=mask) + x
+ x = self.conv(x) + x
+ x = self.ff_2(x) + x
+ return self.post_norm(x)