summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-05 19:25:59 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-05 19:26:39 +0100
commite5e776cb7ce3486d1a9e16f6ae328f55fd20f02b (patch)
tree61ff5abc85015a720793fe724d7b65c4ca82764e /text_recognizer/networks/transformer/layers.py
parentea525029b8b0355c656280e491796b4821c491a4 (diff)
Rename mask to input_mask
Rename mask to input_mask Rename mask to input_mask
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r--text_recognizer/networks/transformer/layers.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index f740244..8387fa4 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -77,7 +77,7 @@ class AttentionLayers(nn.Module):
self,
x: Tensor,
context: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
+ input_mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward pass."""
@@ -91,9 +91,11 @@ class AttentionLayers(nn.Module):
x = norm(x)
if layer_type == "a":
- out = block(x=x, mask=mask)
+ out = block(x=x, input_mask=input_mask)
elif layer_type == "c":
- out = block(x, context=context, mask=mask, context_mask=context_mask)
+ out = block(
+ x, context=context, input_mask=input_mask, context_mask=context_mask
+ )
elif layer_type == "f":
out = block(x)