diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-05 19:25:59 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-05 19:26:39 +0100 |
commit | e5e776cb7ce3486d1a9e16f6ae328f55fd20f02b (patch) | |
tree | 61ff5abc85015a720793fe724d7b65c4ca82764e /text_recognizer/networks/transformer/layers.py | |
parent | ea525029b8b0355c656280e491796b4821c491a4 (diff) |
Rename mask to input_mask
Rename mask to input_mask
Rename mask to input_mask
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index f740244..8387fa4 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -77,7 +77,7 @@ class AttentionLayers(nn.Module): self, x: Tensor, context: Optional[Tensor] = None, - mask: Optional[Tensor] = None, + input_mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, ) -> Tensor: """Forward pass.""" @@ -91,9 +91,11 @@ class AttentionLayers(nn.Module): x = norm(x) if layer_type == "a": - out = block(x=x, mask=mask) + out = block(x=x, input_mask=input_mask) elif layer_type == "c": - out = block(x, context=context, mask=mask, context_mask=context_mask) + out = block( + x, context=context, input_mask=input_mask, context_mask=context_mask + ) elif layer_type == "f": out = block(x) |