summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/transformer.py')
-rw-r--r--text_recognizer/networks/transformer/transformer.py321
1 files changed, 61 insertions, 260 deletions
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
index d49c85a..36f86ac 100644
--- a/text_recognizer/networks/transformer/transformer.py
+++ b/text_recognizer/networks/transformer/transformer.py
@@ -1,260 +1,61 @@
-# """Transfomer module."""
-# import copy
-# from typing import Dict, Optional, Type, Union
-#
-# import numpy as np
-# import torch
-# from torch import nn
-# from torch import Tensor
-# import torch.nn.functional as F
-#
-# from text_recognizer.networks.transformer.attention import MultiHeadAttention
-# from text_recognizer.networks.util import activation_function
-#
-#
-# class GEGLU(nn.Module):
-# """GLU activation for improving feedforward activations."""
-#
-# def __init__(self, dim_in: int, dim_out: int) -> None:
-# super().__init__()
-# self.proj = nn.Linear(dim_in, dim_out * 2)
-#
-# def forward(self, x: Tensor) -> Tensor:
-# """Forward propagation."""
-# x, gate = self.proj(x).chunk(2, dim=-1)
-# return x * F.gelu(gate)
-#
-#
-# def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
-# return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
-#
-#
-# class _IntraLayerConnection(nn.Module):
-# """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
-#
-# def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
-# super().__init__()
-# self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
-# self.dropout = nn.Dropout(p=dropout_rate)
-#
-# def forward(self, src: Tensor, residual: Tensor) -> Tensor:
-# return self.norm(self.dropout(src) + residual)
-#
-#
-# class FeedForward(nn.Module):
-# def __init__(
-# self,
-# hidden_dim: int,
-# expansion_dim: int,
-# dropout_rate: float,
-# activation: str = "relu",
-# ) -> None:
-# super().__init__()
-#
-# in_projection = (
-# nn.Sequential(
-# nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
-# )
-# if activation != "glu"
-# else GEGLU(hidden_dim, expansion_dim)
-# )
-#
-# self.layer = nn.Sequential(
-# in_projection,
-# nn.Dropout(p=dropout_rate),
-# nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
-# )
-#
-# def forward(self, x: Tensor) -> Tensor:
-# return self.layer(x)
-#
-#
-# class EncoderLayer(nn.Module):
-# """Transfomer encoding layer."""
-#
-# def __init__(
-# self,
-# hidden_dim: int,
-# num_heads: int,
-# expansion_dim: int,
-# dropout_rate: float,
-# activation: str = "relu",
-# ) -> None:
-# super().__init__()
-# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
-# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
-# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
-# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
-#
-# def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
-# """Forward pass through the encoder."""
-# # First block.
-# # Multi head attention.
-# out, _ = self.self_attention(src, src, src, mask)
-#
-# # Add & norm.
-# out = self.block1(out, src)
-#
-# # Second block.
-# # Apply 1D-convolution.
-# mlp_out = self.mlp(out)
-#
-# # Add & norm.
-# out = self.block2(mlp_out, out)
-#
-# return out
-#
-#
-# class Encoder(nn.Module):
-# """Transfomer encoder module."""
-#
-# def __init__(
-# self,
-# num_layers: int,
-# encoder_layer: Type[nn.Module],
-# norm: Optional[Type[nn.Module]] = None,
-# ) -> None:
-# super().__init__()
-# self.layers = _get_clones(encoder_layer, num_layers)
-# self.norm = norm
-#
-# def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
-# """Forward pass through all encoder layers."""
-# for layer in self.layers:
-# src = layer(src, src_mask)
-#
-# if self.norm is not None:
-# src = self.norm(src)
-#
-# return src
-#
-#
-# class DecoderLayer(nn.Module):
-# """Transfomer decoder layer."""
-#
-# def __init__(
-# self,
-# hidden_dim: int,
-# num_heads: int,
-# expansion_dim: int,
-# dropout_rate: float = 0.0,
-# activation: str = "relu",
-# ) -> None:
-# super().__init__()
-# self.hidden_dim = hidden_dim
-# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
-# self.multihead_attention = MultiHeadAttention(
-# hidden_dim, num_heads, dropout_rate
-# )
-# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
-# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
-# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
-# self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
-#
-# def forward(
-# self,
-# trg: Tensor,
-# memory: Tensor,
-# trg_mask: Optional[Tensor] = None,
-# memory_mask: Optional[Tensor] = None,
-# ) -> Tensor:
-# """Forward pass of the layer."""
-# out, _ = self.self_attention(trg, trg, trg, trg_mask)
-# trg = self.block1(out, trg)
-#
-# out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
-# trg = self.block2(out, trg)
-#
-# out = self.mlp(trg)
-# out = self.block3(out, trg)
-#
-# return out
-#
-#
-# class Decoder(nn.Module):
-# """Transfomer decoder module."""
-#
-# def __init__(
-# self,
-# decoder_layer: Type[nn.Module],
-# num_layers: int,
-# norm: Optional[Type[nn.Module]] = None,
-# ) -> None:
-# super().__init__()
-# self.layers = _get_clones(decoder_layer, num_layers)
-# self.num_layers = num_layers
-# self.norm = norm
-#
-# def forward(
-# self,
-# trg: Tensor,
-# memory: Tensor,
-# trg_mask: Optional[Tensor] = None,
-# memory_mask: Optional[Tensor] = None,
-# ) -> Tensor:
-# """Forward pass through the decoder."""
-# for layer in self.layers:
-# trg = layer(trg, memory, trg_mask, memory_mask)
-#
-# if self.norm is not None:
-# trg = self.norm(trg)
-#
-# return trg
-#
-#
-# class Transformer(nn.Module):
-# """Transformer network."""
-#
-# def __init__(
-# self,
-# num_encoder_layers: int,
-# num_decoder_layers: int,
-# hidden_dim: int,
-# num_heads: int,
-# expansion_dim: int,
-# dropout_rate: float,
-# activation: str = "relu",
-# ) -> None:
-# super().__init__()
-#
-# # Configure encoder.
-# encoder_norm = nn.LayerNorm(hidden_dim)
-# encoder_layer = EncoderLayer(
-# hidden_dim, num_heads, expansion_dim, dropout_rate, activation
-# )
-# self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
-#
-# # Configure decoder.
-# decoder_norm = nn.LayerNorm(hidden_dim)
-# decoder_layer = DecoderLayer(
-# hidden_dim, num_heads, expansion_dim, dropout_rate, activation
-# )
-# self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
-#
-# self._reset_parameters()
-#
-# def _reset_parameters(self) -> None:
-# for p in self.parameters():
-# if p.dim() > 1:
-# nn.init.xavier_uniform_(p)
-#
-# def forward(
-# self,
-# src: Tensor,
-# trg: Tensor,
-# src_mask: Optional[Tensor] = None,
-# trg_mask: Optional[Tensor] = None,
-# memory_mask: Optional[Tensor] = None,
-# ) -> Tensor:
-# """Forward pass through the transformer."""
-# if src.shape[0] != trg.shape[0]:
-# print(trg.shape)
-# raise RuntimeError("The batch size of the src and trg must be the same.")
-# if src.shape[2] != trg.shape[2]:
-# raise RuntimeError(
-# "The number of features for the src and trg must be the same."
-# )
-#
-# memory = self.encoder(src, src_mask)
-# output = self.decoder(trg, memory, trg_mask, memory_mask)
-# return output
+"""Transformer wrapper."""
+from typing import Optional, Type
+
+from torch import nn, Tensor
+
+from .layers import AttentionLayers
+from text_recognizer.networks.transformer.positional_encodings import (
+ AbsolutePositionalEmbedding,
+)
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ num_tokens: int,
+ max_seq_len: int,
+ attn_layers: Type[AttentionLayers],
+ emb_dim: Optional[int] = None,
+ emb_dropout: float = 0.0,
+ use_pos_emb: bool = True,
+ ) -> None:
+ dim = attn_layers.dim
+ emb_dim = emb_dim if emb_dim is not None else dim
+ self.max_seq_len = max_seq_len
+
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
+ self.emb_dropout = nn.Dropout(emb_dropout)
+ self.pos_emb = (
+ AbsolutePositionalEmbedding(emb_dim, max_seq_len)
+ if (use_pos_emb and not self.attn_layers.has_pos_emb)
+ else None
+ )
+
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
+ self.attn_layers = attn_layers
+ self.norm = nn.LayerNorm(dim)
+
+ self._init_weights()
+
+ self.logits = nn.Linear(dim, num_tokens)
+
+ def _init_weights(self) -> None:
+ nn.init.normal_(self.token_emb.weight, std=0.02)
+
+ def forward(
+ self,
+ x: Tensor,
+ mask: Optional[Tensor],
+ return_embeddings: bool = False,
+ **kwargs: Any
+ ) -> Tensor:
+ b, n, device = *x.shape, x.device
+ x += self.token_emb(x)
+ if self.pos_emb is not None:
+ x += self.pos_emb(x)
+ x = self.emb_dropout(x)
+
+ x = self.project_emb(x)
+ x = self.attn_layers(x, mask=mask, **kwargs)
+ out = self.logits(x) if not return_embeddings else x
+ return x