diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-30 00:35:34 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-30 00:35:34 +0200 |
commit | bace9540792d517c1687d91f455f9503d9854609 (patch) | |
tree | 38aff5e65a9f7410930aa8a6a92bf7ecd7c89bd5 | |
parent | 9e85e0883f2e921ca9a57cb2fd93ec47a2535d59 (diff) |
Update conv transformer model
-rw-r--r-- | notebooks/04-conv-transformer.ipynb | 180 | ||||
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 122 | ||||
-rw-r--r-- | text_recognizer/networks/image_encoder.py | 45 | ||||
-rw-r--r-- | text_recognizer/networks/text_decoder.py | 62 | ||||
-rw-r--r-- | training/conf/experiment/conv_transformer_lines.yaml | 136 | ||||
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 109 |
6 files changed, 303 insertions, 351 deletions
diff --git a/notebooks/04-conv-transformer.ipynb b/notebooks/04-conv-transformer.ipynb index b864098..0d8b370 100644 --- a/notebooks/04-conv-transformer.ipynb +++ b/notebooks/04-conv-transformer.ipynb @@ -2,19 +2,10 @@ "cells": [ { "cell_type": "code", - "execution_count": 8, + "execution_count": 1, "id": "7c02ae76-b540-4b16-9492-e9210b3b9249", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n", @@ -49,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "id": "3cf50475-39f2-4642-a7d1-5bcbc0a036f7", "metadata": {}, "outputs": [], @@ -59,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "id": "e52ecb01-c975-4e55-925d-1182c7aea473", "metadata": {}, "outputs": [], @@ -70,17 +61,17 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "id": "f939aa37-7b1d-45cc-885c-323c4540bda1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'_target_': 'text_recognizer.networks.ConvTransformer', 'input_dims': [1, 1, 576, 640], 'hidden_dim': 128, 'num_classes': 58, 'pad_index': 3, 'encoder': {'_target_': 'text_recognizer.networks.convnext.ConvNext', 'dim': 16, 'dim_mults': [2, 4, 8], 'depths': [3, 3, 6], 'downsampling_factors': [[2, 2], [2, 2], [2, 2]]}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 128, 'depth': 10, 'block': {'_target_': 'text_recognizer.networks.transformer.decoder_block.DecoderBlock', 'self_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': True, 'rotary_embedding': {'_target_': 'text_recognizer.networks.transformer.RotaryEmbedding', 'dim': 64}}, 'cross_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': False}, 'norm': {'_target_': 'text_recognizer.networks.transformer.RMSNorm', 'dim': 128}, 'ff': {'_target_': 'text_recognizer.networks.transformer.FeedForward', 'dim': 128, 'dim_out': None, 'expansion_factor': 2, 'glu': True, 'dropout_rate': 0.2}}}, 'pixel_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage', 'dim': 128, 'axial_shape': [7, 128], 'axial_dims': [64, 64]}, 'token_pos_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding', 'dim': 128, 'dropout_rate': 0.1, 'max_len': 89}}" + "{'_target_': 'text_recognizer.networks.ConvTransformer', 'encoder': {'_target_': 'text_recognizer.networks.image_encoder.ImageEncoder', 'encoder': {'_target_': 'text_recognizer.networks.convnext.ConvNext', 'dim': 16, 'dim_mults': [2, 4, 8], 'depths': [3, 3, 6], 'downsampling_factors': [[2, 2], [2, 2], [2, 2]]}, 'pixel_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage', 'dim': 128, 'axial_shape': [7, 128], 'axial_dims': [64, 64]}}, 'decoder': {'_target_': 'text_recognizer.networks.text_decoder.TextDecoder', 'hidden_dim': 128, 'num_classes': 58, 'pad_index': 3, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 128, 'depth': 10, 'block': {'_target_': 'text_recognizer.networks.transformer.decoder_block.DecoderBlock', 'self_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': True}, 'cross_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': False}, 'norm': {'_target_': 'text_recognizer.networks.transformer.RMSNorm', 'dim': 128}, 'ff': {'_target_': 'text_recognizer.networks.transformer.FeedForward', 'dim': 128, 'dim_out': None, 'expansion_factor': 2, 'glu': True, 'dropout_rate': 0.2}}, 'rotary_embedding': {'_target_': 'text_recognizer.networks.transformer.RotaryEmbedding', 'dim': 64}}, 'token_pos_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding', 'dim': 128, 'dropout_rate': 0.1, 'max_len': 89}}}" ] }, - "execution_count": 5, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -91,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 13, "id": "aaeab329-aeb0-4a1b-aa35-5a2aab81b1d0", "metadata": { "scrolled": false @@ -103,7 +94,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "618b997c-e6a6-4487-b70c-9d260cb556d3", "metadata": {}, "outputs": [], @@ -113,125 +104,60 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "id": "7daf1f49", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "=========================================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "=========================================================================================================\n", - "ConvTransformer [1, 58, 89] --\n", - "├─ConvNext: 1-1 [1, 128, 7, 128] 1,051,488\n", - "│ └─Conv2d: 2-1 [1, 16, 56, 1024] 800\n", - "│ └─ModuleList: 2 -- --\n", - "│ │ └─ModuleList: 3 -- --\n", - "│ │ │ └─ConvNextBlock: 4-1 [1, 16, 56, 1024] 10,080\n", - "│ │ │ └─Downsample: 4-2 [1, 32, 28, 512] 2,080\n", - "│ │ └─ModuleList: 3 -- --\n", - "│ │ │ └─ConvNextBlock: 4-3 [1, 32, 28, 512] 38,592\n", - "│ │ │ └─Downsample: 4-4 [1, 64, 14, 256] 8,256\n", - "│ │ └─ModuleList: 3 -- --\n", - "│ │ │ └─ConvNextBlock: 4-5 [1, 64, 14, 256] 150,912\n", - "│ │ │ └─Downsample: 4-6 [1, 128, 7, 128] 32,896\n", - "│ └─Identity: 2-2 [1, 128, 7, 128] --\n", - "│ └─LayerNorm: 2-3 [1, 128, 7, 128] 128\n", - "├─Conv2d: 1-2 [1, 128, 7, 128] 16,512\n", - "├─AxialPositionalEmbeddingImage: 1-3 [1, 128, 7, 128] --\n", - "│ └─AxialPositionalEmbedding: 2-4 [1, 896, 128] 8,640\n", - "├─Embedding: 1-4 [1, 89, 128] 7,424\n", - "├─PositionalEncoding: 1-5 [1, 89, 128] --\n", - "│ └─Dropout: 2-5 [1, 89, 128] --\n", - "├─Decoder: 1-6 [1, 89, 128] --\n", - "│ └─ModuleList: 2 -- --\n", - "│ │ └─DecoderBlock: 3-1 [1, 89, 128] --\n", - "│ │ │ └─RMSNorm: 4-7 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-8 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-9 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-10 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-11 [1, 89, 128] 128\n", - "│ │ │ └─FeedForward: 4-12 [1, 89, 128] 98,944\n", - "│ │ └─DecoderBlock: 3-2 [1, 89, 128] --\n", - "│ │ │ └─RMSNorm: 4-13 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-14 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-15 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-16 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-17 [1, 89, 128] 128\n", - "│ │ │ └─FeedForward: 4-18 [1, 89, 128] 98,944\n", - "│ │ └─DecoderBlock: 3-3 [1, 89, 128] --\n", - "│ │ │ └─RMSNorm: 4-19 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-20 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-21 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-22 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-23 [1, 89, 128] 128\n", - "│ │ │ └─FeedForward: 4-24 [1, 89, 128] 98,944\n", - "│ │ └─DecoderBlock: 3-4 [1, 89, 128] --\n", - "│ │ │ └─RMSNorm: 4-25 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-26 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-27 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-28 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-29 [1, 89, 128] 128\n", - "│ │ │ └─FeedForward: 4-30 [1, 89, 128] 98,944\n", - "│ │ └─DecoderBlock: 3-5 [1, 89, 128] --\n", - "│ │ │ └─RMSNorm: 4-31 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-32 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-33 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-34 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-35 [1, 89, 128] 128\n", - "│ │ │ └─FeedForward: 4-36 [1, 89, 128] 98,944\n", - "│ │ └─DecoderBlock: 3-6 [1, 89, 128] --\n", - "│ │ │ └─RMSNorm: 4-37 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-38 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-39 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-40 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-41 [1, 89, 128] 128\n", - "│ │ │ └─FeedForward: 4-42 [1, 89, 128] 98,944\n", - "│ │ └─DecoderBlock: 3-7 [1, 89, 128] --\n", - "│ │ │ └─RMSNorm: 4-43 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-44 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-45 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-46 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-47 [1, 89, 128] 128\n", - "│ │ │ └─FeedForward: 4-48 [1, 89, 128] 98,944\n", - "│ │ └─DecoderBlock: 3-8 [1, 89, 128] --\n", - "│ │ │ └─RMSNorm: 4-49 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-50 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-51 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-52 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-53 [1, 89, 128] 128\n", - "│ │ │ └─FeedForward: 4-54 [1, 89, 128] 98,944\n", - "│ │ └─DecoderBlock: 3-9 [1, 89, 128] --\n", - "│ │ │ └─RMSNorm: 4-55 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-56 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-57 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-58 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-59 [1, 89, 128] 128\n", - "│ │ │ └─FeedForward: 4-60 [1, 89, 128] 98,944\n", - "│ │ └─DecoderBlock: 3-10 [1, 89, 128] --\n", - "│ │ │ └─RMSNorm: 4-61 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-62 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-63 [1, 89, 128] 128\n", - "│ │ │ └─Attention: 4-64 [1, 89, 128] 393,344\n", - "│ │ │ └─RMSNorm: 4-65 [1, 89, 128] 128\n", - "│ │ │ └─FeedForward: 4-66 [1, 89, 128] 98,944\n", - "│ └─LayerNorm: 2-6 [1, 89, 128] 256\n", - "├─Linear: 1-7 [1, 89, 58] 7,482\n", - "=========================================================================================================\n", - "Total params: 10,195,706\n", - "Trainable params: 10,195,706\n", + "==============================================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "==============================================================================================================\n", + "ConvTransformer [1, 58, 89] --\n", + "├─ImageEncoder: 1-1 [1, 896, 128] --\n", + "│ └─ConvNext: 2-1 [1, 128, 7, 128] --\n", + "│ │ └─Conv2d: 3-1 [1, 16, 56, 1024] 800\n", + "│ │ └─ModuleList: 3-2 -- --\n", + "│ │ │ └─ModuleList: 4-1 -- 42,400\n", + "│ │ │ └─ModuleList: 4-2 -- 162,624\n", + "│ │ │ └─ModuleList: 4-3 -- 1,089,280\n", + "│ │ └─Identity: 3-3 [1, 128, 7, 128] --\n", + "│ │ └─LayerNorm: 3-4 [1, 128, 7, 128] 128\n", + "│ └─AxialPositionalEmbeddingImage: 2-2 [1, 128, 7, 128] --\n", + "│ │ └─AxialPositionalEmbedding: 3-5 [1, 896, 128] 8,640\n", + "├─TextDecoder: 1-2 [1, 58, 89] --\n", + "│ └─Embedding: 2-3 [1, 89, 128] 7,424\n", + "│ └─PositionalEncoding: 2-4 [1, 89, 128] --\n", + "│ │ └─Dropout: 3-6 [1, 89, 128] --\n", + "│ └─Decoder: 2-5 [1, 89, 128] --\n", + "│ │ └─ModuleList: 3-7 -- --\n", + "│ │ │ └─DecoderBlock: 4-4 [1, 89, 128] 525,568\n", + "│ │ │ └─DecoderBlock: 4-5 [1, 89, 128] 525,568\n", + "│ │ │ └─DecoderBlock: 4-6 [1, 89, 128] 525,568\n", + "│ │ │ └─DecoderBlock: 4-7 [1, 89, 128] 525,568\n", + "│ │ │ └─DecoderBlock: 4-8 [1, 89, 128] 525,568\n", + "│ │ │ └─DecoderBlock: 4-9 [1, 89, 128] 525,568\n", + "│ │ │ └─DecoderBlock: 4-10 [1, 89, 128] 525,568\n", + "│ │ │ └─DecoderBlock: 4-11 [1, 89, 128] 525,568\n", + "│ │ │ └─DecoderBlock: 4-12 [1, 89, 128] 525,568\n", + "│ │ │ └─DecoderBlock: 4-13 [1, 89, 128] 525,568\n", + "│ │ └─LayerNorm: 3-8 [1, 89, 128] 256\n", + "│ └─Linear: 2-6 [1, 89, 58] 7,482\n", + "==============================================================================================================\n", + "Total params: 6,574,714\n", + "Trainable params: 6,574,714\n", "Non-trainable params: 0\n", - "Total mult-adds (G): 8.47\n", - "=========================================================================================================\n", + "Total mult-adds (G): 8.45\n", + "==============================================================================================================\n", "Input size (MB): 0.23\n", - "Forward/backward pass size (MB): 442.25\n", - "Params size (MB): 40.78\n", - "Estimated Total Size (MB): 483.26\n", - "=========================================================================================================" + "Forward/backward pass size (MB): 330.38\n", + "Params size (MB): 26.30\n", + "Estimated Total Size (MB): 356.91\n", + "==============================================================================================================" ] }, - "execution_count": 16, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 6d54918..e36a786 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,13 +1,10 @@ """Base network module.""" -from typing import Optional, Tuple, Type +from typing import Type import torch from torch import Tensor, nn from text_recognizer.networks.transformer.decoder import Decoder -from text_recognizer.networks.transformer.embeddings.axial import ( - AxialPositionalEmbeddingImage, -) class ConvTransformer(nn.Module): @@ -15,124 +12,39 @@ class ConvTransformer(nn.Module): def __init__( self, - input_dims: Tuple[int, int, int], - hidden_dim: int, - num_classes: int, - pad_index: Tensor, encoder: Type[nn.Module], decoder: Decoder, - pixel_embedding: AxialPositionalEmbeddingImage, - token_pos_embedding: Type[nn.Module], ) -> None: super().__init__() - self.input_dims = input_dims - self.hidden_dim = hidden_dim - self.num_classes = num_classes - self.pad_index = pad_index self.encoder = encoder self.decoder = decoder - # Token embedding. - self.token_embedding = nn.Embedding( - num_embeddings=self.num_classes, embedding_dim=self.hidden_dim - ) + def encode(self, img: Tensor) -> Tensor: + """Encodes images to latent representation.""" + return self.encoder(img) - # Positional encoding for decoder tokens. - self.token_pos_embedding = token_pos_embedding - self.pixel_embedding = pixel_embedding + def decode(self, tokens: Tensor, img_features: Tensor) -> Tensor: + """Decodes latent images embedding into characters.""" + return self.decoder(tokens, img_features) - # Latent projector for down sampling number of filters and 2d - # positional encoding. - self.conv = nn.Conv2d( - in_channels=self.encoder.out_channels, - out_channels=self.hidden_dim, - kernel_size=1, - ) - - # Output layer - self.to_logits = nn.Linear( - in_features=self.hidden_dim, out_features=self.num_classes - ) - - # Initalize weights for encoder. - self.init_weights() - - def init_weights(self) -> None: - """Initalize weights for decoder network and to_logits.""" - nn.init.kaiming_normal_(self.token_embedding.weight) - - def encode(self, x: Tensor) -> Tensor: - """Encodes an image into a latent feature vector. - - Args: - x (Tensor): Image tensor. - - Shape: - - x: :math: `(B, C, H, W)` - - z: :math: `(B, Sx, E)` - - where Sx is the length of the flattened feature maps projected from - the encoder. E latent dimension for each pixel in the projected - feature maps. - - Returns: - Tensor: A Latent embedding of the image. - """ - z = self.encoder(x) - z = self.conv(z) - z = z + self.pixel_embedding(z) - z = z.flatten(start_dim=2) - - # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] - z = z.permute(0, 2, 1) - return z - - def decode(self, src: Tensor, trg: Tensor) -> Tensor: - """Decodes latent images embedding into word pieces. - - Args: - src (Tensor): Latent images embedding. - trg (Tensor): Word embeddings. - - Shapes: - - z: :math: `(B, Sx, D)` - - context: :math: `(B, Sy)` - - out: :math: `(B, Sy, C)` - - where Sy is the length of the output and C is the number of classes. - - Returns: - Tensor: Sequence of word piece embeddings. - """ - trg = trg.long() - trg_mask = trg != self.pad_index - trg = self.token_embedding(trg) - trg = trg + self.token_pos_embedding(trg) - out = self.decoder(x=trg, context=src, input_mask=trg_mask) - logits = ( - out @ torch.transpose(self.token_embedding.weight.to(trg.dtype), 0, 1) - ).float() - logits = self.to_logits(out) # [B, Sy, C] - logits = logits.permute(0, 2, 1) # [B, C, Sy] - return logits - - def forward(self, x: Tensor, context: Tensor) -> Tensor: + def forward(self, img: Tensor, tokens: Tensor) -> Tensor: """Encodes images into word piece logtis. Args: - x (Tensor): Input image(s). - context (Tensor): Target word embeddings. + img (Tensor): Input image(s). + tokens (Tensor): Target word embeddings. Shapes: - - x: :math: `(B, D, H, W)` - - context: :math: `(B, Sy, C)` + - img: :math: `(B, 1, H, W)` + - tokens: :math: `(B, Sy)` + - logits: :math: `(B, Sy, C)` - where B is the batch size, D is the number of input channels, H is - the image height, W is the image width, and C is the number of classes. + where B is the batch size, H is the image height, W is the image + width, Sy the output length, and C is the number of classes. Returns: Tensor: Sequence of logits. """ - z = self.encode(x) - logits = self.decode(z, context) + img_features = self.encode(img) + logits = self.decode(tokens, img_features) return logits diff --git a/text_recognizer/networks/image_encoder.py b/text_recognizer/networks/image_encoder.py new file mode 100644 index 0000000..b5fd0c5 --- /dev/null +++ b/text_recognizer/networks/image_encoder.py @@ -0,0 +1,45 @@ +"""Encodes images to latent embeddings.""" +from typing import Tuple, Type + +from torch import Tensor, nn + +from text_recognizer.networks.transformer.embeddings.axial import ( + AxialPositionalEmbeddingImage, +) + + +class ImageEncoder(nn.Module): + """Base transformer network.""" + + def __init__( + self, + encoder: Type[nn.Module], + pixel_embedding: AxialPositionalEmbeddingImage, + ) -> None: + super().__init__() + self.encoder = encoder + self.pixel_embedding = pixel_embedding + + def forward(self, img: Tensor) -> Tensor: + """Encodes an image into a latent feature vector. + + Args: + img (Tensor): Image tensor. + + Shape: + - x: :math: `(B, C, H, W)` + - z: :math: `(B, Sx, D)` + + where Sx is the length of the flattened feature maps projected from + the encoder. D latent dimension for each pixel in the projected + feature maps. + + Returns: + Tensor: A Latent embedding of the image. + """ + z = self.encoder(img) + z = z + self.pixel_embedding(z) + z = z.flatten(start_dim=2) + # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] + z = z.permute(0, 2, 1) + return z diff --git a/text_recognizer/networks/text_decoder.py b/text_recognizer/networks/text_decoder.py new file mode 100644 index 0000000..c054b41 --- /dev/null +++ b/text_recognizer/networks/text_decoder.py @@ -0,0 +1,62 @@ +"""Text decoder.""" +from typing import Type + +import torch +from torch import Tensor, nn + +from text_recognizer.networks.transformer.decoder import Decoder + + +class TextDecoder(nn.Module): + """Decoder transformer network.""" + + def __init__( + self, + hidden_dim: int, + num_classes: int, + pad_index: Tensor, + decoder: Decoder, + token_pos_embedding: Type[nn.Module], + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.num_classes = num_classes + self.pad_index = pad_index + self.decoder = decoder + self.token_embedding = nn.Embedding( + num_embeddings=self.num_classes, embedding_dim=self.hidden_dim + ) + self.token_pos_embedding = token_pos_embedding + self.to_logits = nn.Linear( + in_features=self.hidden_dim, out_features=self.num_classes + ) + + def forward(self, tokens: Tensor, img_features: Tensor) -> Tensor: + """Decodes latent images embedding into word pieces. + + Args: + tokens (Tensor): Token indecies. + img_features (Tensor): Latent images embedding. + + Shapes: + - tokens: :math: `(B, Sy)` + - img_features: :math: `(B, Sx, D)` + - logits: :math: `(B, Sy, C)` + + where Sy is the length of the output, C is the number of classes + and D is the hidden dimension. + + Returns: + Tensor: Sequence of logits. + """ + tokens = tokens.long() + mask = tokens != self.pad_index + tokens = self.token_embedding(tokens) + tokens = tokens + self.token_pos_embedding(tokens) + tokens = self.decoder(x=tokens, context=img_features, mask=mask) + logits = ( + tokens @ torch.transpose(self.token_embedding.weight.to(tokens.dtype), 0, 1) + ).float() + logits = self.to_logits(tokens) # [B, Sy, C] + logits = logits.permute(0, 2, 1) # [B, C, Sy] + return logits diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index e0e426c..d32c7d6 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -12,7 +12,7 @@ defaults: tags: [lines] epochs: &epochs 260 ignore_index: &ignore_index 3 -num_classes: &num_classes 57 +num_classes: &num_classes 58 max_output_len: &max_output_len 89 # summary: [[1, 1, 56, 1024], [1, 89]] @@ -35,7 +35,7 @@ callbacks: optimizer: _target_: adan_pytorch.Adan - lr: 1.0e-3 + lr: 3.0e-4 betas: [0.02, 0.08, 0.01] weight_decay: 0.02 @@ -59,73 +59,77 @@ datamodule: network: _target_: text_recognizer.networks.ConvTransformer - input_dims: [1, 1, 56, 1024] - hidden_dim: &hidden_dim 384 - num_classes: 58 - pad_index: 3 encoder: - _target_: text_recognizer.networks.convnext.ConvNext - dim: 16 - dim_mults: [2, 4, 24] - depths: [3, 3, 6] - downsampling_factors: [[2, 2], [2, 2], [2, 2]] - attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.networks.image_encoder.ImageEncoder + encoder: + _target_: text_recognizer.networks.convnext.ConvNext + dim: 16 + dim_mults: [2, 4, 24] + depths: [3, 3, 6] + downsampling_factors: [[2, 2], [2, 2], [2, 2]] attn: - _target_: text_recognizer.networks.convnext.Attention - dim: *hidden_dim - heads: 4 - dim_head: 64 - scale: 8 - ff: - _target_: text_recognizer.networks.convnext.FeedForward - dim: *hidden_dim - mult: 2 + _target_: text_recognizer.networks.convnext.TransformerBlock + attn: + _target_: text_recognizer.networks.convnext.Attention + dim: *hidden_dim + heads: 4 + dim_head: 64 + scale: 8 + ff: + _target_: text_recognizer.networks.convnext.FeedForward + dim: *hidden_dim + mult: 2 + pixel_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + AxialPositionalEmbeddingImage" + dim: &hidden_dim 384 + axial_shape: [7, 128] + axial_dims: [192, 192] decoder: - _target_: text_recognizer.networks.transformer.Decoder - depth: 6 - dim: *hidden_dim - block: - _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock - self_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 8 - dim_head: 64 - dropout_rate: &dropout_rate 0.2 - causal: true - rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding - dim: 64 - cross_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 8 - dim_head: 64 - dropout_rate: *dropout_rate - causal: false - norm: - _target_: text_recognizer.networks.transformer.RMSNorm - dim: *hidden_dim - ff: - _target_: text_recognizer.networks.transformer.FeedForward - dim: *hidden_dim - dim_out: null - expansion_factor: 2 - glu: true - dropout_rate: *dropout_rate - pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ - AxialPositionalEmbeddingImage" - dim: *hidden_dim - axial_shape: [7, 128] - axial_dims: [192, 192] - token_pos_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ - PositionalEncoding" - dim: *hidden_dim - dropout_rate: 0.1 - max_len: 89 + _target_: text_recognizer.networks.text_decoder.TextDecoder + hidden_dim: *hidden_dim + num_classes: *num_classes + pad_index: *ignore_index + decoder: + _target_: text_recognizer.networks.transformer.Decoder + dim: *hidden_dim + depth: 6 + block: + _target_: text_recognizer.networks.transformer.decoder_block.\ + DecoderBlock + self_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 10 + dim_head: 64 + dropout_rate: &dropout_rate 0.2 + causal: true + cross_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 10 + dim_head: 64 + dropout_rate: *dropout_rate + causal: false + norm: + _target_: text_recognizer.networks.transformer.RMSNorm + dim: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 2 + glu: true + dropout_rate: *dropout_rate + rotary_embedding: + _target_: text_recognizer.networks.transformer.RotaryEmbedding + dim: 64 + token_pos_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + PositionalEncoding" + dim: *hidden_dim + dropout_rate: 0.1 + max_len: *max_output_len model: max_output_len: *max_output_len diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index 0ef862f..016adbb 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -1,56 +1,59 @@ _target_: text_recognizer.networks.ConvTransformer -input_dims: [1, 1, 576, 640] -hidden_dim: &hidden_dim 128 -num_classes: 58 -pad_index: 3 encoder: - _target_: text_recognizer.networks.convnext.ConvNext - dim: 16 - dim_mults: [2, 4, 8] - depths: [3, 3, 6] - downsampling_factors: [[2, 2], [2, 2], [2, 2]] + _target_: text_recognizer.networks.image_encoder.ImageEncoder + encoder: + _target_: text_recognizer.networks.convnext.ConvNext + dim: 16 + dim_mults: [2, 4, 8] + depths: [3, 3, 6] + downsampling_factors: [[2, 2], [2, 2], [2, 2]] + pixel_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + AxialPositionalEmbeddingImage" + dim: &hidden_dim 128 + axial_shape: [7, 128] + axial_dims: [64, 64] decoder: - _target_: text_recognizer.networks.transformer.Decoder - dim: *hidden_dim - depth: 10 - block: - _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock - self_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 12 - dim_head: 64 - dropout_rate: &dropout_rate 0.2 - causal: true - rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding - dim: 64 - cross_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 12 - dim_head: 64 - dropout_rate: *dropout_rate - causal: false - norm: - _target_: text_recognizer.networks.transformer.RMSNorm - dim: *hidden_dim - ff: - _target_: text_recognizer.networks.transformer.FeedForward - dim: *hidden_dim - dim_out: null - expansion_factor: 2 - glu: true - dropout_rate: *dropout_rate -pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ - AxialPositionalEmbeddingImage" - dim: *hidden_dim - axial_shape: [7, 128] - axial_dims: [64, 64] -token_pos_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ - PositionalEncoding" - dim: *hidden_dim - dropout_rate: 0.1 - max_len: 89 + _target_: text_recognizer.networks.text_decoder.TextDecoder + hidden_dim: *hidden_dim + num_classes: 58 + pad_index: 3 + decoder: + _target_: text_recognizer.networks.transformer.Decoder + dim: *hidden_dim + depth: 10 + block: + _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock + self_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: &dropout_rate 0.2 + causal: true + cross_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: *dropout_rate + causal: false + norm: + _target_: text_recognizer.networks.transformer.RMSNorm + dim: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 2 + glu: true + dropout_rate: *dropout_rate + rotary_embedding: + _target_: text_recognizer.networks.transformer.RotaryEmbedding + dim: 64 + token_pos_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + PositionalEncoding" + dim: *hidden_dim + dropout_rate: 0.1 + max_len: 89 |