diff options
Diffstat (limited to 'text_recognizer/networks/conformer/conformer.py')
-rw-r--r-- | text_recognizer/networks/conformer/conformer.py | 35 |
1 files changed, 0 insertions, 35 deletions
diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py deleted file mode 100644 index 09aad55..0000000 --- a/text_recognizer/networks/conformer/conformer.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Conformer module.""" -from copy import deepcopy -from typing import Type - -from torch import nn, Tensor - -from text_recognizer.networks.conformer.block import ConformerBlock - - -class Conformer(nn.Module): - def __init__( - self, - dim: int, - dim_gru: int, - num_classes: int, - subsampler: Type[nn.Module], - block: ConformerBlock, - depth: int, - ) -> None: - super().__init__() - self.subsampler = subsampler - self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)]) - self.gru = nn.GRU( - dim, dim_gru, 1, bidirectional=True, batch_first=True, bias=False - ) - self.fc = nn.Linear(dim_gru, num_classes) - - def forward(self, x: Tensor) -> Tensor: - x = self.subsampler(x) - B, T, C = x.shape - for fn in self.blocks: - x = fn(x) - x, _ = self.gru(x) - x = x.view(B, T, 2, -1).sum(2) - return self.fc(x) |