diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-08 08:41:09 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-08 08:41:09 +0200 |
commit | 7b660c13ce3c0edeace1107838e62c559bc6f078 (patch) | |
tree | 117e8ca03815282907f7ba8da296ebc99de8ea7c /text_recognizer/networks/conformer/block.py | |
parent | 8ae1b802bb7d7c63cf758e44269e97a4c0788b65 (diff) |
Fix conformer net
Diffstat (limited to 'text_recognizer/networks/conformer/block.py')
-rw-r--r-- | text_recognizer/networks/conformer/block.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py index 4ea33c0..c53f339 100644 --- a/text_recognizer/networks/conformer/block.py +++ b/text_recognizer/networks/conformer/block.py @@ -26,9 +26,9 @@ class ConformerBlock(nn.Module): self.conv = conv self.post_norm = nn.LayerNorm(dim) - def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor) -> Tensor: x = self.ff_1(x) + x - x = self.attn(x, input_mask=mask) + x + x = self.attn(x) + x x = self.conv(x) + x x = self.ff_2(x) + x return self.post_norm(x) |