summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/positional_encodings
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/positional_encodings')
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/__init__.py8
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py17
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/positional_encoding.py85
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py42
4 files changed, 0 insertions, 152 deletions
diff --git a/text_recognizer/networks/transformer/positional_encodings/__init__.py b/text_recognizer/networks/transformer/positional_encodings/__init__.py
deleted file mode 100644
index 2ed8a12..0000000
--- a/text_recognizer/networks/transformer/positional_encodings/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-"""Positional encoding for transformers."""
-from .absolute_embedding import AbsolutePositionalEmbedding
-from .positional_encoding import (
- PositionalEncoding,
- PositionalEncoding2D,
- target_padding_mask,
-)
-from .rotary_embedding import apply_rotary_pos_emb, RotaryEmbedding
diff --git a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py
deleted file mode 100644
index 7140537..0000000
--- a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py
+++ /dev/null
@@ -1,17 +0,0 @@
-"""Absolute positional embedding."""
-import torch
-from torch import nn, Tensor
-
-
-class AbsolutePositionalEmbedding(nn.Module):
- def __init__(self, dim: int, max_seq_len: int) -> None:
- super().__init__()
- self.emb = nn.Embedding(max_seq_len, dim)
- self._weight_init()
-
- def _weight_init(self) -> None:
- nn.init.normal_(self.emb.weight, std=0.02)
-
- def forward(self, x: Tensor) -> Tensor:
- n = torch.arange(x.shape[1], device=x.device)
- return self.emb(n)[None, :, :]
diff --git a/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py b/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
deleted file mode 100644
index c50afc3..0000000
--- a/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
+++ /dev/null
@@ -1,85 +0,0 @@
-"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
-from einops import repeat
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class PositionalEncoding(nn.Module):
- """Encodes a sense of distance or time for transformer networks."""
-
- def __init__(
- self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
- ) -> None:
- super().__init__()
- self.dropout = nn.Dropout(p=dropout_rate)
- pe = self.make_pe(hidden_dim, max_len)
- self.register_buffer("pe", pe)
-
- @staticmethod
- def make_pe(hidden_dim: int, max_len: int) -> Tensor:
- """Returns positional encoding."""
- pe = torch.zeros(max_len, hidden_dim)
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
- )
-
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = pe.unsqueeze(1)
- return pe
-
- def forward(self, x: Tensor) -> Tensor:
- """Encodes the tensor with a postional embedding."""
- # [T, B, D]
- if x.shape[2] != self.pe.shape[2]:
- raise ValueError(f"x shape does not match pe in the 3rd dim.")
- x = x + self.pe[: x.shape[0]]
- return self.dropout(x)
-
-
-class PositionalEncoding2D(nn.Module):
- """Positional encodings for feature maps."""
-
- def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
- super().__init__()
- if hidden_dim % 2 != 0:
- raise ValueError(f"Embedding depth {hidden_dim} is not even!")
- self.hidden_dim = hidden_dim
- pe = self.make_pe(hidden_dim, max_h, max_w)
- self.register_buffer("pe", pe)
-
- @staticmethod
- def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
- """Returns 2d postional encoding."""
- pe_h = PositionalEncoding.make_pe(
- hidden_dim // 2, max_len=max_h
- ) # [H, 1, D // 2]
- pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
-
- pe_w = PositionalEncoding.make_pe(
- hidden_dim // 2, max_len=max_w
- ) # [W, 1, D // 2]
- pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h)
-
- pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W]
- return pe
-
- def forward(self, x: Tensor) -> Tensor:
- """Adds 2D postional encoding to input tensor."""
- # Assumes x hase shape [B, D, H, W]
- if x.shape[1] != self.pe.shape[0]:
- raise ValueError("Hidden dimensions does not match.")
- x += self.pe[:, : x.shape[2], : x.shape[3]]
- return x
-
-
-def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
- """Returns causal target mask."""
- trg_pad_mask = (trg != pad_index)[:, None, None]
- trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
- trg_mask = trg_pad_mask & trg_sub_mask
- return trg_mask
diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
deleted file mode 100644
index 2f58964..0000000
--- a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""Roatary embedding.
-
-Stolen from lucidrains:
- https://github.com/lucidrains/rotary-embedding-torch
-
-Explanation of roatary:
- https://blog.eleuther.ai/rotary-embeddings/
-"""
-from typing import Tuple
-
-from einops import rearrange
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class RotaryEmbedding(nn.Module):
- """Rotary positional embedding."""
-
- def __init__(self, dim: int):
- super().__init__()
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
- self.register_buffer("inv_freq", inv_freq)
-
- def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor:
- """Encodes tensor x with rotary embeddings."""
- t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
- freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
- emb = torch.cat((freqs, freqs), dim=-1)
- return rearrange(emb, "n d -> () () n d")
-
-
-def rotate_half(x: Tensor) -> Tensor:
- x = rearrange(x, "... (j d) -> ... j d", j=2)
- x1, x2 = x.unbind(dim=-2)
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor:
- seq_len = t.shape[-2]
- freqs = freqs[:, :, -seq_len:]
- return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())