summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/conformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conformer/conformer.py')
-rw-r--r--text_recognizer/networks/conformer/conformer.py35
1 files changed, 0 insertions, 35 deletions
diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py
deleted file mode 100644
index 09aad55..0000000
--- a/text_recognizer/networks/conformer/conformer.py
+++ /dev/null
@@ -1,35 +0,0 @@
-"""Conformer module."""
-from copy import deepcopy
-from typing import Type
-
-from torch import nn, Tensor
-
-from text_recognizer.networks.conformer.block import ConformerBlock
-
-
-class Conformer(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_gru: int,
- num_classes: int,
- subsampler: Type[nn.Module],
- block: ConformerBlock,
- depth: int,
- ) -> None:
- super().__init__()
- self.subsampler = subsampler
- self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)])
- self.gru = nn.GRU(
- dim, dim_gru, 1, bidirectional=True, batch_first=True, bias=False
- )
- self.fc = nn.Linear(dim_gru, num_classes)
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.subsampler(x)
- B, T, C = x.shape
- for fn in self.blocks:
- x = fn(x)
- x, _ = self.gru(x)
- x = x.view(B, T, 2, -1).sum(2)
- return self.fc(x)