summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/block.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conformer/block.py')
-rw-r--r--text_recognizer/networks/conformer/block.py34
1 files changed, 0 insertions, 34 deletions
diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py
deleted file mode 100644
index c53f339..0000000
--- a/text_recognizer/networks/conformer/block.py
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Conformer block."""
-from copy import deepcopy
-from typing import Optional
-
-from torch import nn, Tensor
-from text_recognizer.networks.conformer.conv import ConformerConv
-
-from text_recognizer.networks.conformer.ff import Feedforward
-from text_recognizer.networks.conformer.scale import Scale
-from text_recognizer.networks.transformer.attention import Attention
-from text_recognizer.networks.transformer.norm import PreNorm
-
-
-class ConformerBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- ff: Feedforward,
- attn: Attention,
- conv: ConformerConv,
- ) -> None:
- super().__init__()
- self.attn = PreNorm(dim, attn)
- self.ff_1 = Scale(0.5, ff)
- self.ff_2 = deepcopy(self.ff_1)
- self.conv = conv
- self.post_norm = nn.LayerNorm(dim)
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.ff_1(x) + x
- x = self.attn(x) + x
- x = self.conv(x) + x
- x = self.ff_2(x) + x
- return self.post_norm(x)