From fb90a53b1235fd836dee74452f3f2a621e0f363a Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Wed, 27 Oct 2021 22:13:54 +0200
Subject: Rename transformer embeddings

---
 .../networks/transformer/embeddings/__init__.py    |  1 +
 .../networks/transformer/embeddings/absolute.py    | 17 +++++
 .../networks/transformer/embeddings/fourier.py     | 85 ++++++++++++++++++++++
 .../networks/transformer/embeddings/rotary.py      | 42 +++++++++++
 4 files changed, 145 insertions(+)
 create mode 100644 text_recognizer/networks/transformer/embeddings/__init__.py
 create mode 100644 text_recognizer/networks/transformer/embeddings/absolute.py
 create mode 100644 text_recognizer/networks/transformer/embeddings/fourier.py
 create mode 100644 text_recognizer/networks/transformer/embeddings/rotary.py

(limited to 'text_recognizer/networks/transformer/embeddings')

diff --git a/text_recognizer/networks/transformer/embeddings/__init__.py b/text_recognizer/networks/transformer/embeddings/__init__.py
new file mode 100644
index 0000000..bb3f904
--- /dev/null
+++ b/text_recognizer/networks/transformer/embeddings/__init__.py
@@ -0,0 +1 @@
+"""Positional encodings for transformers."""
diff --git a/text_recognizer/networks/transformer/embeddings/absolute.py b/text_recognizer/networks/transformer/embeddings/absolute.py
new file mode 100644
index 0000000..7140537
--- /dev/null
+++ b/text_recognizer/networks/transformer/embeddings/absolute.py
@@ -0,0 +1,17 @@
+"""Absolute positional embedding."""
+import torch
+from torch import nn, Tensor
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+    def __init__(self, dim: int, max_seq_len: int) -> None:
+        super().__init__()
+        self.emb = nn.Embedding(max_seq_len, dim)
+        self._weight_init()
+
+    def _weight_init(self) -> None:
+        nn.init.normal_(self.emb.weight, std=0.02)
+
+    def forward(self, x: Tensor) -> Tensor:
+        n = torch.arange(x.shape[1], device=x.device)
+        return self.emb(n)[None, :, :]
diff --git a/text_recognizer/networks/transformer/embeddings/fourier.py b/text_recognizer/networks/transformer/embeddings/fourier.py
new file mode 100644
index 0000000..c50afc3
--- /dev/null
+++ b/text_recognizer/networks/transformer/embeddings/fourier.py
@@ -0,0 +1,85 @@
+"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+from einops import repeat
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+    """Encodes a sense of distance or time for transformer networks."""
+
+    def __init__(
+        self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
+    ) -> None:
+        super().__init__()
+        self.dropout = nn.Dropout(p=dropout_rate)
+        pe = self.make_pe(hidden_dim, max_len)
+        self.register_buffer("pe", pe)
+
+    @staticmethod
+    def make_pe(hidden_dim: int, max_len: int) -> Tensor:
+        """Returns positional encoding."""
+        pe = torch.zeros(max_len, hidden_dim)
+        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
+        )
+
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(1)
+        return pe
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Encodes the tensor with a postional embedding."""
+        # [T, B, D]
+        if x.shape[2] != self.pe.shape[2]:
+            raise ValueError(f"x shape does not match pe in the 3rd dim.")
+        x = x + self.pe[: x.shape[0]]
+        return self.dropout(x)
+
+
+class PositionalEncoding2D(nn.Module):
+    """Positional encodings for feature maps."""
+
+    def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
+        super().__init__()
+        if hidden_dim % 2 != 0:
+            raise ValueError(f"Embedding depth {hidden_dim} is not even!")
+        self.hidden_dim = hidden_dim
+        pe = self.make_pe(hidden_dim, max_h, max_w)
+        self.register_buffer("pe", pe)
+
+    @staticmethod
+    def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
+        """Returns 2d postional encoding."""
+        pe_h = PositionalEncoding.make_pe(
+            hidden_dim // 2, max_len=max_h
+        )  # [H, 1, D // 2]
+        pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
+
+        pe_w = PositionalEncoding.make_pe(
+            hidden_dim // 2, max_len=max_w
+        )  # [W, 1, D // 2]
+        pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h)
+
+        pe = torch.cat([pe_h, pe_w], dim=0)  # [D, H, W]
+        return pe
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Adds 2D postional encoding to input tensor."""
+        # Assumes x hase shape [B, D, H, W]
+        if x.shape[1] != self.pe.shape[0]:
+            raise ValueError("Hidden dimensions does not match.")
+        x += self.pe[:, : x.shape[2], : x.shape[3]]
+        return x
+
+
+def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
+    """Returns causal target mask."""
+    trg_pad_mask = (trg != pad_index)[:, None, None]
+    trg_len = trg.shape[1]
+    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
+    trg_mask = trg_pad_mask & trg_sub_mask
+    return trg_mask
diff --git a/text_recognizer/networks/transformer/embeddings/rotary.py b/text_recognizer/networks/transformer/embeddings/rotary.py
new file mode 100644
index 0000000..2f58964
--- /dev/null
+++ b/text_recognizer/networks/transformer/embeddings/rotary.py
@@ -0,0 +1,42 @@
+"""Roatary embedding.
+
+Stolen from lucidrains:
+    https://github.com/lucidrains/rotary-embedding-torch
+
+Explanation of roatary:
+    https://blog.eleuther.ai/rotary-embeddings/
+"""
+from typing import Tuple
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class RotaryEmbedding(nn.Module):
+    """Rotary positional embedding."""
+
+    def __init__(self, dim: int):
+        super().__init__()
+        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+        self.register_buffer("inv_freq", inv_freq)
+
+    def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor:
+        """Encodes tensor x with rotary embeddings."""
+        t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+        freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
+        emb = torch.cat((freqs, freqs), dim=-1)
+        return rearrange(emb, "n d -> () () n d")
+
+
+def rotate_half(x: Tensor) -> Tensor:
+    x = rearrange(x, "... (j d) -> ... j d", j=2)
+    x1, x2 = x.unbind(dim=-2)
+    return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor:
+    seq_len = t.shape[-2]
+    freqs = freqs[:, :, -seq_len:]
+    return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
-- 
cgit v1.2.3-70-g09d2