summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/block.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conformer/block.py')
-rw-r--r--text_recognizer/networks/conformer/block.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py
index 4b31aec..4ea33c0 100644
--- a/text_recognizer/networks/conformer/block.py
+++ b/text_recognizer/networks/conformer/block.py
@@ -26,9 +26,9 @@ class ConformerBlock(nn.Module):
self.conv = conv
self.post_norm = nn.LayerNorm(dim)
- def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor:
+ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
x = self.ff_1(x) + x
- x = self.attn(x, mask=mask) + x
+ x = self.attn(x, input_mask=mask) + x
x = self.conv(x) + x
x = self.ff_2(x) + x
return self.post_norm(x)