summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/block.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-05 21:20:52 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-05 21:20:52 +0200
commit6023d1254d6003233e52e943cd54fc9dece641f6 (patch)
tree97301877797f5bd88811531375514b86fba6573f /text_recognizer/networks/conformer/block.py
parentbdda28c77798d3c08913fd9c9059710f288e0e41 (diff)
Fix conformer block
Diffstat (limited to 'text_recognizer/networks/conformer/block.py')
-rw-r--r--text_recognizer/networks/conformer/block.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py
index 4b31aec..4ea33c0 100644
--- a/text_recognizer/networks/conformer/block.py
+++ b/text_recognizer/networks/conformer/block.py
@@ -26,9 +26,9 @@ class ConformerBlock(nn.Module):
self.conv = conv
self.post_norm = nn.LayerNorm(dim)
- def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor:
+ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
x = self.ff_1(x) + x
- x = self.attn(x, mask=mask) + x
+ x = self.attn(x, input_mask=mask) + x
x = self.conv(x) + x
x = self.ff_2(x) + x
return self.post_norm(x)