diff options
Diffstat (limited to 'text_recognizer/networks')
| -rw-r--r-- | text_recognizer/networks/__init__.py | 3 | ||||
| -rw-r--r-- | text_recognizer/networks/cnn_transformer.py | 364 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/__init__.py | 3 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/transformer.py | 520 | 
4 files changed, 446 insertions, 444 deletions
| diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index a9117f8..d1ebf1a 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,4 +1,5 @@  """Network modules"""  from .encoders import EfficientNet  from .vqvae import VQVAE -from .cnn_transformer import CNNTransformer + +# from .cnn_transformer import CNNTransformer diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index d42c29d..80798e1 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -1,182 +1,182 @@ -"""A Transformer with a cnn backbone. - -The network encodes a image with a convolutional backbone to a latent representation, -i.e. feature maps. A 2d positional encoding is applied to the feature maps for -spatial information. The resulting feature are then set to a transformer decoder -together with the target tokens. - -TODO: Local attention for lower layer in attention. - -""" -import importlib -import math -from typing import Dict, Optional, Union, Sequence, Type - -from einops import rearrange -from omegaconf import DictConfig, OmegaConf -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS -from text_recognizer.networks.transformer import ( -    Decoder, -    DecoderLayer, -    PositionalEncoding, -    PositionalEncoding2D, -    target_padding_mask, -) - -NUM_WORD_PIECES = 1000 - - -class CNNTransformer(nn.Module): -    def __init__( -        self, -        input_dim: Sequence[int], -        output_dims: Sequence[int], -        encoder: Union[DictConfig, Dict], -        vocab_size: Optional[int] = None, -        num_decoder_layers: int = 4, -        hidden_dim: int = 256, -        num_heads: int = 4, -        expansion_dim: int = 1024, -        dropout_rate: float = 0.1, -        transformer_activation: str = "glu", -        *args, -        **kwargs, -    ) -> None: -        super().__init__() -        self.vocab_size = ( -            NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size -        ) -        self.pad_index = 3  # TODO: fix me -        self.hidden_dim = hidden_dim -        self.max_output_length = output_dims[0] - -        # Image backbone -        self.encoder = self._configure_encoder(encoder) -        self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) -        self.feature_map_encoding = PositionalEncoding2D( -            hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] -        ) - -        # Target token embedding -        self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) -        self.trg_position_encoding = PositionalEncoding( -            hidden_dim, dropout_rate, max_len=output_dims[0] -        ) - -        # Transformer decoder -        self.decoder = Decoder( -            decoder_layer=DecoderLayer( -                hidden_dim=hidden_dim, -                num_heads=num_heads, -                expansion_dim=expansion_dim, -                dropout_rate=dropout_rate, -                activation=transformer_activation, -            ), -            num_layers=num_decoder_layers, -            norm=nn.LayerNorm(hidden_dim), -        ) - -        # Classification head -        self.head = nn.Linear(hidden_dim, self.vocab_size) - -        # Initialize weights -        self._init_weights() - -    def _init_weights(self) -> None: -        """Initialize network weights.""" -        self.trg_embedding.weight.data.uniform_(-0.1, 0.1) -        self.head.bias.data.zero_() -        self.head.weight.data.uniform_(-0.1, 0.1) - -        nn.init.kaiming_normal_( -            self.encoder_proj.weight.data, -            a=0, -            mode="fan_out", -            nonlinearity="relu", -        ) -        if self.encoder_proj.bias is not None: -            _, fan_out = nn.init._calculate_fan_in_and_fan_out( -                self.encoder_proj.weight.data -            ) -            bound = 1 / math.sqrt(fan_out) -            nn.init.normal_(self.encoder_proj.bias, -bound, bound) - -    @staticmethod -    def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: -        encoder = OmegaConf.create(encoder) -        args = encoder.args or {} -        network_module = importlib.import_module("text_recognizer.networks") -        encoder_class = getattr(network_module, encoder.type) -        return encoder_class(**args) - -    def encode(self, image: Tensor) -> Tensor: -        """Extracts image features with backbone. - -        Args: -            image (Tensor): Image(s) of handwritten text. - -        Retuns: -            Tensor: Image features. - -        Shapes: -            - image: :math: `(B, C, H, W)` -            - latent: :math: `(B, T, C)` - -        """ -        # Extract image features. -        image_features = self.encoder(image) -        image_features = self.encoder_proj(image_features) - -        # Add 2d encoding to the feature maps. -        image_features = self.feature_map_encoding(image_features) - -        # Collapse features maps height and width. -        image_features = rearrange(image_features, "b c h w -> b (h w) c") -        return image_features - -    def decode(self, memory: Tensor, trg: Tensor) -> Tensor: -        """Decodes image features with transformer decoder.""" -        trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) -        trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) -        trg = rearrange(trg, "b t d -> t b d") -        trg = self.trg_position_encoding(trg) -        trg = rearrange(trg, "t b d -> b t d") -        out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) -        logits = self.head(out) -        return logits - -    def forward(self, image: Tensor, trg: Tensor) -> Tensor: -        image_features = self.encode(image) -        output = self.decode(image_features, trg) -        output = rearrange(output, "b t c -> b c t") -        return output - -    def predict(self, image: Tensor) -> Tensor: -        """Transcribes text in image(s).""" -        bsz = image.shape[0] -        image_features = self.encode(image) - -        output_tokens = ( -            (torch.ones((bsz, self.max_output_length)) * self.pad_index) -            .type_as(image) -            .long() -        ) -        output_tokens[:, 0] = self.start_index -        for i in range(1, self.max_output_length): -            trg = output_tokens[:, :i] -            output = self.decode(image_features, trg) -            output = torch.argmax(output, dim=-1) -            output_tokens[:, i] = output[-1:] - -        # Set all tokens after end token to be padding. -        for i in range(1, self.max_output_length): -            indices = output_tokens[:, i - 1] == self.end_index | ( -                output_tokens[:, i - 1] == self.pad_index -            ) -            output_tokens[indices, i] = self.pad_index - -        return output_tokens +# """A Transformer with a cnn backbone. +# +# The network encodes a image with a convolutional backbone to a latent representation, +# i.e. feature maps. A 2d positional encoding is applied to the feature maps for +# spatial information. The resulting feature are then set to a transformer decoder +# together with the target tokens. +# +# TODO: Local attention for lower layer in attention. +# +# """ +# import importlib +# import math +# from typing import Dict, Optional, Union, Sequence, Type +# +# from einops import rearrange +# from omegaconf import DictConfig, OmegaConf +# import torch +# from torch import nn +# from torch import Tensor +# +# from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS +# from text_recognizer.networks.transformer import ( +#     Decoder, +#     DecoderLayer, +#     PositionalEncoding, +#     PositionalEncoding2D, +#     target_padding_mask, +# ) +# +# NUM_WORD_PIECES = 1000 +# +# +# class CNNTransformer(nn.Module): +#     def __init__( +#         self, +#         input_dim: Sequence[int], +#         output_dims: Sequence[int], +#         encoder: Union[DictConfig, Dict], +#         vocab_size: Optional[int] = None, +#         num_decoder_layers: int = 4, +#         hidden_dim: int = 256, +#         num_heads: int = 4, +#         expansion_dim: int = 1024, +#         dropout_rate: float = 0.1, +#         transformer_activation: str = "glu", +#         *args, +#         **kwargs, +#     ) -> None: +#         super().__init__() +#         self.vocab_size = ( +#             NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size +#         ) +#         self.pad_index = 3  # TODO: fix me +#         self.hidden_dim = hidden_dim +#         self.max_output_length = output_dims[0] +# +#         # Image backbone +#         self.encoder = self._configure_encoder(encoder) +#         self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) +#         self.feature_map_encoding = PositionalEncoding2D( +#             hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] +#         ) +# +#         # Target token embedding +#         self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) +#         self.trg_position_encoding = PositionalEncoding( +#             hidden_dim, dropout_rate, max_len=output_dims[0] +#         ) +# +#         # Transformer decoder +#         self.decoder = Decoder( +#             decoder_layer=DecoderLayer( +#                 hidden_dim=hidden_dim, +#                 num_heads=num_heads, +#                 expansion_dim=expansion_dim, +#                 dropout_rate=dropout_rate, +#                 activation=transformer_activation, +#             ), +#             num_layers=num_decoder_layers, +#             norm=nn.LayerNorm(hidden_dim), +#         ) +# +#         # Classification head +#         self.head = nn.Linear(hidden_dim, self.vocab_size) +# +#         # Initialize weights +#         self._init_weights() +# +#     def _init_weights(self) -> None: +#         """Initialize network weights.""" +#         self.trg_embedding.weight.data.uniform_(-0.1, 0.1) +#         self.head.bias.data.zero_() +#         self.head.weight.data.uniform_(-0.1, 0.1) +# +#         nn.init.kaiming_normal_( +#             self.encoder_proj.weight.data, +#             a=0, +#             mode="fan_out", +#             nonlinearity="relu", +#         ) +#         if self.encoder_proj.bias is not None: +#             _, fan_out = nn.init._calculate_fan_in_and_fan_out( +#                 self.encoder_proj.weight.data +#             ) +#             bound = 1 / math.sqrt(fan_out) +#             nn.init.normal_(self.encoder_proj.bias, -bound, bound) +# +#     @staticmethod +#     def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: +#         encoder = OmegaConf.create(encoder) +#         args = encoder.args or {} +#         network_module = importlib.import_module("text_recognizer.networks") +#         encoder_class = getattr(network_module, encoder.type) +#         return encoder_class(**args) +# +#     def encode(self, image: Tensor) -> Tensor: +#         """Extracts image features with backbone. +# +#         Args: +#             image (Tensor): Image(s) of handwritten text. +# +#         Retuns: +#             Tensor: Image features. +# +#         Shapes: +#             - image: :math: `(B, C, H, W)` +#             - latent: :math: `(B, T, C)` +# +#         """ +#         # Extract image features. +#         image_features = self.encoder(image) +#         image_features = self.encoder_proj(image_features) +# +#         # Add 2d encoding to the feature maps. +#         image_features = self.feature_map_encoding(image_features) +# +#         # Collapse features maps height and width. +#         image_features = rearrange(image_features, "b c h w -> b (h w) c") +#         return image_features +# +#     def decode(self, memory: Tensor, trg: Tensor) -> Tensor: +#         """Decodes image features with transformer decoder.""" +#         trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) +#         trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) +#         trg = rearrange(trg, "b t d -> t b d") +#         trg = self.trg_position_encoding(trg) +#         trg = rearrange(trg, "t b d -> b t d") +#         out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) +#         logits = self.head(out) +#         return logits +# +#     def forward(self, image: Tensor, trg: Tensor) -> Tensor: +#         image_features = self.encode(image) +#         output = self.decode(image_features, trg) +#         output = rearrange(output, "b t c -> b c t") +#         return output +# +#     def predict(self, image: Tensor) -> Tensor: +#         """Transcribes text in image(s).""" +#         bsz = image.shape[0] +#         image_features = self.encode(image) +# +#         output_tokens = ( +#             (torch.ones((bsz, self.max_output_length)) * self.pad_index) +#             .type_as(image) +#             .long() +#         ) +#         output_tokens[:, 0] = self.start_index +#         for i in range(1, self.max_output_length): +#             trg = output_tokens[:, :i] +#             output = self.decode(image_features, trg) +#             output = torch.argmax(output, dim=-1) +#             output_tokens[:, i] = output[-1:] +# +#         # Set all tokens after end token to be padding. +#         for i in range(1, self.max_output_length): +#             indices = output_tokens[:, i - 1] == self.end_index | ( +#                 output_tokens[:, i - 1] == self.pad_index +#             ) +#             output_tokens[indices, i] = self.pad_index +# +#         return output_tokens diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 627fa7b..4ff48f7 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -4,4 +4,5 @@ from .positional_encoding import (      PositionalEncoding2D,      target_padding_mask,  ) -from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer + +# from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py index 5ac2787..d49c85a 100644 --- a/text_recognizer/networks/transformer/transformer.py +++ b/text_recognizer/networks/transformer/transformer.py @@ -1,260 +1,260 @@ -"""Transfomer module.""" -import copy -from typing import Dict, Optional, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - -from text_recognizer.networks.transformer.attention import MultiHeadAttention -from text_recognizer.networks.util import activation_function - - -class GEGLU(nn.Module): -    """GLU activation for improving feedforward activations.""" - -    def __init__(self, dim_in: int, dim_out: int) -> None: -        super().__init__() -        self.proj = nn.Linear(dim_in, dim_out * 2) - -    def forward(self, x: Tensor) -> Tensor: -        """Forward propagation.""" -        x, gate = self.proj(x).chunk(2, dim=-1) -        return x * F.gelu(gate) - - -def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: -    return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) - - -class _IntraLayerConnection(nn.Module): -    """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" - -    def __init__(self, dropout_rate: float, hidden_dim: int) -> None: -        super().__init__() -        self.norm = nn.LayerNorm(normalized_shape=hidden_dim) -        self.dropout = nn.Dropout(p=dropout_rate) - -    def forward(self, src: Tensor, residual: Tensor) -> Tensor: -        return self.norm(self.dropout(src) + residual) - - -class FeedForward(nn.Module): -    def __init__( -        self, -        hidden_dim: int, -        expansion_dim: int, -        dropout_rate: float, -        activation: str = "relu", -    ) -> None: -        super().__init__() - -        in_projection = ( -            nn.Sequential( -                nn.Linear(hidden_dim, expansion_dim), activation_function(activation) -            ) -            if activation != "glu" -            else GEGLU(hidden_dim, expansion_dim) -        ) - -        self.layer = nn.Sequential( -            in_projection, -            nn.Dropout(p=dropout_rate), -            nn.Linear(in_features=expansion_dim, out_features=hidden_dim), -        ) - -    def forward(self, x: Tensor) -> Tensor: -        return self.layer(x) - - -class EncoderLayer(nn.Module): -    """Transfomer encoding layer.""" - -    def __init__( -        self, -        hidden_dim: int, -        num_heads: int, -        expansion_dim: int, -        dropout_rate: float, -        activation: str = "relu", -    ) -> None: -        super().__init__() -        self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) -        self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) -        self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) -        self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) - -    def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: -        """Forward pass through the encoder.""" -        # First block. -        # Multi head attention. -        out, _ = self.self_attention(src, src, src, mask) - -        # Add & norm. -        out = self.block1(out, src) - -        # Second block. -        # Apply 1D-convolution. -        mlp_out = self.mlp(out) - -        # Add & norm. -        out = self.block2(mlp_out, out) - -        return out - - -class Encoder(nn.Module): -    """Transfomer encoder module.""" - -    def __init__( -        self, -        num_layers: int, -        encoder_layer: Type[nn.Module], -        norm: Optional[Type[nn.Module]] = None, -    ) -> None: -        super().__init__() -        self.layers = _get_clones(encoder_layer, num_layers) -        self.norm = norm - -    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: -        """Forward pass through all encoder layers.""" -        for layer in self.layers: -            src = layer(src, src_mask) - -        if self.norm is not None: -            src = self.norm(src) - -        return src - - -class DecoderLayer(nn.Module): -    """Transfomer decoder layer.""" - -    def __init__( -        self, -        hidden_dim: int, -        num_heads: int, -        expansion_dim: int, -        dropout_rate: float = 0.0, -        activation: str = "relu", -    ) -> None: -        super().__init__() -        self.hidden_dim = hidden_dim -        self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) -        self.multihead_attention = MultiHeadAttention( -            hidden_dim, num_heads, dropout_rate -        ) -        self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) -        self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) -        self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) -        self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) - -    def forward( -        self, -        trg: Tensor, -        memory: Tensor, -        trg_mask: Optional[Tensor] = None, -        memory_mask: Optional[Tensor] = None, -    ) -> Tensor: -        """Forward pass of the layer.""" -        out, _ = self.self_attention(trg, trg, trg, trg_mask) -        trg = self.block1(out, trg) - -        out, _ = self.multihead_attention(trg, memory, memory, memory_mask) -        trg = self.block2(out, trg) - -        out = self.mlp(trg) -        out = self.block3(out, trg) - -        return out - - -class Decoder(nn.Module): -    """Transfomer decoder module.""" - -    def __init__( -        self, -        decoder_layer: Type[nn.Module], -        num_layers: int, -        norm: Optional[Type[nn.Module]] = None, -    ) -> None: -        super().__init__() -        self.layers = _get_clones(decoder_layer, num_layers) -        self.num_layers = num_layers -        self.norm = norm - -    def forward( -        self, -        trg: Tensor, -        memory: Tensor, -        trg_mask: Optional[Tensor] = None, -        memory_mask: Optional[Tensor] = None, -    ) -> Tensor: -        """Forward pass through the decoder.""" -        for layer in self.layers: -            trg = layer(trg, memory, trg_mask, memory_mask) - -        if self.norm is not None: -            trg = self.norm(trg) - -        return trg - - -class Transformer(nn.Module): -    """Transformer network.""" - -    def __init__( -        self, -        num_encoder_layers: int, -        num_decoder_layers: int, -        hidden_dim: int, -        num_heads: int, -        expansion_dim: int, -        dropout_rate: float, -        activation: str = "relu", -    ) -> None: -        super().__init__() - -        # Configure encoder. -        encoder_norm = nn.LayerNorm(hidden_dim) -        encoder_layer = EncoderLayer( -            hidden_dim, num_heads, expansion_dim, dropout_rate, activation -        ) -        self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) - -        # Configure decoder. -        decoder_norm = nn.LayerNorm(hidden_dim) -        decoder_layer = DecoderLayer( -            hidden_dim, num_heads, expansion_dim, dropout_rate, activation -        ) -        self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) - -        self._reset_parameters() - -    def _reset_parameters(self) -> None: -        for p in self.parameters(): -            if p.dim() > 1: -                nn.init.xavier_uniform_(p) - -    def forward( -        self, -        src: Tensor, -        trg: Tensor, -        src_mask: Optional[Tensor] = None, -        trg_mask: Optional[Tensor] = None, -        memory_mask: Optional[Tensor] = None, -    ) -> Tensor: -        """Forward pass through the transformer.""" -        if src.shape[0] != trg.shape[0]: -            print(trg.shape) -            raise RuntimeError("The batch size of the src and trg must be the same.") -        if src.shape[2] != trg.shape[2]: -            raise RuntimeError( -                "The number of features for the src and trg must be the same." -            ) - -        memory = self.encoder(src, src_mask) -        output = self.decoder(trg, memory, trg_mask, memory_mask) -        return output +# """Transfomer module.""" +# import copy +# from typing import Dict, Optional, Type, Union +# +# import numpy as np +# import torch +# from torch import nn +# from torch import Tensor +# import torch.nn.functional as F +# +# from text_recognizer.networks.transformer.attention import MultiHeadAttention +# from text_recognizer.networks.util import activation_function +# +# +# class GEGLU(nn.Module): +#     """GLU activation for improving feedforward activations.""" +# +#     def __init__(self, dim_in: int, dim_out: int) -> None: +#         super().__init__() +#         self.proj = nn.Linear(dim_in, dim_out * 2) +# +#     def forward(self, x: Tensor) -> Tensor: +#         """Forward propagation.""" +#         x, gate = self.proj(x).chunk(2, dim=-1) +#         return x * F.gelu(gate) +# +# +# def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: +#     return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) +# +# +# class _IntraLayerConnection(nn.Module): +#     """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" +# +#     def __init__(self, dropout_rate: float, hidden_dim: int) -> None: +#         super().__init__() +#         self.norm = nn.LayerNorm(normalized_shape=hidden_dim) +#         self.dropout = nn.Dropout(p=dropout_rate) +# +#     def forward(self, src: Tensor, residual: Tensor) -> Tensor: +#         return self.norm(self.dropout(src) + residual) +# +# +# class FeedForward(nn.Module): +#     def __init__( +#         self, +#         hidden_dim: int, +#         expansion_dim: int, +#         dropout_rate: float, +#         activation: str = "relu", +#     ) -> None: +#         super().__init__() +# +#         in_projection = ( +#             nn.Sequential( +#                 nn.Linear(hidden_dim, expansion_dim), activation_function(activation) +#             ) +#             if activation != "glu" +#             else GEGLU(hidden_dim, expansion_dim) +#         ) +# +#         self.layer = nn.Sequential( +#             in_projection, +#             nn.Dropout(p=dropout_rate), +#             nn.Linear(in_features=expansion_dim, out_features=hidden_dim), +#         ) +# +#     def forward(self, x: Tensor) -> Tensor: +#         return self.layer(x) +# +# +# class EncoderLayer(nn.Module): +#     """Transfomer encoding layer.""" +# +#     def __init__( +#         self, +#         hidden_dim: int, +#         num_heads: int, +#         expansion_dim: int, +#         dropout_rate: float, +#         activation: str = "relu", +#     ) -> None: +#         super().__init__() +#         self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) +#         self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) +#         self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) +#         self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) +# +#     def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: +#         """Forward pass through the encoder.""" +#         # First block. +#         # Multi head attention. +#         out, _ = self.self_attention(src, src, src, mask) +# +#         # Add & norm. +#         out = self.block1(out, src) +# +#         # Second block. +#         # Apply 1D-convolution. +#         mlp_out = self.mlp(out) +# +#         # Add & norm. +#         out = self.block2(mlp_out, out) +# +#         return out +# +# +# class Encoder(nn.Module): +#     """Transfomer encoder module.""" +# +#     def __init__( +#         self, +#         num_layers: int, +#         encoder_layer: Type[nn.Module], +#         norm: Optional[Type[nn.Module]] = None, +#     ) -> None: +#         super().__init__() +#         self.layers = _get_clones(encoder_layer, num_layers) +#         self.norm = norm +# +#     def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: +#         """Forward pass through all encoder layers.""" +#         for layer in self.layers: +#             src = layer(src, src_mask) +# +#         if self.norm is not None: +#             src = self.norm(src) +# +#         return src +# +# +# class DecoderLayer(nn.Module): +#     """Transfomer decoder layer.""" +# +#     def __init__( +#         self, +#         hidden_dim: int, +#         num_heads: int, +#         expansion_dim: int, +#         dropout_rate: float = 0.0, +#         activation: str = "relu", +#     ) -> None: +#         super().__init__() +#         self.hidden_dim = hidden_dim +#         self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) +#         self.multihead_attention = MultiHeadAttention( +#             hidden_dim, num_heads, dropout_rate +#         ) +#         self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) +#         self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) +#         self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) +#         self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) +# +#     def forward( +#         self, +#         trg: Tensor, +#         memory: Tensor, +#         trg_mask: Optional[Tensor] = None, +#         memory_mask: Optional[Tensor] = None, +#     ) -> Tensor: +#         """Forward pass of the layer.""" +#         out, _ = self.self_attention(trg, trg, trg, trg_mask) +#         trg = self.block1(out, trg) +# +#         out, _ = self.multihead_attention(trg, memory, memory, memory_mask) +#         trg = self.block2(out, trg) +# +#         out = self.mlp(trg) +#         out = self.block3(out, trg) +# +#         return out +# +# +# class Decoder(nn.Module): +#     """Transfomer decoder module.""" +# +#     def __init__( +#         self, +#         decoder_layer: Type[nn.Module], +#         num_layers: int, +#         norm: Optional[Type[nn.Module]] = None, +#     ) -> None: +#         super().__init__() +#         self.layers = _get_clones(decoder_layer, num_layers) +#         self.num_layers = num_layers +#         self.norm = norm +# +#     def forward( +#         self, +#         trg: Tensor, +#         memory: Tensor, +#         trg_mask: Optional[Tensor] = None, +#         memory_mask: Optional[Tensor] = None, +#     ) -> Tensor: +#         """Forward pass through the decoder.""" +#         for layer in self.layers: +#             trg = layer(trg, memory, trg_mask, memory_mask) +# +#         if self.norm is not None: +#             trg = self.norm(trg) +# +#         return trg +# +# +# class Transformer(nn.Module): +#     """Transformer network.""" +# +#     def __init__( +#         self, +#         num_encoder_layers: int, +#         num_decoder_layers: int, +#         hidden_dim: int, +#         num_heads: int, +#         expansion_dim: int, +#         dropout_rate: float, +#         activation: str = "relu", +#     ) -> None: +#         super().__init__() +# +#         # Configure encoder. +#         encoder_norm = nn.LayerNorm(hidden_dim) +#         encoder_layer = EncoderLayer( +#             hidden_dim, num_heads, expansion_dim, dropout_rate, activation +#         ) +#         self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) +# +#         # Configure decoder. +#         decoder_norm = nn.LayerNorm(hidden_dim) +#         decoder_layer = DecoderLayer( +#             hidden_dim, num_heads, expansion_dim, dropout_rate, activation +#         ) +#         self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) +# +#         self._reset_parameters() +# +#     def _reset_parameters(self) -> None: +#         for p in self.parameters(): +#             if p.dim() > 1: +#                 nn.init.xavier_uniform_(p) +# +#     def forward( +#         self, +#         src: Tensor, +#         trg: Tensor, +#         src_mask: Optional[Tensor] = None, +#         trg_mask: Optional[Tensor] = None, +#         memory_mask: Optional[Tensor] = None, +#     ) -> Tensor: +#         """Forward pass through the transformer.""" +#         if src.shape[0] != trg.shape[0]: +#             print(trg.shape) +#             raise RuntimeError("The batch size of the src and trg must be the same.") +#         if src.shape[2] != trg.shape[2]: +#             raise RuntimeError( +#                 "The number of features for the src and trg must be the same." +#             ) +# +#         memory = self.encoder(src, src_mask) +#         output = self.decoder(trg, memory, trg_mask, memory_mask) +#         return output |