diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-07 00:24:28 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-07 00:24:28 +0200 |
commit | 38dc6ca3b787bcdb54d43ac5c076e08af25d44b2 (patch) | |
tree | df12ee98c797c44c61f02369cf8cb794d6f47b7c /text_recognizer/networks/conformer/conformer.py | |
parent | 7d759b6c0efcb58b5c7c6858d7dcbd2060992430 (diff) |
Add subsampler layer
Diffstat (limited to 'text_recognizer/networks/conformer/conformer.py')
-rw-r--r-- | text_recognizer/networks/conformer/conformer.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py index d56955e..8d0e98e 100644 --- a/text_recognizer/networks/conformer/conformer.py +++ b/text_recognizer/networks/conformer/conformer.py @@ -1,5 +1,6 @@ """Conformer module.""" from copy import deepcopy +from typing import Type from torch import nn, Tensor @@ -7,11 +8,18 @@ from text_recognizer.networks.conformer.block import ConformerBlock class Conformer(nn.Module): - def __init__(self, block: ConformerBlock, depth: int) -> None: + def __init__( + self, + subsampler: Type[nn.Module], + block: ConformerBlock, + depth: int, + ) -> None: super().__init__() + self.subsampler = subsampler self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)]) def forward(self, x: Tensor) -> Tensor: + x = self.subsampler(x) for fn in self.blocks: x = fn(x) return x |