summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 12:45:54 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 12:45:54 +0200
commitd2c4d31a94d9813b6985108b743a5bd3ffed36b9 (patch)
treee0c31ed89a4ea966e56eddba25956ae6a2aa904a
parent186edf0890953f070cf707b6c3aef26961e1721f (diff)
Add 2d positional encoding
-rw-r--r--notebooks/03-look-at-iam-paragraphs.ipynb16
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py50
2 files changed, 52 insertions, 14 deletions
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb
index 69e5996..73045c6 100644
--- a/notebooks/03-look-at-iam-paragraphs.ipynb
+++ b/notebooks/03-look-at-iam-paragraphs.ipynb
@@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": 7,
- "id": "4b00e00c",
+ "id": "a6f19997",
"metadata": {},
"outputs": [
{
@@ -40,7 +40,7 @@
{
"cell_type": "code",
"execution_count": 2,
- "id": "a9955e92",
+ "id": "abe7e727",
"metadata": {},
"outputs": [],
"source": [
@@ -57,7 +57,7 @@
{
"cell_type": "code",
"execution_count": 3,
- "id": "bd882f2d",
+ "id": "10519f10",
"metadata": {},
"outputs": [
{
@@ -94,7 +94,7 @@
{
"cell_type": "code",
"execution_count": 4,
- "id": "8a2b8cc5",
+ "id": "2672fb27",
"metadata": {
"scrolled": false
},
@@ -172,7 +172,7 @@
{
"cell_type": "code",
"execution_count": 5,
- "id": "d7884595",
+ "id": "8b9ef38c",
"metadata": {
"scrolled": false
},
@@ -251,7 +251,7 @@
{
"cell_type": "code",
"execution_count": 8,
- "id": "67f6c35e",
+ "id": "09b91f61",
"metadata": {},
"outputs": [
{
@@ -286,7 +286,7 @@
{
"cell_type": "code",
"execution_count": 9,
- "id": "69c4dc90",
+ "id": "c883fa43",
"metadata": {
"scrolled": false
},
@@ -364,7 +364,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7671f207",
+ "id": "6703bfaf",
"metadata": {},
"outputs": [],
"source": []
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index 1ba5537..d03f630 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -1,4 +1,5 @@
"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+from einops import repeat
import numpy as np
import torch
from torch import nn
@@ -13,20 +14,57 @@ class PositionalEncoding(nn.Module):
) -> None:
super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
- self.max_len = max_len
-
+ pe = self.make_pe(hidden_dim, max_len)
+ self.register_buffer("pe", pe)
+
+ @staticmethod
+ def make_pe(hidden_dim: int, max_len: int) -> Tensor:
+ """Returns positional encoding."""
pe = torch.zeros(max_len, hidden_dim)
- position = torch.arange(0, max_len).unsqueeze(1)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
- torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
+ torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
- pe = pe.unsqueeze(0)
- self.register_buffer("pe", pe)
+ pe = pe.unsqueeze(1)
+ return pe
def forward(self, x: Tensor) -> Tensor:
"""Encodes the tensor with a postional embedding."""
x = x + self.pe[:, : x.shape[1]]
return self.dropout(x)
+
+
+class PositionalEncoding2D(nn.Module):
+ """Positional encodings for feature maps."""
+
+ def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int =2048) -> None:
+ super().__init__()
+ if hidden_dim % 2 != 0:
+ raise ValueError(f"Embedding depth {hidden_dim} is not even!")
+ self.hidden_dim = hidden_dim
+ pe = self.make_pe(hidden_dim, max_h, max_w)
+ self.register_buffer("pe", pe)
+
+ def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
+ """Returns 2d postional encoding."""
+ pe_h = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [H, 1, D // 2]
+ pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
+
+ pe_w = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [W, 1, D // 2]
+ pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h)
+
+ pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W]
+ return pe
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Adds 2D postional encoding to input tensor."""
+ # Assumes x hase shape [B, D, H, W]
+ if x.shape[1] != self.pe.shape[0]:
+ raise ValueError("Hidden dimensions does not match.")
+ x += self.pe[:, :x.shape[2], :x.shape[3]]
+ return x
+
+