From fb90a53b1235fd836dee74452f3f2a621e0f363a Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 27 Oct 2021 22:13:54 +0200 Subject: Rename transformer embeddings --- text_recognizer/networks/transformer/attention.py | 4 +- .../networks/transformer/embeddings/__init__.py | 1 + .../networks/transformer/embeddings/absolute.py | 17 +++++ .../networks/transformer/embeddings/fourier.py | 85 ++++++++++++++++++++++ .../networks/transformer/embeddings/rotary.py | 42 +++++++++++ text_recognizer/networks/transformer/layers.py | 4 +- .../transformer/positional_encodings/__init__.py | 8 -- .../positional_encodings/absolute_embedding.py | 17 ----- .../positional_encodings/positional_encoding.py | 85 ---------------------- .../positional_encodings/rotary_embedding.py | 42 ----------- 10 files changed, 147 insertions(+), 158 deletions(-) create mode 100644 text_recognizer/networks/transformer/embeddings/__init__.py create mode 100644 text_recognizer/networks/transformer/embeddings/absolute.py create mode 100644 text_recognizer/networks/transformer/embeddings/fourier.py create mode 100644 text_recognizer/networks/transformer/embeddings/rotary.py delete mode 100644 text_recognizer/networks/transformer/positional_encodings/__init__.py delete mode 100644 text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py delete mode 100644 text_recognizer/networks/transformer/positional_encodings/positional_encoding.py delete mode 100644 text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py (limited to 'text_recognizer/networks/transformer') diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index e098b63..3d2ece1 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -10,9 +10,7 @@ from torch import nn from torch import Tensor import torch.nn.functional as F -from text_recognizer.networks.transformer.positional_encodings.rotary_embedding import ( - apply_rotary_pos_emb, -) +from text_recognizer.networks.transformer.embeddings.rotary import apply_rotary_pos_emb @attr.s(eq=False) diff --git a/text_recognizer/networks/transformer/embeddings/__init__.py b/text_recognizer/networks/transformer/embeddings/__init__.py new file mode 100644 index 0000000..bb3f904 --- /dev/null +++ b/text_recognizer/networks/transformer/embeddings/__init__.py @@ -0,0 +1 @@ +"""Positional encodings for transformers.""" diff --git a/text_recognizer/networks/transformer/embeddings/absolute.py b/text_recognizer/networks/transformer/embeddings/absolute.py new file mode 100644 index 0000000..7140537 --- /dev/null +++ b/text_recognizer/networks/transformer/embeddings/absolute.py @@ -0,0 +1,17 @@ +"""Absolute positional embedding.""" +import torch +from torch import nn, Tensor + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim: int, max_seq_len: int) -> None: + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self._weight_init() + + def _weight_init(self) -> None: + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x: Tensor) -> Tensor: + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] diff --git a/text_recognizer/networks/transformer/embeddings/fourier.py b/text_recognizer/networks/transformer/embeddings/fourier.py new file mode 100644 index 0000000..c50afc3 --- /dev/null +++ b/text_recognizer/networks/transformer/embeddings/fourier.py @@ -0,0 +1,85 @@ +"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" +from einops import repeat +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class PositionalEncoding(nn.Module): + """Encodes a sense of distance or time for transformer networks.""" + + def __init__( + self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 + ) -> None: + super().__init__() + self.dropout = nn.Dropout(p=dropout_rate) + pe = self.make_pe(hidden_dim, max_len) + self.register_buffer("pe", pe) + + @staticmethod + def make_pe(hidden_dim: int, max_len: int) -> Tensor: + """Returns positional encoding.""" + pe = torch.zeros(max_len, hidden_dim) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(1) + return pe + + def forward(self, x: Tensor) -> Tensor: + """Encodes the tensor with a postional embedding.""" + # [T, B, D] + if x.shape[2] != self.pe.shape[2]: + raise ValueError(f"x shape does not match pe in the 3rd dim.") + x = x + self.pe[: x.shape[0]] + return self.dropout(x) + + +class PositionalEncoding2D(nn.Module): + """Positional encodings for feature maps.""" + + def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None: + super().__init__() + if hidden_dim % 2 != 0: + raise ValueError(f"Embedding depth {hidden_dim} is not even!") + self.hidden_dim = hidden_dim + pe = self.make_pe(hidden_dim, max_h, max_w) + self.register_buffer("pe", pe) + + @staticmethod + def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: + """Returns 2d postional encoding.""" + pe_h = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_h + ) # [H, 1, D // 2] + pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) + + pe_w = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_w + ) # [W, 1, D // 2] + pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h) + + pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] + return pe + + def forward(self, x: Tensor) -> Tensor: + """Adds 2D postional encoding to input tensor.""" + # Assumes x hase shape [B, D, H, W] + if x.shape[1] != self.pe.shape[0]: + raise ValueError("Hidden dimensions does not match.") + x += self.pe[:, : x.shape[2], : x.shape[3]] + return x + + +def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor: + """Returns causal target mask.""" + trg_pad_mask = (trg != pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask diff --git a/text_recognizer/networks/transformer/embeddings/rotary.py b/text_recognizer/networks/transformer/embeddings/rotary.py new file mode 100644 index 0000000..2f58964 --- /dev/null +++ b/text_recognizer/networks/transformer/embeddings/rotary.py @@ -0,0 +1,42 @@ +"""Roatary embedding. + +Stolen from lucidrains: + https://github.com/lucidrains/rotary-embedding-torch + +Explanation of roatary: + https://blog.eleuther.ai/rotary-embeddings/ +""" +from typing import Tuple + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + + +class RotaryEmbedding(nn.Module): + """Rotary positional embedding.""" + + def __init__(self, dim: int): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: + """Encodes tensor x with rotary embeddings.""" + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return rearrange(emb, "n d -> () () n d") + + +def rotate_half(x: Tensor) -> Tensor: + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor: + seq_len = t.shape[-2] + freqs = freqs[:, :, -seq_len:] + return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 70a0ac7..2b8427d 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -4,10 +4,8 @@ from typing import Any, Dict, Optional, Tuple import attr from torch import nn, Tensor +from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding from text_recognizer.networks.transformer.residual import Residual -from text_recognizer.networks.transformer.positional_encodings.rotary_embedding import ( - RotaryEmbedding, -) from text_recognizer.networks.util import load_partial_fn diff --git a/text_recognizer/networks/transformer/positional_encodings/__init__.py b/text_recognizer/networks/transformer/positional_encodings/__init__.py deleted file mode 100644 index 2ed8a12..0000000 --- a/text_recognizer/networks/transformer/positional_encodings/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Positional encoding for transformers.""" -from .absolute_embedding import AbsolutePositionalEmbedding -from .positional_encoding import ( - PositionalEncoding, - PositionalEncoding2D, - target_padding_mask, -) -from .rotary_embedding import apply_rotary_pos_emb, RotaryEmbedding diff --git a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py deleted file mode 100644 index 7140537..0000000 --- a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Absolute positional embedding.""" -import torch -from torch import nn, Tensor - - -class AbsolutePositionalEmbedding(nn.Module): - def __init__(self, dim: int, max_seq_len: int) -> None: - super().__init__() - self.emb = nn.Embedding(max_seq_len, dim) - self._weight_init() - - def _weight_init(self) -> None: - nn.init.normal_(self.emb.weight, std=0.02) - - def forward(self, x: Tensor) -> Tensor: - n = torch.arange(x.shape[1], device=x.device) - return self.emb(n)[None, :, :] diff --git a/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py b/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py deleted file mode 100644 index c50afc3..0000000 --- a/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py +++ /dev/null @@ -1,85 +0,0 @@ -"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" -from einops import repeat -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class PositionalEncoding(nn.Module): - """Encodes a sense of distance or time for transformer networks.""" - - def __init__( - self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 - ) -> None: - super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) - pe = self.make_pe(hidden_dim, max_len) - self.register_buffer("pe", pe) - - @staticmethod - def make_pe(hidden_dim: int, max_len: int) -> Tensor: - """Returns positional encoding.""" - pe = torch.zeros(max_len, hidden_dim) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim) - ) - - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(1) - return pe - - def forward(self, x: Tensor) -> Tensor: - """Encodes the tensor with a postional embedding.""" - # [T, B, D] - if x.shape[2] != self.pe.shape[2]: - raise ValueError(f"x shape does not match pe in the 3rd dim.") - x = x + self.pe[: x.shape[0]] - return self.dropout(x) - - -class PositionalEncoding2D(nn.Module): - """Positional encodings for feature maps.""" - - def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None: - super().__init__() - if hidden_dim % 2 != 0: - raise ValueError(f"Embedding depth {hidden_dim} is not even!") - self.hidden_dim = hidden_dim - pe = self.make_pe(hidden_dim, max_h, max_w) - self.register_buffer("pe", pe) - - @staticmethod - def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: - """Returns 2d postional encoding.""" - pe_h = PositionalEncoding.make_pe( - hidden_dim // 2, max_len=max_h - ) # [H, 1, D // 2] - pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) - - pe_w = PositionalEncoding.make_pe( - hidden_dim // 2, max_len=max_w - ) # [W, 1, D // 2] - pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h) - - pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] - return pe - - def forward(self, x: Tensor) -> Tensor: - """Adds 2D postional encoding to input tensor.""" - # Assumes x hase shape [B, D, H, W] - if x.shape[1] != self.pe.shape[0]: - raise ValueError("Hidden dimensions does not match.") - x += self.pe[:, : x.shape[2], : x.shape[3]] - return x - - -def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor: - """Returns causal target mask.""" - trg_pad_mask = (trg != pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py deleted file mode 100644 index 2f58964..0000000 --- a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Roatary embedding. - -Stolen from lucidrains: - https://github.com/lucidrains/rotary-embedding-torch - -Explanation of roatary: - https://blog.eleuther.ai/rotary-embeddings/ -""" -from typing import Tuple - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor - - -class RotaryEmbedding(nn.Module): - """Rotary positional embedding.""" - - def __init__(self, dim: int): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - - def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: - """Encodes tensor x with rotary embeddings.""" - t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) - freqs = torch.einsum("i , j -> i j", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return rearrange(emb, "n d -> () () n d") - - -def rotate_half(x: Tensor) -> Tensor: - x = rearrange(x, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor: - seq_len = t.shape[-2] - freqs = freqs[:, :, -seq_len:] - return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) -- cgit v1.2.3-70-g09d2