From d2c4d31a94d9813b6985108b743a5bd3ffed36b9 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 4 Apr 2021 12:45:54 +0200
Subject: Add 2d positional encoding

---
 .../networks/transformer/positional_encoding.py    | 50 +++++++++++++++++++---
 1 file changed, 44 insertions(+), 6 deletions(-)

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index 1ba5537..d03f630 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -1,4 +1,5 @@
 """A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+from einops import repeat
 import numpy as np
 import torch
 from torch import nn
@@ -13,20 +14,57 @@ class PositionalEncoding(nn.Module):
     ) -> None:
         super().__init__()
         self.dropout = nn.Dropout(p=dropout_rate)
-        self.max_len = max_len
-
+        pe = self.make_pe(hidden_dim, max_len)
+        self.register_buffer("pe", pe)
+    
+    @staticmethod
+    def make_pe(hidden_dim: int, max_len: int) -> Tensor:
+        """Returns positional encoding."""
         pe = torch.zeros(max_len, hidden_dim)
-        position = torch.arange(0, max_len).unsqueeze(1)
+        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
         div_term = torch.exp(
-            torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
+            torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
         )
 
         pe[:, 0::2] = torch.sin(position * div_term)
         pe[:, 1::2] = torch.cos(position * div_term)
-        pe = pe.unsqueeze(0)
-        self.register_buffer("pe", pe)
+        pe = pe.unsqueeze(1)
+        return pe
 
     def forward(self, x: Tensor) -> Tensor:
         """Encodes the tensor with a postional embedding."""
         x = x + self.pe[:, : x.shape[1]]
         return self.dropout(x)
+
+
+class PositionalEncoding2D(nn.Module):
+    """Positional encodings for feature maps."""
+
+    def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int =2048) -> None:
+        super().__init__()
+        if hidden_dim % 2 != 0:
+            raise ValueError(f"Embedding depth {hidden_dim} is not even!")
+        self.hidden_dim = hidden_dim
+        pe = self.make_pe(hidden_dim, max_h, max_w)
+        self.register_buffer("pe", pe)
+
+    def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
+        """Returns 2d postional encoding."""
+        pe_h = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h)  # [H, 1, D // 2]
+        pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
+
+        pe_w = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h)  # [W, 1, D // 2]
+        pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h)
+
+        pe = torch.cat([pe_h, pe_w], dim=0)  # [D, H, W]
+        return pe
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Adds 2D postional encoding to input tensor."""
+        # Assumes x hase shape [B, D, H, W]
+        if x.shape[1] != self.pe.shape[0]:
+            raise ValueError("Hidden dimensions does not match.")
+        x += self.pe[:, :x.shape[2], :x.shape[3]]
+        return x
+
+
-- 
cgit v1.2.3-70-g09d2