diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 23:08:46 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 23:08:46 +0200 |
commit | 9e54591b7e342edc93b0bb04809a0f54045c6a15 (patch) | |
tree | a0f8ba9a72389e65d306c5733cbc6bbc36ea2fcf /text_recognizer/networks/transformer | |
parent | 2d4714fcfeb8914f240a0d36d938b434e82f191b (diff) |
black reformatting
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/__init__.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/positional_encoding.py | 5 |
2 files changed, 7 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 139cd23..652e82e 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -1,3 +1,7 @@ """Transformer modules.""" -from .positional_encoding import PositionalEncoding, PositionalEncoding2D, target_padding_mask +from .positional_encoding import ( + PositionalEncoding, + PositionalEncoding2D, + target_padding_mask, +) from .transformer import Decoder, Encoder, EncoderLayer, Transformer diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index dbde887..5874e97 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -71,12 +71,11 @@ class PositionalEncoding2D(nn.Module): x += self.pe[:, : x.shape[2], : x.shape[3]] return x + def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor: """Returns causal target mask.""" trg_pad_mask = (trg != pad_index)[:, None, None] trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() + trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool() trg_mask = trg_pad_mask & trg_sub_mask return trg_mask |