From d2c4d31a94d9813b6985108b743a5bd3ffed36b9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 4 Apr 2021 12:45:54 +0200 Subject: Add 2d positional encoding --- notebooks/03-look-at-iam-paragraphs.ipynb | 16 +++---- .../networks/transformer/positional_encoding.py | 50 +++++++++++++++++++--- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index 69e5996..73045c6 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "4b00e00c", + "id": "a6f19997", "metadata": {}, "outputs": [ { @@ -40,7 +40,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "a9955e92", + "id": "abe7e727", "metadata": {}, "outputs": [], "source": [ @@ -57,7 +57,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "bd882f2d", + "id": "10519f10", "metadata": {}, "outputs": [ { @@ -94,7 +94,7 @@ { "cell_type": "code", "execution_count": 4, - "id": "8a2b8cc5", + "id": "2672fb27", "metadata": { "scrolled": false }, @@ -172,7 +172,7 @@ { "cell_type": "code", "execution_count": 5, - "id": "d7884595", + "id": "8b9ef38c", "metadata": { "scrolled": false }, @@ -251,7 +251,7 @@ { "cell_type": "code", "execution_count": 8, - "id": "67f6c35e", + "id": "09b91f61", "metadata": {}, "outputs": [ { @@ -286,7 +286,7 @@ { "cell_type": "code", "execution_count": 9, - "id": "69c4dc90", + "id": "c883fa43", "metadata": { "scrolled": false }, @@ -364,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7671f207", + "id": "6703bfaf", "metadata": {}, "outputs": [], "source": [] diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index 1ba5537..d03f630 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -1,4 +1,5 @@ """A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" +from einops import repeat import numpy as np import torch from torch import nn @@ -13,20 +14,57 @@ class PositionalEncoding(nn.Module): ) -> None: super().__init__() self.dropout = nn.Dropout(p=dropout_rate) - self.max_len = max_len - + pe = self.make_pe(hidden_dim, max_len) + self.register_buffer("pe", pe) + + @staticmethod + def make_pe(hidden_dim: int, max_len: int) -> Tensor: + """Returns positional encoding.""" pe = torch.zeros(max_len, hidden_dim) - position = torch.arange(0, max_len).unsqueeze(1) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( - torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim) + torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - self.register_buffer("pe", pe) + pe = pe.unsqueeze(1) + return pe def forward(self, x: Tensor) -> Tensor: """Encodes the tensor with a postional embedding.""" x = x + self.pe[:, : x.shape[1]] return self.dropout(x) + + +class PositionalEncoding2D(nn.Module): + """Positional encodings for feature maps.""" + + def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int =2048) -> None: + super().__init__() + if hidden_dim % 2 != 0: + raise ValueError(f"Embedding depth {hidden_dim} is not even!") + self.hidden_dim = hidden_dim + pe = self.make_pe(hidden_dim, max_h, max_w) + self.register_buffer("pe", pe) + + def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: + """Returns 2d postional encoding.""" + pe_h = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [H, 1, D // 2] + pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) + + pe_w = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [W, 1, D // 2] + pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h) + + pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] + return pe + + def forward(self, x: Tensor) -> Tensor: + """Adds 2D postional encoding to input tensor.""" + # Assumes x hase shape [B, D, H, W] + if x.shape[1] != self.pe.shape[0]: + raise ValueError("Hidden dimensions does not match.") + x += self.pe[:, :x.shape[2], :x.shape[3]] + return x + + -- cgit v1.2.3-70-g09d2