diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-02 13:51:15 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-02 13:51:15 +0200 |
commit | 1d0977585f01c42e9f6280559a1a98037907a62e (patch) | |
tree | 7e86dd71b163f3138ed2658cb52c44e805f21539 /text_recognizer | |
parent | 58ae7154aa945cfe5a46592cc1dfb28f0a4e51b3 (diff) |
Implemented training script with hydra
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/__init__.py | 3 | ||||
-rw-r--r-- | text_recognizer/networks/cnn_transformer.py | 364 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/__init__.py | 3 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/transformer.py | 520 |
5 files changed, 447 insertions, 445 deletions
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 78e6c05..ad6fa25 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -76,7 +76,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): def setup(self, stage: str = None) -> None: """Loading synthetic dataset.""" - logger.info(f"IAM Synthetic dataset steup for stage {stage}") + logger.info(f"IAM Synthetic dataset steup for stage {stage}...") if stage == "fit" or stage is None: line_crops, line_labels = load_line_crops_and_labels( diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index a9117f8..d1ebf1a 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,4 +1,5 @@ """Network modules""" from .encoders import EfficientNet from .vqvae import VQVAE -from .cnn_transformer import CNNTransformer + +# from .cnn_transformer import CNNTransformer diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index d42c29d..80798e1 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -1,182 +1,182 @@ -"""A Transformer with a cnn backbone. - -The network encodes a image with a convolutional backbone to a latent representation, -i.e. feature maps. A 2d positional encoding is applied to the feature maps for -spatial information. The resulting feature are then set to a transformer decoder -together with the target tokens. - -TODO: Local attention for lower layer in attention. - -""" -import importlib -import math -from typing import Dict, Optional, Union, Sequence, Type - -from einops import rearrange -from omegaconf import DictConfig, OmegaConf -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS -from text_recognizer.networks.transformer import ( - Decoder, - DecoderLayer, - PositionalEncoding, - PositionalEncoding2D, - target_padding_mask, -) - -NUM_WORD_PIECES = 1000 - - -class CNNTransformer(nn.Module): - def __init__( - self, - input_dim: Sequence[int], - output_dims: Sequence[int], - encoder: Union[DictConfig, Dict], - vocab_size: Optional[int] = None, - num_decoder_layers: int = 4, - hidden_dim: int = 256, - num_heads: int = 4, - expansion_dim: int = 1024, - dropout_rate: float = 0.1, - transformer_activation: str = "glu", - *args, - **kwargs, - ) -> None: - super().__init__() - self.vocab_size = ( - NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size - ) - self.pad_index = 3 # TODO: fix me - self.hidden_dim = hidden_dim - self.max_output_length = output_dims[0] - - # Image backbone - self.encoder = self._configure_encoder(encoder) - self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) - self.feature_map_encoding = PositionalEncoding2D( - hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] - ) - - # Target token embedding - self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) - self.trg_position_encoding = PositionalEncoding( - hidden_dim, dropout_rate, max_len=output_dims[0] - ) - - # Transformer decoder - self.decoder = Decoder( - decoder_layer=DecoderLayer( - hidden_dim=hidden_dim, - num_heads=num_heads, - expansion_dim=expansion_dim, - dropout_rate=dropout_rate, - activation=transformer_activation, - ), - num_layers=num_decoder_layers, - norm=nn.LayerNorm(hidden_dim), - ) - - # Classification head - self.head = nn.Linear(hidden_dim, self.vocab_size) - - # Initialize weights - self._init_weights() - - def _init_weights(self) -> None: - """Initialize network weights.""" - self.trg_embedding.weight.data.uniform_(-0.1, 0.1) - self.head.bias.data.zero_() - self.head.weight.data.uniform_(-0.1, 0.1) - - nn.init.kaiming_normal_( - self.encoder_proj.weight.data, - a=0, - mode="fan_out", - nonlinearity="relu", - ) - if self.encoder_proj.bias is not None: - _, fan_out = nn.init._calculate_fan_in_and_fan_out( - self.encoder_proj.weight.data - ) - bound = 1 / math.sqrt(fan_out) - nn.init.normal_(self.encoder_proj.bias, -bound, bound) - - @staticmethod - def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: - encoder = OmegaConf.create(encoder) - args = encoder.args or {} - network_module = importlib.import_module("text_recognizer.networks") - encoder_class = getattr(network_module, encoder.type) - return encoder_class(**args) - - def encode(self, image: Tensor) -> Tensor: - """Extracts image features with backbone. - - Args: - image (Tensor): Image(s) of handwritten text. - - Retuns: - Tensor: Image features. - - Shapes: - - image: :math: `(B, C, H, W)` - - latent: :math: `(B, T, C)` - - """ - # Extract image features. - image_features = self.encoder(image) - image_features = self.encoder_proj(image_features) - - # Add 2d encoding to the feature maps. - image_features = self.feature_map_encoding(image_features) - - # Collapse features maps height and width. - image_features = rearrange(image_features, "b c h w -> b (h w) c") - return image_features - - def decode(self, memory: Tensor, trg: Tensor) -> Tensor: - """Decodes image features with transformer decoder.""" - trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) - trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) - trg = rearrange(trg, "b t d -> t b d") - trg = self.trg_position_encoding(trg) - trg = rearrange(trg, "t b d -> b t d") - out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) - logits = self.head(out) - return logits - - def forward(self, image: Tensor, trg: Tensor) -> Tensor: - image_features = self.encode(image) - output = self.decode(image_features, trg) - output = rearrange(output, "b t c -> b c t") - return output - - def predict(self, image: Tensor) -> Tensor: - """Transcribes text in image(s).""" - bsz = image.shape[0] - image_features = self.encode(image) - - output_tokens = ( - (torch.ones((bsz, self.max_output_length)) * self.pad_index) - .type_as(image) - .long() - ) - output_tokens[:, 0] = self.start_index - for i in range(1, self.max_output_length): - trg = output_tokens[:, :i] - output = self.decode(image_features, trg) - output = torch.argmax(output, dim=-1) - output_tokens[:, i] = output[-1:] - - # Set all tokens after end token to be padding. - for i in range(1, self.max_output_length): - indices = output_tokens[:, i - 1] == self.end_index | ( - output_tokens[:, i - 1] == self.pad_index - ) - output_tokens[indices, i] = self.pad_index - - return output_tokens +# """A Transformer with a cnn backbone. +# +# The network encodes a image with a convolutional backbone to a latent representation, +# i.e. feature maps. A 2d positional encoding is applied to the feature maps for +# spatial information. The resulting feature are then set to a transformer decoder +# together with the target tokens. +# +# TODO: Local attention for lower layer in attention. +# +# """ +# import importlib +# import math +# from typing import Dict, Optional, Union, Sequence, Type +# +# from einops import rearrange +# from omegaconf import DictConfig, OmegaConf +# import torch +# from torch import nn +# from torch import Tensor +# +# from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS +# from text_recognizer.networks.transformer import ( +# Decoder, +# DecoderLayer, +# PositionalEncoding, +# PositionalEncoding2D, +# target_padding_mask, +# ) +# +# NUM_WORD_PIECES = 1000 +# +# +# class CNNTransformer(nn.Module): +# def __init__( +# self, +# input_dim: Sequence[int], +# output_dims: Sequence[int], +# encoder: Union[DictConfig, Dict], +# vocab_size: Optional[int] = None, +# num_decoder_layers: int = 4, +# hidden_dim: int = 256, +# num_heads: int = 4, +# expansion_dim: int = 1024, +# dropout_rate: float = 0.1, +# transformer_activation: str = "glu", +# *args, +# **kwargs, +# ) -> None: +# super().__init__() +# self.vocab_size = ( +# NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size +# ) +# self.pad_index = 3 # TODO: fix me +# self.hidden_dim = hidden_dim +# self.max_output_length = output_dims[0] +# +# # Image backbone +# self.encoder = self._configure_encoder(encoder) +# self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) +# self.feature_map_encoding = PositionalEncoding2D( +# hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] +# ) +# +# # Target token embedding +# self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) +# self.trg_position_encoding = PositionalEncoding( +# hidden_dim, dropout_rate, max_len=output_dims[0] +# ) +# +# # Transformer decoder +# self.decoder = Decoder( +# decoder_layer=DecoderLayer( +# hidden_dim=hidden_dim, +# num_heads=num_heads, +# expansion_dim=expansion_dim, +# dropout_rate=dropout_rate, +# activation=transformer_activation, +# ), +# num_layers=num_decoder_layers, +# norm=nn.LayerNorm(hidden_dim), +# ) +# +# # Classification head +# self.head = nn.Linear(hidden_dim, self.vocab_size) +# +# # Initialize weights +# self._init_weights() +# +# def _init_weights(self) -> None: +# """Initialize network weights.""" +# self.trg_embedding.weight.data.uniform_(-0.1, 0.1) +# self.head.bias.data.zero_() +# self.head.weight.data.uniform_(-0.1, 0.1) +# +# nn.init.kaiming_normal_( +# self.encoder_proj.weight.data, +# a=0, +# mode="fan_out", +# nonlinearity="relu", +# ) +# if self.encoder_proj.bias is not None: +# _, fan_out = nn.init._calculate_fan_in_and_fan_out( +# self.encoder_proj.weight.data +# ) +# bound = 1 / math.sqrt(fan_out) +# nn.init.normal_(self.encoder_proj.bias, -bound, bound) +# +# @staticmethod +# def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: +# encoder = OmegaConf.create(encoder) +# args = encoder.args or {} +# network_module = importlib.import_module("text_recognizer.networks") +# encoder_class = getattr(network_module, encoder.type) +# return encoder_class(**args) +# +# def encode(self, image: Tensor) -> Tensor: +# """Extracts image features with backbone. +# +# Args: +# image (Tensor): Image(s) of handwritten text. +# +# Retuns: +# Tensor: Image features. +# +# Shapes: +# - image: :math: `(B, C, H, W)` +# - latent: :math: `(B, T, C)` +# +# """ +# # Extract image features. +# image_features = self.encoder(image) +# image_features = self.encoder_proj(image_features) +# +# # Add 2d encoding to the feature maps. +# image_features = self.feature_map_encoding(image_features) +# +# # Collapse features maps height and width. +# image_features = rearrange(image_features, "b c h w -> b (h w) c") +# return image_features +# +# def decode(self, memory: Tensor, trg: Tensor) -> Tensor: +# """Decodes image features with transformer decoder.""" +# trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) +# trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) +# trg = rearrange(trg, "b t d -> t b d") +# trg = self.trg_position_encoding(trg) +# trg = rearrange(trg, "t b d -> b t d") +# out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) +# logits = self.head(out) +# return logits +# +# def forward(self, image: Tensor, trg: Tensor) -> Tensor: +# image_features = self.encode(image) +# output = self.decode(image_features, trg) +# output = rearrange(output, "b t c -> b c t") +# return output +# +# def predict(self, image: Tensor) -> Tensor: +# """Transcribes text in image(s).""" +# bsz = image.shape[0] +# image_features = self.encode(image) +# +# output_tokens = ( +# (torch.ones((bsz, self.max_output_length)) * self.pad_index) +# .type_as(image) +# .long() +# ) +# output_tokens[:, 0] = self.start_index +# for i in range(1, self.max_output_length): +# trg = output_tokens[:, :i] +# output = self.decode(image_features, trg) +# output = torch.argmax(output, dim=-1) +# output_tokens[:, i] = output[-1:] +# +# # Set all tokens after end token to be padding. +# for i in range(1, self.max_output_length): +# indices = output_tokens[:, i - 1] == self.end_index | ( +# output_tokens[:, i - 1] == self.pad_index +# ) +# output_tokens[indices, i] = self.pad_index +# +# return output_tokens diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 627fa7b..4ff48f7 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -4,4 +4,5 @@ from .positional_encoding import ( PositionalEncoding2D, target_padding_mask, ) -from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer + +# from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py index 5ac2787..d49c85a 100644 --- a/text_recognizer/networks/transformer/transformer.py +++ b/text_recognizer/networks/transformer/transformer.py @@ -1,260 +1,260 @@ -"""Transfomer module.""" -import copy -from typing import Dict, Optional, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - -from text_recognizer.networks.transformer.attention import MultiHeadAttention -from text_recognizer.networks.util import activation_function - - -class GEGLU(nn.Module): - """GLU activation for improving feedforward activations.""" - - def __init__(self, dim_in: int, dim_out: int) -> None: - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x: Tensor) -> Tensor: - """Forward propagation.""" - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: - return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) - - -class _IntraLayerConnection(nn.Module): - """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" - - def __init__(self, dropout_rate: float, hidden_dim: int) -> None: - super().__init__() - self.norm = nn.LayerNorm(normalized_shape=hidden_dim) - self.dropout = nn.Dropout(p=dropout_rate) - - def forward(self, src: Tensor, residual: Tensor) -> Tensor: - return self.norm(self.dropout(src) + residual) - - -class FeedForward(nn.Module): - def __init__( - self, - hidden_dim: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - - in_projection = ( - nn.Sequential( - nn.Linear(hidden_dim, expansion_dim), activation_function(activation) - ) - if activation != "glu" - else GEGLU(hidden_dim, expansion_dim) - ) - - self.layer = nn.Sequential( - in_projection, - nn.Dropout(p=dropout_rate), - nn.Linear(in_features=expansion_dim, out_features=hidden_dim), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layer(x) - - -class EncoderLayer(nn.Module): - """Transfomer encoding layer.""" - - def __init__( - self, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) - self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) - self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) - - def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: - """Forward pass through the encoder.""" - # First block. - # Multi head attention. - out, _ = self.self_attention(src, src, src, mask) - - # Add & norm. - out = self.block1(out, src) - - # Second block. - # Apply 1D-convolution. - mlp_out = self.mlp(out) - - # Add & norm. - out = self.block2(mlp_out, out) - - return out - - -class Encoder(nn.Module): - """Transfomer encoder module.""" - - def __init__( - self, - num_layers: int, - encoder_layer: Type[nn.Module], - norm: Optional[Type[nn.Module]] = None, - ) -> None: - super().__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.norm = norm - - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: - """Forward pass through all encoder layers.""" - for layer in self.layers: - src = layer(src, src_mask) - - if self.norm is not None: - src = self.norm(src) - - return src - - -class DecoderLayer(nn.Module): - """Transfomer decoder layer.""" - - def __init__( - self, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float = 0.0, - activation: str = "relu", - ) -> None: - super().__init__() - self.hidden_dim = hidden_dim - self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) - self.multihead_attention = MultiHeadAttention( - hidden_dim, num_heads, dropout_rate - ) - self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) - self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) - - def forward( - self, - trg: Tensor, - memory: Tensor, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass of the layer.""" - out, _ = self.self_attention(trg, trg, trg, trg_mask) - trg = self.block1(out, trg) - - out, _ = self.multihead_attention(trg, memory, memory, memory_mask) - trg = self.block2(out, trg) - - out = self.mlp(trg) - out = self.block3(out, trg) - - return out - - -class Decoder(nn.Module): - """Transfomer decoder module.""" - - def __init__( - self, - decoder_layer: Type[nn.Module], - num_layers: int, - norm: Optional[Type[nn.Module]] = None, - ) -> None: - super().__init__() - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( - self, - trg: Tensor, - memory: Tensor, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass through the decoder.""" - for layer in self.layers: - trg = layer(trg, memory, trg_mask, memory_mask) - - if self.norm is not None: - trg = self.norm(trg) - - return trg - - -class Transformer(nn.Module): - """Transformer network.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - - # Configure encoder. - encoder_norm = nn.LayerNorm(hidden_dim) - encoder_layer = EncoderLayer( - hidden_dim, num_heads, expansion_dim, dropout_rate, activation - ) - self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) - - # Configure decoder. - decoder_norm = nn.LayerNorm(hidden_dim) - decoder_layer = DecoderLayer( - hidden_dim, num_heads, expansion_dim, dropout_rate, activation - ) - self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - - def forward( - self, - src: Tensor, - trg: Tensor, - src_mask: Optional[Tensor] = None, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass through the transformer.""" - if src.shape[0] != trg.shape[0]: - print(trg.shape) - raise RuntimeError("The batch size of the src and trg must be the same.") - if src.shape[2] != trg.shape[2]: - raise RuntimeError( - "The number of features for the src and trg must be the same." - ) - - memory = self.encoder(src, src_mask) - output = self.decoder(trg, memory, trg_mask, memory_mask) - return output +# """Transfomer module.""" +# import copy +# from typing import Dict, Optional, Type, Union +# +# import numpy as np +# import torch +# from torch import nn +# from torch import Tensor +# import torch.nn.functional as F +# +# from text_recognizer.networks.transformer.attention import MultiHeadAttention +# from text_recognizer.networks.util import activation_function +# +# +# class GEGLU(nn.Module): +# """GLU activation for improving feedforward activations.""" +# +# def __init__(self, dim_in: int, dim_out: int) -> None: +# super().__init__() +# self.proj = nn.Linear(dim_in, dim_out * 2) +# +# def forward(self, x: Tensor) -> Tensor: +# """Forward propagation.""" +# x, gate = self.proj(x).chunk(2, dim=-1) +# return x * F.gelu(gate) +# +# +# def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: +# return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) +# +# +# class _IntraLayerConnection(nn.Module): +# """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" +# +# def __init__(self, dropout_rate: float, hidden_dim: int) -> None: +# super().__init__() +# self.norm = nn.LayerNorm(normalized_shape=hidden_dim) +# self.dropout = nn.Dropout(p=dropout_rate) +# +# def forward(self, src: Tensor, residual: Tensor) -> Tensor: +# return self.norm(self.dropout(src) + residual) +# +# +# class FeedForward(nn.Module): +# def __init__( +# self, +# hidden_dim: int, +# expansion_dim: int, +# dropout_rate: float, +# activation: str = "relu", +# ) -> None: +# super().__init__() +# +# in_projection = ( +# nn.Sequential( +# nn.Linear(hidden_dim, expansion_dim), activation_function(activation) +# ) +# if activation != "glu" +# else GEGLU(hidden_dim, expansion_dim) +# ) +# +# self.layer = nn.Sequential( +# in_projection, +# nn.Dropout(p=dropout_rate), +# nn.Linear(in_features=expansion_dim, out_features=hidden_dim), +# ) +# +# def forward(self, x: Tensor) -> Tensor: +# return self.layer(x) +# +# +# class EncoderLayer(nn.Module): +# """Transfomer encoding layer.""" +# +# def __init__( +# self, +# hidden_dim: int, +# num_heads: int, +# expansion_dim: int, +# dropout_rate: float, +# activation: str = "relu", +# ) -> None: +# super().__init__() +# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) +# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) +# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) +# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) +# +# def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: +# """Forward pass through the encoder.""" +# # First block. +# # Multi head attention. +# out, _ = self.self_attention(src, src, src, mask) +# +# # Add & norm. +# out = self.block1(out, src) +# +# # Second block. +# # Apply 1D-convolution. +# mlp_out = self.mlp(out) +# +# # Add & norm. +# out = self.block2(mlp_out, out) +# +# return out +# +# +# class Encoder(nn.Module): +# """Transfomer encoder module.""" +# +# def __init__( +# self, +# num_layers: int, +# encoder_layer: Type[nn.Module], +# norm: Optional[Type[nn.Module]] = None, +# ) -> None: +# super().__init__() +# self.layers = _get_clones(encoder_layer, num_layers) +# self.norm = norm +# +# def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: +# """Forward pass through all encoder layers.""" +# for layer in self.layers: +# src = layer(src, src_mask) +# +# if self.norm is not None: +# src = self.norm(src) +# +# return src +# +# +# class DecoderLayer(nn.Module): +# """Transfomer decoder layer.""" +# +# def __init__( +# self, +# hidden_dim: int, +# num_heads: int, +# expansion_dim: int, +# dropout_rate: float = 0.0, +# activation: str = "relu", +# ) -> None: +# super().__init__() +# self.hidden_dim = hidden_dim +# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) +# self.multihead_attention = MultiHeadAttention( +# hidden_dim, num_heads, dropout_rate +# ) +# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) +# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) +# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) +# self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) +# +# def forward( +# self, +# trg: Tensor, +# memory: Tensor, +# trg_mask: Optional[Tensor] = None, +# memory_mask: Optional[Tensor] = None, +# ) -> Tensor: +# """Forward pass of the layer.""" +# out, _ = self.self_attention(trg, trg, trg, trg_mask) +# trg = self.block1(out, trg) +# +# out, _ = self.multihead_attention(trg, memory, memory, memory_mask) +# trg = self.block2(out, trg) +# +# out = self.mlp(trg) +# out = self.block3(out, trg) +# +# return out +# +# +# class Decoder(nn.Module): +# """Transfomer decoder module.""" +# +# def __init__( +# self, +# decoder_layer: Type[nn.Module], +# num_layers: int, +# norm: Optional[Type[nn.Module]] = None, +# ) -> None: +# super().__init__() +# self.layers = _get_clones(decoder_layer, num_layers) +# self.num_layers = num_layers +# self.norm = norm +# +# def forward( +# self, +# trg: Tensor, +# memory: Tensor, +# trg_mask: Optional[Tensor] = None, +# memory_mask: Optional[Tensor] = None, +# ) -> Tensor: +# """Forward pass through the decoder.""" +# for layer in self.layers: +# trg = layer(trg, memory, trg_mask, memory_mask) +# +# if self.norm is not None: +# trg = self.norm(trg) +# +# return trg +# +# +# class Transformer(nn.Module): +# """Transformer network.""" +# +# def __init__( +# self, +# num_encoder_layers: int, +# num_decoder_layers: int, +# hidden_dim: int, +# num_heads: int, +# expansion_dim: int, +# dropout_rate: float, +# activation: str = "relu", +# ) -> None: +# super().__init__() +# +# # Configure encoder. +# encoder_norm = nn.LayerNorm(hidden_dim) +# encoder_layer = EncoderLayer( +# hidden_dim, num_heads, expansion_dim, dropout_rate, activation +# ) +# self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) +# +# # Configure decoder. +# decoder_norm = nn.LayerNorm(hidden_dim) +# decoder_layer = DecoderLayer( +# hidden_dim, num_heads, expansion_dim, dropout_rate, activation +# ) +# self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) +# +# self._reset_parameters() +# +# def _reset_parameters(self) -> None: +# for p in self.parameters(): +# if p.dim() > 1: +# nn.init.xavier_uniform_(p) +# +# def forward( +# self, +# src: Tensor, +# trg: Tensor, +# src_mask: Optional[Tensor] = None, +# trg_mask: Optional[Tensor] = None, +# memory_mask: Optional[Tensor] = None, +# ) -> Tensor: +# """Forward pass through the transformer.""" +# if src.shape[0] != trg.shape[0]: +# print(trg.shape) +# raise RuntimeError("The batch size of the src and trg must be the same.") +# if src.shape[2] != trg.shape[2]: +# raise RuntimeError( +# "The number of features for the src and trg must be the same." +# ) +# +# memory = self.encoder(src, src_mask) +# output = self.decoder(trg, memory, trg_mask, memory_mask) +# return output |