summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 22:46:09 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 22:46:09 +0200
commitc9c60678673e19ad3367339eb8e7a093e5a98474 (patch)
treeb787a7fbb535c2ee44f935720d75034cc24ffd30 /text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
parenta2a3133ed5da283888efbdb9924d0e3733c274c8 (diff)
Reformatting of positional encodings and ViT working
Diffstat (limited to 'text_recognizer/networks/transformer/positional_encodings/positional_encoding.py')
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/positional_encoding.py85
1 files changed, 85 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py b/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
new file mode 100644
index 0000000..c50afc3
--- /dev/null
+++ b/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
@@ -0,0 +1,85 @@
+"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+from einops import repeat
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+ """Encodes a sense of distance or time for transformer networks."""
+
+ def __init__(
+ self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
+ ) -> None:
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout_rate)
+ pe = self.make_pe(hidden_dim, max_len)
+ self.register_buffer("pe", pe)
+
+ @staticmethod
+ def make_pe(hidden_dim: int, max_len: int) -> Tensor:
+ """Returns positional encoding."""
+ pe = torch.zeros(max_len, hidden_dim)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
+ )
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(1)
+ return pe
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Encodes the tensor with a postional embedding."""
+ # [T, B, D]
+ if x.shape[2] != self.pe.shape[2]:
+ raise ValueError(f"x shape does not match pe in the 3rd dim.")
+ x = x + self.pe[: x.shape[0]]
+ return self.dropout(x)
+
+
+class PositionalEncoding2D(nn.Module):
+ """Positional encodings for feature maps."""
+
+ def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
+ super().__init__()
+ if hidden_dim % 2 != 0:
+ raise ValueError(f"Embedding depth {hidden_dim} is not even!")
+ self.hidden_dim = hidden_dim
+ pe = self.make_pe(hidden_dim, max_h, max_w)
+ self.register_buffer("pe", pe)
+
+ @staticmethod
+ def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
+ """Returns 2d postional encoding."""
+ pe_h = PositionalEncoding.make_pe(
+ hidden_dim // 2, max_len=max_h
+ ) # [H, 1, D // 2]
+ pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
+
+ pe_w = PositionalEncoding.make_pe(
+ hidden_dim // 2, max_len=max_w
+ ) # [W, 1, D // 2]
+ pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h)
+
+ pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W]
+ return pe
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Adds 2D postional encoding to input tensor."""
+ # Assumes x hase shape [B, D, H, W]
+ if x.shape[1] != self.pe.shape[0]:
+ raise ValueError("Hidden dimensions does not match.")
+ x += self.pe[:, : x.shape[2], : x.shape[3]]
+ return x
+
+
+def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
+ """Returns causal target mask."""
+ trg_pad_mask = (trg != pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask