diff options
Diffstat (limited to 'text_recognizer/networks/transformer/transformer.py')
-rw-r--r-- | text_recognizer/networks/transformer/transformer.py | 264 |
1 files changed, 264 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py new file mode 100644 index 0000000..dd180c4 --- /dev/null +++ b/text_recognizer/networks/transformer/transformer.py @@ -0,0 +1,264 @@ +"""Transfomer module.""" +import copy +from typing import Dict, Optional, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + +from text_recognizer.networks.transformer.attention import MultiHeadAttention +from text_recognizer.networks.util import activation_function + + +class GEGLU(nn.Module): + """GLU activation for improving feedforward activations.""" + + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x: Tensor) -> Tensor: + """Forward propagation.""" + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: + return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) + + +class _IntraLayerConnection(nn.Module): + """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" + + def __init__(self, dropout_rate: float, hidden_dim: int) -> None: + super().__init__() + self.norm = nn.LayerNorm(normalized_shape=hidden_dim) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, src: Tensor, residual: Tensor) -> Tensor: + return self.norm(self.dropout(src) + residual) + + +class _ConvolutionalLayer(nn.Module): + def __init__( + self, + hidden_dim: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + + in_projection = ( + nn.Sequential( + nn.Linear(hidden_dim, expansion_dim), activation_function(activation) + ) + if activation != "glu" + else GEGLU(hidden_dim, expansion_dim) + ) + + self.layer = nn.Sequential( + in_projection, + nn.Dropout(p=dropout_rate), + nn.Linear(in_features=expansion_dim, out_features=hidden_dim), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.layer(x) + + +class EncoderLayer(nn.Module): + """Transfomer encoding layer.""" + + def __init__( + self, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) + self.cnn = _ConvolutionalLayer( + hidden_dim, expansion_dim, dropout_rate, activation + ) + self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) + + def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + """Forward pass through the encoder.""" + # First block. + # Multi head attention. + out, _ = self.self_attention(src, src, src, mask) + + # Add & norm. + out = self.block1(out, src) + + # Second block. + # Apply 1D-convolution. + cnn_out = self.cnn(out) + + # Add & norm. + out = self.block2(cnn_out, out) + + return out + + +class Encoder(nn.Module): + """Transfomer encoder module.""" + + def __init__( + self, + num_layers: int, + encoder_layer: Type[nn.Module], + norm: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.norm = norm + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + """Forward pass through all encoder layers.""" + for layer in self.layers: + src = layer(src, src_mask) + + if self.norm is not None: + src = self.norm(src) + + return src + + +class DecoderLayer(nn.Module): + """Transfomer decoder layer.""" + + def __init__( + self, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float = 0.0, + activation: str = "relu", + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) + self.multihead_attention = MultiHeadAttention( + hidden_dim, num_heads, dropout_rate + ) + self.cnn = _ConvolutionalLayer( + hidden_dim, expansion_dim, dropout_rate, activation + ) + self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) + + def forward( + self, + trg: Tensor, + memory: Tensor, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass of the layer.""" + out, _ = self.self_attention(trg, trg, trg, trg_mask) + trg = self.block1(out, trg) + + out, _ = self.multihead_attention(trg, memory, memory, memory_mask) + trg = self.block2(out, trg) + + out = self.cnn(trg) + out = self.block3(out, trg) + + return out + + +class Decoder(nn.Module): + """Transfomer decoder module.""" + + def __init__( + self, + decoder_layer: Type[nn.Module], + num_layers: int, + norm: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + trg: Tensor, + memory: Tensor, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass through the decoder.""" + for layer in self.layers: + trg = layer(trg, memory, trg_mask, memory_mask) + + if self.norm is not None: + trg = self.norm(trg) + + return trg + + +class Transformer(nn.Module): + """Transformer network.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + + # Configure encoder. + encoder_norm = nn.LayerNorm(hidden_dim) + encoder_layer = EncoderLayer( + hidden_dim, num_heads, expansion_dim, dropout_rate, activation + ) + self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) + + # Configure decoder. + decoder_norm = nn.LayerNorm(hidden_dim) + decoder_layer = DecoderLayer( + hidden_dim, num_heads, expansion_dim, dropout_rate, activation + ) + self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + src: Tensor, + trg: Tensor, + src_mask: Optional[Tensor] = None, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass through the transformer.""" + if src.shape[0] != trg.shape[0]: + print(trg.shape) + raise RuntimeError("The batch size of the src and trg must be the same.") + if src.shape[2] != trg.shape[2]: + raise RuntimeError( + "The number of features for the src and trg must be the same." + ) + + memory = self.encoder(src, src_mask) + output = self.decoder(trg, memory, trg_mask, memory_mask) + return output |