diff options
Diffstat (limited to 'text_recognizer/networks/conformer')
| -rw-r--r-- | text_recognizer/networks/conformer/block.py | 4 | 
1 files changed, 2 insertions, 2 deletions
| diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py index 4b31aec..4ea33c0 100644 --- a/text_recognizer/networks/conformer/block.py +++ b/text_recognizer/networks/conformer/block.py @@ -26,9 +26,9 @@ class ConformerBlock(nn.Module):          self.conv = conv          self.post_norm = nn.LayerNorm(dim) -    def forward(self, x: Tensor, mask: Optional[Tensor]) -> Tensor: +    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:          x = self.ff_1(x) + x -        x = self.attn(x, mask=mask) + x +        x = self.attn(x, input_mask=mask) + x          x = self.conv(x) + x          x = self.ff_2(x) + x          return self.post_norm(x) |