summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/positional_encoding.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 23:08:46 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 23:08:46 +0200
commit9e54591b7e342edc93b0bb04809a0f54045c6a15 (patch)
treea0f8ba9a72389e65d306c5733cbc6bbc36ea2fcf /text_recognizer/networks/transformer/positional_encoding.py
parent2d4714fcfeb8914f240a0d36d938b434e82f191b (diff)
black reformatting
Diffstat (limited to 'text_recognizer/networks/transformer/positional_encoding.py')
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index dbde887..5874e97 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -71,12 +71,11 @@ class PositionalEncoding2D(nn.Module):
x += self.pe[:, : x.shape[2], : x.shape[3]]
return x
+
def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
"""Returns causal target mask."""
trg_pad_mask = (trg != pad_index)[:, None, None]
trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(
- torch.ones((trg_len, trg_len), device=trg.device)
- ).bool()
+ trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
trg_mask = trg_pad_mask & trg_sub_mask
return trg_mask