summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/embeddings/fourier.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-28 21:21:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-28 21:21:00 +0200
commit2203a1ba52ab2f72682fcee738844ee9ec584bda (patch)
tree8a1bcc6ceb3f1c91d47d0f35ca64554bcce2fcbd /text_recognizer/networks/transformer/embeddings/fourier.py
parent863c42ea67823ad616ac7485fc8a1e5018cb4233 (diff)
Remove 2D positional embedding
Diffstat (limited to 'text_recognizer/networks/transformer/embeddings/fourier.py')
-rw-r--r--text_recognizer/networks/transformer/embeddings/fourier.py50
1 files changed, 2 insertions, 48 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/fourier.py b/text_recognizer/networks/transformer/embeddings/fourier.py
index c50afc3..ade589c 100644
--- a/text_recognizer/networks/transformer/embeddings/fourier.py
+++ b/text_recognizer/networks/transformer/embeddings/fourier.py
@@ -1,5 +1,4 @@
-"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
-from einops import repeat
+"""Fourier positional embedding."""
import numpy as np
import torch
from torch import nn
@@ -35,51 +34,6 @@ class PositionalEncoding(nn.Module):
"""Encodes the tensor with a postional embedding."""
# [T, B, D]
if x.shape[2] != self.pe.shape[2]:
- raise ValueError(f"x shape does not match pe in the 3rd dim.")
+ raise ValueError("x shape does not match pe in the 3rd dim.")
x = x + self.pe[: x.shape[0]]
return self.dropout(x)
-
-
-class PositionalEncoding2D(nn.Module):
- """Positional encodings for feature maps."""
-
- def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
- super().__init__()
- if hidden_dim % 2 != 0:
- raise ValueError(f"Embedding depth {hidden_dim} is not even!")
- self.hidden_dim = hidden_dim
- pe = self.make_pe(hidden_dim, max_h, max_w)
- self.register_buffer("pe", pe)
-
- @staticmethod
- def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
- """Returns 2d postional encoding."""
- pe_h = PositionalEncoding.make_pe(
- hidden_dim // 2, max_len=max_h
- ) # [H, 1, D // 2]
- pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
-
- pe_w = PositionalEncoding.make_pe(
- hidden_dim // 2, max_len=max_w
- ) # [W, 1, D // 2]
- pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h)
-
- pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W]
- return pe
-
- def forward(self, x: Tensor) -> Tensor:
- """Adds 2D postional encoding to input tensor."""
- # Assumes x hase shape [B, D, H, W]
- if x.shape[1] != self.pe.shape[0]:
- raise ValueError("Hidden dimensions does not match.")
- x += self.pe[:, : x.shape[2], : x.shape[3]]
- return x
-
-
-def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
- """Returns causal target mask."""
- trg_pad_mask = (trg != pad_index)[:, None, None]
- trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
- trg_mask = trg_pad_mask & trg_sub_mask
- return trg_mask