From c614c472707910658b86bb28b9f02062e6982999 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Fri, 30 Sep 2022 01:12:13 +0200
Subject: Make rotary pos encoding mandatory

---
 text_recognizer/networks/text_decoder.py           |  5 +--
 text_recognizer/networks/transformer/decoder.py    |  6 ++--
 .../networks/transformer/decoder_block.py          |  2 +-
 .../networks/transformer/embeddings/absolute.py    | 34 --------------------
 .../networks/transformer/embeddings/fourier.py     | 36 ----------------------
 5 files changed, 4 insertions(+), 79 deletions(-)
 delete mode 100644 text_recognizer/networks/transformer/embeddings/absolute.py
 delete mode 100644 text_recognizer/networks/transformer/embeddings/fourier.py

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/text_decoder.py b/text_recognizer/networks/text_decoder.py
index c054b41..7ee6720 100644
--- a/text_recognizer/networks/text_decoder.py
+++ b/text_recognizer/networks/text_decoder.py
@@ -1,5 +1,5 @@
 """Text decoder."""
-from typing import Type
+from typing import Optional, Type
 
 import torch
 from torch import Tensor, nn
@@ -16,7 +16,6 @@ class TextDecoder(nn.Module):
         num_classes: int,
         pad_index: Tensor,
         decoder: Decoder,
-        token_pos_embedding: Type[nn.Module],
     ) -> None:
         super().__init__()
         self.hidden_dim = hidden_dim
@@ -26,7 +25,6 @@ class TextDecoder(nn.Module):
         self.token_embedding = nn.Embedding(
             num_embeddings=self.num_classes, embedding_dim=self.hidden_dim
         )
-        self.token_pos_embedding = token_pos_embedding
         self.to_logits = nn.Linear(
             in_features=self.hidden_dim, out_features=self.num_classes
         )
@@ -52,7 +50,6 @@ class TextDecoder(nn.Module):
         tokens = tokens.long()
         mask = tokens != self.pad_index
         tokens = self.token_embedding(tokens)
-        tokens = tokens + self.token_pos_embedding(tokens)
         tokens = self.decoder(x=tokens, context=img_features, mask=mask)
         logits = (
             tokens @ torch.transpose(self.token_embedding.weight.to(tokens.dtype), 0, 1)
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py
index 741f5b3..09d2dce 100644
--- a/text_recognizer/networks/transformer/decoder.py
+++ b/text_recognizer/networks/transformer/decoder.py
@@ -1,13 +1,11 @@
 """Transformer decoder module."""
 from copy import deepcopy
-from typing import Optional, Type
+from typing import Optional
 
 from torch import Tensor, nn
 
-from text_recognizer.networks.transformer.attention import Attention
 from text_recognizer.networks.transformer.decoder_block import DecoderBlock
 from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding
-from text_recognizer.networks.transformer.ff import FeedForward
 
 
 class Decoder(nn.Module):
@@ -18,7 +16,7 @@ class Decoder(nn.Module):
         depth: int,
         dim: int,
         block: DecoderBlock,
-        rotary_embedding: Optional[RotaryEmbedding] = None,
+        rotary_embedding: RotaryEmbedding,
     ) -> None:
         super().__init__()
         self.depth = depth
diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py
index 2dc4ddf..f7ae454 100644
--- a/text_recognizer/networks/transformer/decoder_block.py
+++ b/text_recognizer/networks/transformer/decoder_block.py
@@ -30,9 +30,9 @@ class DecoderBlock(nn.Module):
     def forward(
         self,
         x: Tensor,
+        rotary_embedding: RotaryEmbedding,
         context: Optional[Tensor] = None,
         mask: Optional[Tensor] = None,
-        rotary_embedding: Optional[RotaryEmbedding] = None,
     ) -> Tensor:
         """Applies decoder block on input signals."""
         x = x + self.attn(self.ln_attn(x), mask=mask, rotary_embedding=rotary_embedding)
diff --git a/text_recognizer/networks/transformer/embeddings/absolute.py b/text_recognizer/networks/transformer/embeddings/absolute.py
deleted file mode 100644
index 9274b55..0000000
--- a/text_recognizer/networks/transformer/embeddings/absolute.py
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Absolute positional embedding."""
-
-import torch
-import torch.nn.functional as F
-from einops import rearrange
-from torch import nn
-
-
-def l2norm(t, groups=1):
-    t = rearrange(t, "... (g d) -> ... g d", g=groups)
-    t = F.normalize(t, p=2, dim=-1)
-    return rearrange(t, "... g d -> ... (g d)")
-
-
-class AbsolutePositionalEmbedding(nn.Module):
-    def __init__(self, dim, max_seq_len, l2norm_embed=False):
-        super().__init__()
-        self.scale = dim**-0.5 if not l2norm_embed else 1.0
-        self.max_seq_len = max_seq_len
-        self.l2norm_embed = l2norm_embed
-        self.emb = nn.Embedding(max_seq_len, dim)
-
-    def forward(self, x, pos=None):
-        seq_len = x.shape[1]
-        assert (
-            seq_len <= self.max_seq_len
-        ), f"you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}"
-
-        if pos is None:
-            pos = torch.arange(seq_len, device=x.device)
-
-        pos_emb = self.emb(pos)
-        pos_emb = pos_emb * self.scale
-        return l2norm(pos_emb) if self.l2norm_embed else pos_emb
diff --git a/text_recognizer/networks/transformer/embeddings/fourier.py b/text_recognizer/networks/transformer/embeddings/fourier.py
deleted file mode 100644
index 28da7a1..0000000
--- a/text_recognizer/networks/transformer/embeddings/fourier.py
+++ /dev/null
@@ -1,36 +0,0 @@
-"""Fourier positional embedding."""
-import numpy as np
-import torch
-from torch import Tensor, nn
-
-
-class PositionalEncoding(nn.Module):
-    """Encodes a sense of distance or time for transformer networks."""
-
-    def __init__(self, dim: int, dropout_rate: float, max_len: int = 1000) -> None:
-        super().__init__()
-        self.dropout = nn.Dropout(p=dropout_rate)
-        pe = self.make_pe(dim, max_len)
-        self.register_buffer("pe", pe)
-
-    @staticmethod
-    def make_pe(hidden_dim: int, max_len: int) -> Tensor:
-        """Returns positional encoding."""
-        pe = torch.zeros(max_len, hidden_dim)
-        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
-        div_term = torch.exp(
-            torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
-        )
-
-        pe[:, 0::2] = torch.sin(position * div_term)
-        pe[:, 1::2] = torch.cos(position * div_term)
-        pe = pe.unsqueeze(1)
-        return pe
-
-    def forward(self, x: Tensor) -> Tensor:
-        """Encodes the tensor with a postional embedding."""
-        # [T, B, D]
-        if x.shape[2] != self.pe.shape[2]:
-            raise ValueError("x shape does not match pe in the 3rd dim.")
-        x = x + self.pe[: x.shape[0]]
-        return self.dropout(x)
-- 
cgit v1.2.3-70-g09d2