diff options
Diffstat (limited to 'src/text_recognizer/networks/transformer')
5 files changed, 369 insertions, 0 deletions
| diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py new file mode 100644 index 0000000..020a917 --- /dev/null +++ b/src/text_recognizer/networks/transformer/__init__.py @@ -0,0 +1,3 @@ +"""Transformer modules.""" +from .positional_encoding import PositionalEncoding +from .transformer import Decoder, Encoder, Transformer diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py new file mode 100644 index 0000000..cce1ecc --- /dev/null +++ b/src/text_recognizer/networks/transformer/attention.py @@ -0,0 +1,93 @@ +"""Implementes the attention module for the transformer.""" +from typing import Optional, Tuple + +from einops import rearrange +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class MultiHeadAttention(nn.Module): +    """Implementation of multihead attention.""" + +    def __init__( +        self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0 +    ) -> None: +        super().__init__() +        self.hidden_dim = hidden_dim +        self.num_heads = num_heads +        self.fc_q = nn.Linear( +            in_features=hidden_dim, out_features=hidden_dim, bias=False +        ) +        self.fc_k = nn.Linear( +            in_features=hidden_dim, out_features=hidden_dim, bias=False +        ) +        self.fc_v = nn.Linear( +            in_features=hidden_dim, out_features=hidden_dim, bias=False +        ) +        self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) + +        self._init_weights() + +        self.dropout = nn.Dropout(p=dropout_rate) + +    def _init_weights(self) -> None: +        nn.init.normal_( +            self.fc_q.weight, +            mean=0, +            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), +        ) +        nn.init.normal_( +            self.fc_k.weight, +            mean=0, +            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), +        ) +        nn.init.normal_( +            self.fc_v.weight, +            mean=0, +            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), +        ) +        nn.init.xavier_normal_(self.fc_out.weight) + +    def scaled_dot_product_attention( +        self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None +    ) -> Tensor: +        """Calculates the scaled dot product attention.""" + +        # Compute the energy. +        energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt( +            query.shape[-1] +        ) + +        # If we have a mask for padding some inputs. +        if mask is not None: +            energy = energy.masked_fill(mask == 0, -np.inf) + +        # Compute the attention from the energy. +        attention = torch.softmax(energy, dim=3) + +        out = torch.einsum("bhlt,bhtv->bhlv", [attention, value]) +        out = rearrange(out, "b head l v -> b l (head v)") +        return out, attention + +    def forward( +        self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None +    ) -> Tuple[Tensor, Tensor]: +        """Forward pass for computing the multihead attention.""" +        # Get the query, key, and value tensor. +        query = rearrange( +            self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads +        ) +        key = rearrange( +            self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads +        ) +        value = rearrange( +            self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads +        ) + +        out, attention = self.scaled_dot_product_attention(query, key, value, mask) + +        out = self.fc_out(out) +        out = self.dropout(out) +        return out, attention diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py new file mode 100644 index 0000000..a47141b --- /dev/null +++ b/src/text_recognizer/networks/transformer/positional_encoding.py @@ -0,0 +1,31 @@ +"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class PositionalEncoding(nn.Module): +    """Encodes a sense of distance or time for transformer networks.""" + +    def __init__( +        self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 +    ) -> None: +        super().__init__() +        self.dropout = nn.Dropout(p=dropout_rate) + +        pe = torch.zeros(max_len, hidden_dim) +        position = torch.arange(0, max_len).unsqueeze(1) +        div_term = torch.exp( +            torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim) +        ) + +        pe[:, 0::2] = torch.sin(position * div_term) +        pe[:, 1::2] = torch.cos(position * div_term) +        pe = pe.unsqueeze(0) +        self.register_buffer("pe", pe) + +    def forward(self, x: Tensor) -> Tensor: +        """Encodes the tensor with a postional embedding.""" +        x = x + self.pe[:, : x.shape[1]] +        return self.dropout(x) diff --git a/src/text_recognizer/networks/transformer/sparse_transformer.py b/src/text_recognizer/networks/transformer/sparse_transformer.py new file mode 100644 index 0000000..8c391c8 --- /dev/null +++ b/src/text_recognizer/networks/transformer/sparse_transformer.py @@ -0,0 +1 @@ +"""Encoder and Decoder modules using spares activations.""" diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py new file mode 100644 index 0000000..1c9c7dd --- /dev/null +++ b/src/text_recognizer/networks/transformer/transformer.py @@ -0,0 +1,241 @@ +"""Transfomer module.""" +import copy +from typing import Dict, Optional, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer.attention import MultiHeadAttention +from text_recognizer.networks.util import activation_function + + +def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: +    return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) + + +class _IntraLayerConnection(nn.Module): +    """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" + +    def __init__(self, dropout_rate: float, hidden_dim: int) -> None: +        super().__init__() +        self.norm = nn.LayerNorm(normalized_shape=hidden_dim) +        self.dropout = nn.Dropout(p=dropout_rate) + +    def forward(self, src: Tensor, residual: Tensor) -> Tensor: +        return self.norm(self.dropout(src) + residual) + + +class _ConvolutionalLayer(nn.Module): +    def __init__( +        self, +        hidden_dim: int, +        expansion_dim: int, +        dropout_rate: float, +        activation: str = "relu", +    ) -> None: +        super().__init__() +        self.layer = nn.Sequential( +            nn.Linear(in_features=hidden_dim, out_features=expansion_dim), +            activation_function(activation), +            nn.Dropout(p=dropout_rate), +            nn.Linear(in_features=expansion_dim, out_features=hidden_dim), +        ) + +    def forward(self, x: Tensor) -> Tensor: +        return self.layer(x) + + +class EncoderLayer(nn.Module): +    """Transfomer encoding layer.""" + +    def __init__( +        self, +        hidden_dim: int, +        num_heads: int, +        expansion_dim: int, +        dropout_rate: float, +        activation: str = "relu", +    ) -> None: +        super().__init__() +        self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) +        self.cnn = _ConvolutionalLayer( +            hidden_dim, expansion_dim, dropout_rate, activation +        ) +        self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) +        self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) + +    def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: +        """Forward pass through the encoder.""" +        # First block. +        # Multi head attention. +        out, _ = self.self_attention(src, src, src, mask) + +        # Add & norm. +        out = self.block1(out, src) + +        # Second block. +        # Apply 1D-convolution. +        cnn_out = self.cnn(out) + +        # Add & norm. +        out = self.block2(cnn_out, out) + +        return out + + +class Encoder(nn.Module): +    """Transfomer encoder module.""" + +    def __init__( +        self, +        num_layers: int, +        encoder_layer: Type[nn.Module], +        norm: Optional[Type[nn.Module]] = None, +    ) -> None: +        super().__init__() +        self.layers = _get_clones(encoder_layer, num_layers) +        self.norm = norm + +    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: +        """Forward pass through all encoder layers.""" +        for layer in self.layers: +            src = layer(src, src_mask) + +        if self.norm is not None: +            src = self.norm(src) + +        return src + + +class DecoderLayer(nn.Module): +    """Transfomer decoder layer.""" + +    def __init__( +        self, +        hidden_dim: int, +        num_heads: int, +        expansion_dim: int, +        dropout_rate: float = 0.0, +        activation: str = "relu", +    ) -> None: +        super().__init__() +        self.hidden_dim = hidden_dim +        self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) +        self.multihead_attention = MultiHeadAttention( +            hidden_dim, num_heads, dropout_rate +        ) +        self.cnn = _ConvolutionalLayer( +            hidden_dim, expansion_dim, dropout_rate, activation +        ) +        self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) +        self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) +        self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) + +    def forward( +        self, +        trg: Tensor, +        memory: Tensor, +        trg_mask: Optional[Tensor] = None, +        memory_mask: Optional[Tensor] = None, +    ) -> Tensor: +        """Forward pass of the layer.""" +        out, _ = self.self_attention(trg, trg, trg, trg_mask) +        trg = self.block1(out, trg) + +        out, _ = self.multihead_attention(trg, memory, memory, memory_mask) +        trg = self.block2(out, trg) + +        out = self.cnn(trg) +        out = self.block3(out, trg) + +        return out + + +class Decoder(nn.Module): +    """Transfomer decoder module.""" + +    def __init__( +        self, +        decoder_layer: Type[nn.Module], +        num_layers: int, +        norm: Optional[Type[nn.Module]] = None, +    ) -> None: +        super().__init__() +        self.layers = _get_clones(decoder_layer, num_layers) +        self.num_layers = num_layers +        self.norm = norm + +    def forward( +        self, +        trg: Tensor, +        memory: Tensor, +        trg_mask: Optional[Tensor] = None, +        memory_mask: Optional[Tensor] = None, +    ) -> Tensor: +        """Forward pass through the decoder.""" +        for layer in self.layers: +            trg = layer(trg, memory, trg_mask, memory_mask) + +        if self.norm is not None: +            trg = self.norm(trg) + +        return trg + + +class Transformer(nn.Module): +    """Transformer network.""" + +    def __init__( +        self, +        num_encoder_layers: int, +        num_decoder_layers: int, +        hidden_dim: int, +        num_heads: int, +        expansion_dim: int, +        dropout_rate: float, +        activation: str = "relu", +    ) -> None: +        super().__init__() + +        # Configure encoder. +        encoder_norm = nn.LayerNorm(hidden_dim) +        encoder_layer = EncoderLayer( +            hidden_dim, num_heads, expansion_dim, dropout_rate, activation +        ) +        self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) + +        # Configure decoder. +        decoder_norm = nn.LayerNorm(hidden_dim) +        decoder_layer = DecoderLayer( +            hidden_dim, num_heads, expansion_dim, dropout_rate, activation +        ) +        self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) + +        self._reset_parameters() + +    def _reset_parameters(self) -> None: +        for p in self.parameters(): +            if p.dim() > 1: +                nn.init.xavier_uniform_(p) + +    def forward( +        self, +        src: Tensor, +        trg: Tensor, +        src_mask: Optional[Tensor] = None, +        trg_mask: Optional[Tensor] = None, +        memory_mask: Optional[Tensor] = None, +    ) -> Tensor: +        """Forward pass through the transformer.""" +        if src.shape[0] != trg.shape[0]: +            raise RuntimeError("The batch size of the src and trg must be the same.") +        if src.shape[2] != trg.shape[2]: +            raise RuntimeError( +                "The number of features for the src and trg must be the same." +            ) + +        memory = self.encoder(src, src_mask) +        output = self.decoder(trg, memory, trg_mask, memory_mask) +        return output |