From c9c60678673e19ad3367339eb8e7a093e5a98474 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 9 May 2021 22:46:09 +0200 Subject: Reformatting of positional encodings and ViT working --- .../networks/transformer/transformer.py | 321 ++++----------------- 1 file changed, 61 insertions(+), 260 deletions(-) (limited to 'text_recognizer/networks/transformer/transformer.py') diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py index d49c85a..36f86ac 100644 --- a/text_recognizer/networks/transformer/transformer.py +++ b/text_recognizer/networks/transformer/transformer.py @@ -1,260 +1,61 @@ -# """Transfomer module.""" -# import copy -# from typing import Dict, Optional, Type, Union -# -# import numpy as np -# import torch -# from torch import nn -# from torch import Tensor -# import torch.nn.functional as F -# -# from text_recognizer.networks.transformer.attention import MultiHeadAttention -# from text_recognizer.networks.util import activation_function -# -# -# class GEGLU(nn.Module): -# """GLU activation for improving feedforward activations.""" -# -# def __init__(self, dim_in: int, dim_out: int) -> None: -# super().__init__() -# self.proj = nn.Linear(dim_in, dim_out * 2) -# -# def forward(self, x: Tensor) -> Tensor: -# """Forward propagation.""" -# x, gate = self.proj(x).chunk(2, dim=-1) -# return x * F.gelu(gate) -# -# -# def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: -# return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) -# -# -# class _IntraLayerConnection(nn.Module): -# """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" -# -# def __init__(self, dropout_rate: float, hidden_dim: int) -> None: -# super().__init__() -# self.norm = nn.LayerNorm(normalized_shape=hidden_dim) -# self.dropout = nn.Dropout(p=dropout_rate) -# -# def forward(self, src: Tensor, residual: Tensor) -> Tensor: -# return self.norm(self.dropout(src) + residual) -# -# -# class FeedForward(nn.Module): -# def __init__( -# self, -# hidden_dim: int, -# expansion_dim: int, -# dropout_rate: float, -# activation: str = "relu", -# ) -> None: -# super().__init__() -# -# in_projection = ( -# nn.Sequential( -# nn.Linear(hidden_dim, expansion_dim), activation_function(activation) -# ) -# if activation != "glu" -# else GEGLU(hidden_dim, expansion_dim) -# ) -# -# self.layer = nn.Sequential( -# in_projection, -# nn.Dropout(p=dropout_rate), -# nn.Linear(in_features=expansion_dim, out_features=hidden_dim), -# ) -# -# def forward(self, x: Tensor) -> Tensor: -# return self.layer(x) -# -# -# class EncoderLayer(nn.Module): -# """Transfomer encoding layer.""" -# -# def __init__( -# self, -# hidden_dim: int, -# num_heads: int, -# expansion_dim: int, -# dropout_rate: float, -# activation: str = "relu", -# ) -> None: -# super().__init__() -# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) -# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) -# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) -# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) -# -# def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: -# """Forward pass through the encoder.""" -# # First block. -# # Multi head attention. -# out, _ = self.self_attention(src, src, src, mask) -# -# # Add & norm. -# out = self.block1(out, src) -# -# # Second block. -# # Apply 1D-convolution. -# mlp_out = self.mlp(out) -# -# # Add & norm. -# out = self.block2(mlp_out, out) -# -# return out -# -# -# class Encoder(nn.Module): -# """Transfomer encoder module.""" -# -# def __init__( -# self, -# num_layers: int, -# encoder_layer: Type[nn.Module], -# norm: Optional[Type[nn.Module]] = None, -# ) -> None: -# super().__init__() -# self.layers = _get_clones(encoder_layer, num_layers) -# self.norm = norm -# -# def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: -# """Forward pass through all encoder layers.""" -# for layer in self.layers: -# src = layer(src, src_mask) -# -# if self.norm is not None: -# src = self.norm(src) -# -# return src -# -# -# class DecoderLayer(nn.Module): -# """Transfomer decoder layer.""" -# -# def __init__( -# self, -# hidden_dim: int, -# num_heads: int, -# expansion_dim: int, -# dropout_rate: float = 0.0, -# activation: str = "relu", -# ) -> None: -# super().__init__() -# self.hidden_dim = hidden_dim -# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) -# self.multihead_attention = MultiHeadAttention( -# hidden_dim, num_heads, dropout_rate -# ) -# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) -# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) -# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) -# self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) -# -# def forward( -# self, -# trg: Tensor, -# memory: Tensor, -# trg_mask: Optional[Tensor] = None, -# memory_mask: Optional[Tensor] = None, -# ) -> Tensor: -# """Forward pass of the layer.""" -# out, _ = self.self_attention(trg, trg, trg, trg_mask) -# trg = self.block1(out, trg) -# -# out, _ = self.multihead_attention(trg, memory, memory, memory_mask) -# trg = self.block2(out, trg) -# -# out = self.mlp(trg) -# out = self.block3(out, trg) -# -# return out -# -# -# class Decoder(nn.Module): -# """Transfomer decoder module.""" -# -# def __init__( -# self, -# decoder_layer: Type[nn.Module], -# num_layers: int, -# norm: Optional[Type[nn.Module]] = None, -# ) -> None: -# super().__init__() -# self.layers = _get_clones(decoder_layer, num_layers) -# self.num_layers = num_layers -# self.norm = norm -# -# def forward( -# self, -# trg: Tensor, -# memory: Tensor, -# trg_mask: Optional[Tensor] = None, -# memory_mask: Optional[Tensor] = None, -# ) -> Tensor: -# """Forward pass through the decoder.""" -# for layer in self.layers: -# trg = layer(trg, memory, trg_mask, memory_mask) -# -# if self.norm is not None: -# trg = self.norm(trg) -# -# return trg -# -# -# class Transformer(nn.Module): -# """Transformer network.""" -# -# def __init__( -# self, -# num_encoder_layers: int, -# num_decoder_layers: int, -# hidden_dim: int, -# num_heads: int, -# expansion_dim: int, -# dropout_rate: float, -# activation: str = "relu", -# ) -> None: -# super().__init__() -# -# # Configure encoder. -# encoder_norm = nn.LayerNorm(hidden_dim) -# encoder_layer = EncoderLayer( -# hidden_dim, num_heads, expansion_dim, dropout_rate, activation -# ) -# self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) -# -# # Configure decoder. -# decoder_norm = nn.LayerNorm(hidden_dim) -# decoder_layer = DecoderLayer( -# hidden_dim, num_heads, expansion_dim, dropout_rate, activation -# ) -# self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) -# -# self._reset_parameters() -# -# def _reset_parameters(self) -> None: -# for p in self.parameters(): -# if p.dim() > 1: -# nn.init.xavier_uniform_(p) -# -# def forward( -# self, -# src: Tensor, -# trg: Tensor, -# src_mask: Optional[Tensor] = None, -# trg_mask: Optional[Tensor] = None, -# memory_mask: Optional[Tensor] = None, -# ) -> Tensor: -# """Forward pass through the transformer.""" -# if src.shape[0] != trg.shape[0]: -# print(trg.shape) -# raise RuntimeError("The batch size of the src and trg must be the same.") -# if src.shape[2] != trg.shape[2]: -# raise RuntimeError( -# "The number of features for the src and trg must be the same." -# ) -# -# memory = self.encoder(src, src_mask) -# output = self.decoder(trg, memory, trg_mask, memory_mask) -# return output +"""Transformer wrapper.""" +from typing import Optional, Type + +from torch import nn, Tensor + +from .layers import AttentionLayers +from text_recognizer.networks.transformer.positional_encodings import ( + AbsolutePositionalEmbedding, +) + + +class Transformer(nn.Module): + def __init__( + self, + num_tokens: int, + max_seq_len: int, + attn_layers: Type[AttentionLayers], + emb_dim: Optional[int] = None, + emb_dropout: float = 0.0, + use_pos_emb: bool = True, + ) -> None: + dim = attn_layers.dim + emb_dim = emb_dim if emb_dim is not None else dim + self.max_seq_len = max_seq_len + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.emb_dropout = nn.Dropout(emb_dropout) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not self.attn_layers.has_pos_emb) + else None + ) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self._init_weights() + + self.logits = nn.Linear(dim, num_tokens) + + def _init_weights(self) -> None: + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x: Tensor, + mask: Optional[Tensor], + return_embeddings: bool = False, + **kwargs: Any + ) -> Tensor: + b, n, device = *x.shape, x.device + x += self.token_emb(x) + if self.pos_emb is not None: + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + x = self.attn_layers(x, mask=mask, **kwargs) + out = self.logits(x) if not return_embeddings else x + return x -- cgit v1.2.3-70-g09d2