diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 23:08:16 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 23:08:16 +0200 |
commit | 2d4714fcfeb8914f240a0d36d938b434e82f191b (patch) | |
tree | 32e7b3446332cee4685ec90870210c51f9f1279f /text_recognizer/networks/transformer | |
parent | 5dc8a7097ab6b4f39f0a3add408e3fd0f131f85b (diff) |
Add new transformer network
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 3 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/positional_encoding.py | 14 |
3 files changed, 15 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 9febc88..139cd23 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -1,3 +1,3 @@ """Transformer modules.""" -from .positional_encoding import PositionalEncoding +from .positional_encoding import PositionalEncoding, PositionalEncoding2D, target_padding_mask from .transformer import Decoder, Encoder, EncoderLayer, Transformer diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index cce1ecc..ac75d2f 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -50,8 +50,9 @@ class MultiHeadAttention(nn.Module): ) nn.init.xavier_normal_(self.fc_out.weight) + @staticmethod def scaled_dot_product_attention( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None ) -> Tensor: """Calculates the scaled dot product attention.""" diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index d67d297..dbde887 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -56,9 +56,9 @@ class PositionalEncoding2D(nn.Module): pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) pe_w = PositionalEncoding.make_pe( - hidden_dim // 2, max_len=max_h + hidden_dim // 2, max_len=max_w ) # [W, 1, D // 2] - pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h) + pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h) pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] return pe @@ -70,3 +70,13 @@ class PositionalEncoding2D(nn.Module): raise ValueError("Hidden dimensions does not match.") x += self.pe[:, : x.shape[2], : x.shape[3]] return x + +def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor: + """Returns causal target mask.""" + trg_pad_mask = (trg != pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask |