From 2d4714fcfeb8914f240a0d36d938b434e82f191b Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 4 Apr 2021 23:08:16 +0200
Subject: Add new transformer network

---
 text_recognizer/networks/transformer/__init__.py           |  2 +-
 text_recognizer/networks/transformer/attention.py          |  3 ++-
 .../networks/transformer/positional_encoding.py            | 14 ++++++++++++--
 3 files changed, 15 insertions(+), 4 deletions(-)

(limited to 'text_recognizer/networks/transformer')

diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
index 9febc88..139cd23 100644
--- a/text_recognizer/networks/transformer/__init__.py
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -1,3 +1,3 @@
 """Transformer modules."""
-from .positional_encoding import PositionalEncoding
+from .positional_encoding import PositionalEncoding, PositionalEncoding2D, target_padding_mask
 from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index cce1ecc..ac75d2f 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -50,8 +50,9 @@ class MultiHeadAttention(nn.Module):
         )
         nn.init.xavier_normal_(self.fc_out.weight)
 
+    @staticmethod
     def scaled_dot_product_attention(
-        self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+        query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
     ) -> Tensor:
         """Calculates the scaled dot product attention."""
 
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index d67d297..dbde887 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -56,9 +56,9 @@ class PositionalEncoding2D(nn.Module):
         pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
 
         pe_w = PositionalEncoding.make_pe(
-            hidden_dim // 2, max_len=max_h
+            hidden_dim // 2, max_len=max_w
         )  # [W, 1, D // 2]
-        pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h)
+        pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h)
 
         pe = torch.cat([pe_h, pe_w], dim=0)  # [D, H, W]
         return pe
@@ -70,3 +70,13 @@ class PositionalEncoding2D(nn.Module):
             raise ValueError("Hidden dimensions does not match.")
         x += self.pe[:, : x.shape[2], : x.shape[3]]
         return x
+
+def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
+    """Returns causal target mask."""
+    trg_pad_mask = (trg != pad_index)[:, None, None]
+    trg_len = trg.shape[1]
+    trg_sub_mask = torch.tril(
+        torch.ones((trg_len, trg_len), device=trg.device)
+    ).bool()
+    trg_mask = trg_pad_mask & trg_sub_mask
+    return trg_mask
-- 
cgit v1.2.3-70-g09d2