diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-08 08:41:09 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-08 08:41:09 +0200 |
commit | 7b660c13ce3c0edeace1107838e62c559bc6f078 (patch) | |
tree | 117e8ca03815282907f7ba8da296ebc99de8ea7c /text_recognizer/networks/conformer/conformer.py | |
parent | 8ae1b802bb7d7c63cf758e44269e97a4c0788b65 (diff) |
Fix conformer net
Diffstat (limited to 'text_recognizer/networks/conformer/conformer.py')
-rw-r--r-- | text_recognizer/networks/conformer/conformer.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py index 8d0e98e..e2dce27 100644 --- a/text_recognizer/networks/conformer/conformer.py +++ b/text_recognizer/networks/conformer/conformer.py @@ -10,6 +10,8 @@ from text_recognizer.networks.conformer.block import ConformerBlock class Conformer(nn.Module): def __init__( self, + dim: int, + num_classes: int, subsampler: Type[nn.Module], block: ConformerBlock, depth: int, @@ -17,9 +19,10 @@ class Conformer(nn.Module): super().__init__() self.subsampler = subsampler self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)]) + self.fc = nn.Linear(dim, num_classes, bias=False) def forward(self, x: Tensor) -> Tensor: x = self.subsampler(x) for fn in self.blocks: x = fn(x) - return x + return self.fc(x) |