From c9c60678673e19ad3367339eb8e7a093e5a98474 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 9 May 2021 22:46:09 +0200 Subject: Reformatting of positional encodings and ViT working --- text_recognizer/networks/transformer/__init__.py | 7 - text_recognizer/networks/transformer/attention.py | 4 +- text_recognizer/networks/transformer/layers.py | 29 +- .../networks/transformer/nystromer/attention.py | 4 +- .../networks/transformer/positional_encoding.py | 85 ------ .../transformer/positional_encodings/__init__.py | 4 + .../positional_encodings/absolute_embedding.py | 16 + .../positional_encodings/positional_encoding.py | 85 ++++++ .../positional_encodings/rotary_embedding.py | 39 +++ .../networks/transformer/rotary_embedding.py | 39 --- .../networks/transformer/transformer.py | 321 ++++----------------- text_recognizer/networks/transformer/vit.py | 46 +++ text_recognizer/networks/vision_transformer.py | 7 + 13 files changed, 284 insertions(+), 402 deletions(-) delete mode 100644 text_recognizer/networks/transformer/positional_encoding.py create mode 100644 text_recognizer/networks/transformer/positional_encodings/__init__.py create mode 100644 text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py create mode 100644 text_recognizer/networks/transformer/positional_encodings/positional_encoding.py create mode 100644 text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py delete mode 100644 text_recognizer/networks/transformer/rotary_embedding.py create mode 100644 text_recognizer/networks/vision_transformer.py (limited to 'text_recognizer') diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 4ff48f7..a3f3011 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -1,8 +1 @@ """Transformer modules.""" -from .positional_encoding import ( - PositionalEncoding, - PositionalEncoding2D, - target_padding_mask, -) - -# from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 623d680..eabeadf 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -9,7 +9,9 @@ from torch import nn from torch import Tensor import torch.nn.functional as F -from text_recognizer.networks.transformer.rotary_embedding import apply_rotary_pos_emb +from text_recognizer.networks.transformer.positional_encodings.rotary_embedding import ( + apply_rotary_pos_emb, +) class Attention(nn.Module): diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index a2fdb1a..4063425 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,6 +1,6 @@ """Generates the attention layer architecture.""" from functools import partial -from typing import Dict, Optional, Type +from typing import Any, Dict, Optional, Type from click.types import Tuple @@ -36,12 +36,11 @@ class AttentionLayers(nn.Module): norm_fn = partial(norm_fn, dim=dim) ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) self.layer_types = self._get_layer_types(cross_attend) * depth - self.layers = self._build_network( - causal, attn_fn, norm_fn, ff_fn, residual_fn - ) + self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn, residual_fn) rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None self.pre_norm = pre_norm + self.has_pos_emb = True if self.rotary_emb is not None else False @staticmethod def _get_layer_types(cross_attend: bool) -> Tuple: @@ -70,7 +69,7 @@ class AttentionLayers(nn.Module): residual_fn = residual_fn() - layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) + layers.append(nn.modulelist([norm_fn(), layer, residual_fn])) return layers def forward( @@ -82,10 +81,12 @@ class AttentionLayers(nn.Module): ) -> Tensor: rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None - for i, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + for i, (layer_type, (norm, block, residual_fn)) in enumerate( + zip(self.layer_types, self.layers) + ): is_last = i == len(self.layers) - 1 - - residual = x + + residual = x if self.pre_norm: x = norm(x) @@ -103,3 +104,15 @@ class AttentionLayers(nn.Module): x = norm(x) return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs: Any) -> None: + assert "causal" not in kwargs, "Cannot set causality on encoder" + super().__init__(causal=False, **kwargs) + + +class Decoder(AttentionLayers): + def __init__(self, **kwargs: Any) -> None: + assert "causal" not in kwargs, "Cannot set causality on decoder" + super().__init__(causal=True, **kwargs) diff --git a/text_recognizer/networks/transformer/nystromer/attention.py b/text_recognizer/networks/transformer/nystromer/attention.py index c2871fb..5ab19cf 100644 --- a/text_recognizer/networks/transformer/nystromer/attention.py +++ b/text_recognizer/networks/transformer/nystromer/attention.py @@ -157,14 +157,14 @@ class NystromAttention(nn.Module): self, x: Tensor, mask: Optional[Tensor] = None, return_attn: bool = False ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Compute the Nystrom attention.""" - _, n, _, h, m = x.shape, self.num_heads + _, n, _, h, m = *x.shape, self.num_heads, self.num_landmarks if n % m != 0: x, mask = self._pad_sequence(x, mask, n, m) q, k, v = self.qkv_fn(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - q *= self.scale + q = q * self.scale out, attn = self._nystrom_attention(q, k, v, mask, n, m, return_attn) diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py deleted file mode 100644 index c50afc3..0000000 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ /dev/null @@ -1,85 +0,0 @@ -"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" -from einops import repeat -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class PositionalEncoding(nn.Module): - """Encodes a sense of distance or time for transformer networks.""" - - def __init__( - self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 - ) -> None: - super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) - pe = self.make_pe(hidden_dim, max_len) - self.register_buffer("pe", pe) - - @staticmethod - def make_pe(hidden_dim: int, max_len: int) -> Tensor: - """Returns positional encoding.""" - pe = torch.zeros(max_len, hidden_dim) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim) - ) - - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(1) - return pe - - def forward(self, x: Tensor) -> Tensor: - """Encodes the tensor with a postional embedding.""" - # [T, B, D] - if x.shape[2] != self.pe.shape[2]: - raise ValueError(f"x shape does not match pe in the 3rd dim.") - x = x + self.pe[: x.shape[0]] - return self.dropout(x) - - -class PositionalEncoding2D(nn.Module): - """Positional encodings for feature maps.""" - - def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None: - super().__init__() - if hidden_dim % 2 != 0: - raise ValueError(f"Embedding depth {hidden_dim} is not even!") - self.hidden_dim = hidden_dim - pe = self.make_pe(hidden_dim, max_h, max_w) - self.register_buffer("pe", pe) - - @staticmethod - def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: - """Returns 2d postional encoding.""" - pe_h = PositionalEncoding.make_pe( - hidden_dim // 2, max_len=max_h - ) # [H, 1, D // 2] - pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) - - pe_w = PositionalEncoding.make_pe( - hidden_dim // 2, max_len=max_w - ) # [W, 1, D // 2] - pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h) - - pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] - return pe - - def forward(self, x: Tensor) -> Tensor: - """Adds 2D postional encoding to input tensor.""" - # Assumes x hase shape [B, D, H, W] - if x.shape[1] != self.pe.shape[0]: - raise ValueError("Hidden dimensions does not match.") - x += self.pe[:, : x.shape[2], : x.shape[3]] - return x - - -def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor: - """Returns causal target mask.""" - trg_pad_mask = (trg != pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask diff --git a/text_recognizer/networks/transformer/positional_encodings/__init__.py b/text_recognizer/networks/transformer/positional_encodings/__init__.py new file mode 100644 index 0000000..91278ee --- /dev/null +++ b/text_recognizer/networks/transformer/positional_encodings/__init__.py @@ -0,0 +1,4 @@ +"""Positional encoding for transformers.""" +from .absolute_embedding import AbsolutePositionalEmbedding +from .positional_encoding import PositionalEncoding, PositionalEncoding2D +from .rotary_embedding import apply_rotary_pos_emb, RotaryEmbedding diff --git a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py new file mode 100644 index 0000000..9466f6e --- /dev/null +++ b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py @@ -0,0 +1,16 @@ +"""Absolute positional embedding.""" +from torch import nn, Tensor + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim: int, max_seq_len: int) -> None: + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self._weight_init() + + def _weight_init(self) -> None: + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x: Tensor) -> Tensor: + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] diff --git a/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py b/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py new file mode 100644 index 0000000..c50afc3 --- /dev/null +++ b/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py @@ -0,0 +1,85 @@ +"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" +from einops import repeat +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class PositionalEncoding(nn.Module): + """Encodes a sense of distance or time for transformer networks.""" + + def __init__( + self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 + ) -> None: + super().__init__() + self.dropout = nn.Dropout(p=dropout_rate) + pe = self.make_pe(hidden_dim, max_len) + self.register_buffer("pe", pe) + + @staticmethod + def make_pe(hidden_dim: int, max_len: int) -> Tensor: + """Returns positional encoding.""" + pe = torch.zeros(max_len, hidden_dim) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(1) + return pe + + def forward(self, x: Tensor) -> Tensor: + """Encodes the tensor with a postional embedding.""" + # [T, B, D] + if x.shape[2] != self.pe.shape[2]: + raise ValueError(f"x shape does not match pe in the 3rd dim.") + x = x + self.pe[: x.shape[0]] + return self.dropout(x) + + +class PositionalEncoding2D(nn.Module): + """Positional encodings for feature maps.""" + + def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None: + super().__init__() + if hidden_dim % 2 != 0: + raise ValueError(f"Embedding depth {hidden_dim} is not even!") + self.hidden_dim = hidden_dim + pe = self.make_pe(hidden_dim, max_h, max_w) + self.register_buffer("pe", pe) + + @staticmethod + def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: + """Returns 2d postional encoding.""" + pe_h = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_h + ) # [H, 1, D // 2] + pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) + + pe_w = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_w + ) # [W, 1, D // 2] + pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h) + + pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] + return pe + + def forward(self, x: Tensor) -> Tensor: + """Adds 2D postional encoding to input tensor.""" + # Assumes x hase shape [B, D, H, W] + if x.shape[1] != self.pe.shape[0]: + raise ValueError("Hidden dimensions does not match.") + x += self.pe[:, : x.shape[2], : x.shape[3]] + return x + + +def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor: + """Returns causal target mask.""" + trg_pad_mask = (trg != pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py new file mode 100644 index 0000000..5e80572 --- /dev/null +++ b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py @@ -0,0 +1,39 @@ +"""Roatary embedding. + +Stolen from lucidrains: + https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py + +Explanation of roatary: + https://blog.eleuther.ai/rotary-embeddings/ + +""" +from typing import Tuple + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb[None, :, :] + + +def rotate_half(x: Tensor) -> Tensor: + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]: + q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) + return q, k diff --git a/text_recognizer/networks/transformer/rotary_embedding.py b/text_recognizer/networks/transformer/rotary_embedding.py deleted file mode 100644 index 5e80572..0000000 --- a/text_recognizer/networks/transformer/rotary_embedding.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Roatary embedding. - -Stolen from lucidrains: - https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py - -Explanation of roatary: - https://blog.eleuther.ai/rotary-embeddings/ - -""" -from typing import Tuple - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim: int): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - - def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: - t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) - freqs = torch.einsum("i , j -> i j", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return emb[None, :, :] - - -def rotate_half(x: Tensor) -> Tensor: - x = rearrange(x, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]: - q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) - return q, k diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py index d49c85a..36f86ac 100644 --- a/text_recognizer/networks/transformer/transformer.py +++ b/text_recognizer/networks/transformer/transformer.py @@ -1,260 +1,61 @@ -# """Transfomer module.""" -# import copy -# from typing import Dict, Optional, Type, Union -# -# import numpy as np -# import torch -# from torch import nn -# from torch import Tensor -# import torch.nn.functional as F -# -# from text_recognizer.networks.transformer.attention import MultiHeadAttention -# from text_recognizer.networks.util import activation_function -# -# -# class GEGLU(nn.Module): -# """GLU activation for improving feedforward activations.""" -# -# def __init__(self, dim_in: int, dim_out: int) -> None: -# super().__init__() -# self.proj = nn.Linear(dim_in, dim_out * 2) -# -# def forward(self, x: Tensor) -> Tensor: -# """Forward propagation.""" -# x, gate = self.proj(x).chunk(2, dim=-1) -# return x * F.gelu(gate) -# -# -# def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: -# return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) -# -# -# class _IntraLayerConnection(nn.Module): -# """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" -# -# def __init__(self, dropout_rate: float, hidden_dim: int) -> None: -# super().__init__() -# self.norm = nn.LayerNorm(normalized_shape=hidden_dim) -# self.dropout = nn.Dropout(p=dropout_rate) -# -# def forward(self, src: Tensor, residual: Tensor) -> Tensor: -# return self.norm(self.dropout(src) + residual) -# -# -# class FeedForward(nn.Module): -# def __init__( -# self, -# hidden_dim: int, -# expansion_dim: int, -# dropout_rate: float, -# activation: str = "relu", -# ) -> None: -# super().__init__() -# -# in_projection = ( -# nn.Sequential( -# nn.Linear(hidden_dim, expansion_dim), activation_function(activation) -# ) -# if activation != "glu" -# else GEGLU(hidden_dim, expansion_dim) -# ) -# -# self.layer = nn.Sequential( -# in_projection, -# nn.Dropout(p=dropout_rate), -# nn.Linear(in_features=expansion_dim, out_features=hidden_dim), -# ) -# -# def forward(self, x: Tensor) -> Tensor: -# return self.layer(x) -# -# -# class EncoderLayer(nn.Module): -# """Transfomer encoding layer.""" -# -# def __init__( -# self, -# hidden_dim: int, -# num_heads: int, -# expansion_dim: int, -# dropout_rate: float, -# activation: str = "relu", -# ) -> None: -# super().__init__() -# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) -# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) -# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) -# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) -# -# def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: -# """Forward pass through the encoder.""" -# # First block. -# # Multi head attention. -# out, _ = self.self_attention(src, src, src, mask) -# -# # Add & norm. -# out = self.block1(out, src) -# -# # Second block. -# # Apply 1D-convolution. -# mlp_out = self.mlp(out) -# -# # Add & norm. -# out = self.block2(mlp_out, out) -# -# return out -# -# -# class Encoder(nn.Module): -# """Transfomer encoder module.""" -# -# def __init__( -# self, -# num_layers: int, -# encoder_layer: Type[nn.Module], -# norm: Optional[Type[nn.Module]] = None, -# ) -> None: -# super().__init__() -# self.layers = _get_clones(encoder_layer, num_layers) -# self.norm = norm -# -# def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: -# """Forward pass through all encoder layers.""" -# for layer in self.layers: -# src = layer(src, src_mask) -# -# if self.norm is not None: -# src = self.norm(src) -# -# return src -# -# -# class DecoderLayer(nn.Module): -# """Transfomer decoder layer.""" -# -# def __init__( -# self, -# hidden_dim: int, -# num_heads: int, -# expansion_dim: int, -# dropout_rate: float = 0.0, -# activation: str = "relu", -# ) -> None: -# super().__init__() -# self.hidden_dim = hidden_dim -# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) -# self.multihead_attention = MultiHeadAttention( -# hidden_dim, num_heads, dropout_rate -# ) -# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) -# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) -# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) -# self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) -# -# def forward( -# self, -# trg: Tensor, -# memory: Tensor, -# trg_mask: Optional[Tensor] = None, -# memory_mask: Optional[Tensor] = None, -# ) -> Tensor: -# """Forward pass of the layer.""" -# out, _ = self.self_attention(trg, trg, trg, trg_mask) -# trg = self.block1(out, trg) -# -# out, _ = self.multihead_attention(trg, memory, memory, memory_mask) -# trg = self.block2(out, trg) -# -# out = self.mlp(trg) -# out = self.block3(out, trg) -# -# return out -# -# -# class Decoder(nn.Module): -# """Transfomer decoder module.""" -# -# def __init__( -# self, -# decoder_layer: Type[nn.Module], -# num_layers: int, -# norm: Optional[Type[nn.Module]] = None, -# ) -> None: -# super().__init__() -# self.layers = _get_clones(decoder_layer, num_layers) -# self.num_layers = num_layers -# self.norm = norm -# -# def forward( -# self, -# trg: Tensor, -# memory: Tensor, -# trg_mask: Optional[Tensor] = None, -# memory_mask: Optional[Tensor] = None, -# ) -> Tensor: -# """Forward pass through the decoder.""" -# for layer in self.layers: -# trg = layer(trg, memory, trg_mask, memory_mask) -# -# if self.norm is not None: -# trg = self.norm(trg) -# -# return trg -# -# -# class Transformer(nn.Module): -# """Transformer network.""" -# -# def __init__( -# self, -# num_encoder_layers: int, -# num_decoder_layers: int, -# hidden_dim: int, -# num_heads: int, -# expansion_dim: int, -# dropout_rate: float, -# activation: str = "relu", -# ) -> None: -# super().__init__() -# -# # Configure encoder. -# encoder_norm = nn.LayerNorm(hidden_dim) -# encoder_layer = EncoderLayer( -# hidden_dim, num_heads, expansion_dim, dropout_rate, activation -# ) -# self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) -# -# # Configure decoder. -# decoder_norm = nn.LayerNorm(hidden_dim) -# decoder_layer = DecoderLayer( -# hidden_dim, num_heads, expansion_dim, dropout_rate, activation -# ) -# self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) -# -# self._reset_parameters() -# -# def _reset_parameters(self) -> None: -# for p in self.parameters(): -# if p.dim() > 1: -# nn.init.xavier_uniform_(p) -# -# def forward( -# self, -# src: Tensor, -# trg: Tensor, -# src_mask: Optional[Tensor] = None, -# trg_mask: Optional[Tensor] = None, -# memory_mask: Optional[Tensor] = None, -# ) -> Tensor: -# """Forward pass through the transformer.""" -# if src.shape[0] != trg.shape[0]: -# print(trg.shape) -# raise RuntimeError("The batch size of the src and trg must be the same.") -# if src.shape[2] != trg.shape[2]: -# raise RuntimeError( -# "The number of features for the src and trg must be the same." -# ) -# -# memory = self.encoder(src, src_mask) -# output = self.decoder(trg, memory, trg_mask, memory_mask) -# return output +"""Transformer wrapper.""" +from typing import Optional, Type + +from torch import nn, Tensor + +from .layers import AttentionLayers +from text_recognizer.networks.transformer.positional_encodings import ( + AbsolutePositionalEmbedding, +) + + +class Transformer(nn.Module): + def __init__( + self, + num_tokens: int, + max_seq_len: int, + attn_layers: Type[AttentionLayers], + emb_dim: Optional[int] = None, + emb_dropout: float = 0.0, + use_pos_emb: bool = True, + ) -> None: + dim = attn_layers.dim + emb_dim = emb_dim if emb_dim is not None else dim + self.max_seq_len = max_seq_len + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.emb_dropout = nn.Dropout(emb_dropout) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not self.attn_layers.has_pos_emb) + else None + ) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self._init_weights() + + self.logits = nn.Linear(dim, num_tokens) + + def _init_weights(self) -> None: + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x: Tensor, + mask: Optional[Tensor], + return_embeddings: bool = False, + **kwargs: Any + ) -> Tensor: + b, n, device = *x.shape, x.device + x += self.token_emb(x) + if self.pos_emb is not None: + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + x = self.attn_layers(x, mask=mask, **kwargs) + out = self.logits(x) if not return_embeddings else x + return x diff --git a/text_recognizer/networks/transformer/vit.py b/text_recognizer/networks/transformer/vit.py index e69de29..ab331f8 100644 --- a/text_recognizer/networks/transformer/vit.py +++ b/text_recognizer/networks/transformer/vit.py @@ -0,0 +1,46 @@ +"""Vision Transformer.""" +from typing import Tuple, Type + +from einops.layers.torch import Rearrange +import torch +from torch import nn, Tensor + + +class ViT(nn.Module): + def __init__( + self, + image_size: Tuple[int, int], + patch_size: Tuple[int, int], + dim: int, + transformer: Type[nn.Module], + channels: int = 1, + ) -> None: + super().__init__() + img_height, img_width = image_size + patch_height, patch_width = patch_size + assert img_height % patch_height == 0 + assert img_width % patch_width == 0 + + num_patches = (img_height // patch_height) * (img_width // patch_width) + patch_dim = channels * patch_height * patch_width + + self.to_patch_embedding = nn.Sequential( + Rearrange( + "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", + p1=patch_height, + p2=patch_width, + c=channels, + ), + nn.Linear(patch_dim, dim), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) + self.transformer = transformer + self.norm = nn.LayerNorm(dim) + + def forward(self, img: Tensor) -> Tensor: + x = self.to_patch_embedding(img) + _, n, _ = x.shape + x += self.pos_embedding[:, :n] + x = self.transformer(x) + return x diff --git a/text_recognizer/networks/vision_transformer.py b/text_recognizer/networks/vision_transformer.py new file mode 100644 index 0000000..b617c71 --- /dev/null +++ b/text_recognizer/networks/vision_transformer.py @@ -0,0 +1,7 @@ +"""Vision transformer for character recognition.""" +from torch import nn, Tensor + + +class VisionTransformer(nn.Module): + def __init__(self,) -> None: + pass -- cgit v1.2.3-70-g09d2