From 6023d1254d6003233e52e943cd54fc9dece641f6 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 5 Jun 2022 21:20:52 +0200 Subject: Fix conformer block --- text_recognizer/networks/conformer/block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py index 4b31aec..4ea33c0 100644 --- a/text_recognizer/networks/conformer/block.py +++ b/text_recognizer/networks/conformer/block.py @@ -26,9 +26,9 @@ class ConformerBlock(nn.Module): self.conv = conv self.post_norm = nn.LayerNorm(dim) - def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: + def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: x = self.ff_1(x) + x - x = self.attn(x, mask=mask) + x + x = self.attn(x, input_mask=mask) + x x = self.conv(x) + x x = self.ff_2(x) + x return self.post_norm(x) -- cgit v1.2.3-70-g09d2