diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-25 23:32:50 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-25 23:32:50 +0200 |
commit | 9426cc794d8c28a65bbbf5ae5466a0a343078558 (patch) | |
tree | 44e31b0a7c58597d603ac29a693462aae4b6e9b0 | |
parent | 4e60c836fb710baceba570c28c06437db3ad5c9b (diff) |
Efficient net and non working transformer model.
24 files changed, 474 insertions, 637 deletions
diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb index e6cf099..7c7b3a6 100644 --- a/notebooks/00-testing-stuff-out.ipynb +++ b/notebooks/00-testing-stuff-out.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -52,110 +52,16 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "seed: 4711\n", - "network:\n", - " desc: Configuration of the PyTorch neural network.\n", - " type: VQVAE\n", - " args:\n", - " in_channels: 1\n", - " channels:\n", - " - 32\n", - " - 64\n", - " - 96\n", - " - 96\n", - " - 128\n", - " kernel_sizes:\n", - " - 4\n", - " - 4\n", - " - 4\n", - " - 4\n", - " - 4\n", - " strides:\n", - " - 2\n", - " - 2\n", - " - 2\n", - " - 2\n", - " - 2\n", - " num_residual_layers: 2\n", - " embedding_dim: 128\n", - " num_embeddings: 1024\n", - " upsampling: null\n", - " beta: 0.25\n", - " activation: leaky_relu\n", - " dropout_rate: 0.1\n", - "model:\n", - " desc: Configuration of the PyTorch Lightning model.\n", - " type: LitVQVAEModel\n", - " args:\n", - " optimizer:\n", - " type: MADGRAD\n", - " args:\n", - " lr: 0.001\n", - " momentum: 0.9\n", - " weight_decay: 0\n", - " eps: 1.0e-06\n", - " lr_scheduler:\n", - " type: OneCycleLR\n", - " args:\n", - " interval: step\n", - " max_lr: 0.001\n", - " three_phase: true\n", - " epochs: 1024\n", - " steps_per_epoch: 317\n", - " criterion:\n", - " type: MSELoss\n", - " args:\n", - " reduction: mean\n", - " monitor: val_loss\n", - " mapping: sentence_piece\n", - "data:\n", - " desc: Configuration of the training/test data.\n", - " type: IAMExtendedParagraphs\n", - " args:\n", - " batch_size: 64\n", - " num_workers: 12\n", - " train_fraction: 0.8\n", - " augment: true\n", - "callbacks:\n", - "- type: ModelCheckpoint\n", - " args:\n", - " monitor: val_loss\n", - " mode: min\n", - " save_last: true\n", - "- type: LearningRateMonitor\n", - " args:\n", - " logging_interval: step\n", - "trainer:\n", - " desc: Configuration of the PyTorch Lightning Trainer.\n", - " args:\n", - " stochastic_weight_avg: false\n", - " auto_scale_batch_size: binsearch\n", - " gradient_clip_val: 0\n", - " fast_dev_run: false\n", - " gpus: 1\n", - " precision: 16\n", - " max_epochs: 1024\n", - " terminate_on_nan: true\n", - " weights_summary: full\n", - "load_checkpoint: null\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -164,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -173,167 +79,16 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "VQVAE(\n", - " (encoder): Encoder(\n", - " (encoder): Sequential(\n", - " (0): Sequential(\n", - " (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (1): Dropout(p=0.1, inplace=False)\n", - " (2): Sequential(\n", - " (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " (4): Sequential(\n", - " (0): Conv2d(64, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (5): Dropout(p=0.1, inplace=False)\n", - " (6): Sequential(\n", - " (0): Conv2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (7): Dropout(p=0.1, inplace=False)\n", - " (8): Sequential(\n", - " (0): Conv2d(96, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (9): Dropout(p=0.1, inplace=False)\n", - " (10): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (11): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (12): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (vector_quantizer): VectorQuantizer(\n", - " (embedding): Embedding(1024, 128)\n", - " )\n", - " )\n", - " (decoder): Decoder(\n", - " (res_block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (2): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " )\n", - " (upsampling_block): Sequential(\n", - " (0): Sequential(\n", - " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (1): Dropout(p=0.1, inplace=False)\n", - " (2): Sequential(\n", - " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " (4): Sequential(\n", - " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (5): Dropout(p=0.1, inplace=False)\n", - " (6): Sequential(\n", - " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (7): Dropout(p=0.1, inplace=False)\n", - " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (9): Tanh()\n", - " )\n", - " (decoder): Sequential(\n", - " (0): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (2): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " )\n", - " (1): Sequential(\n", - " (0): Sequential(\n", - " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (1): Dropout(p=0.1, inplace=False)\n", - " (2): Sequential(\n", - " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " (4): Sequential(\n", - " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (5): Dropout(p=0.1, inplace=False)\n", - " (6): Sequential(\n", - " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (7): Dropout(p=0.1, inplace=False)\n", - " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (9): Tanh()\n", - " )\n", - " )\n", - " )\n", - ")" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "vae" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -342,275 +97,259 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "proj = nn.Conv2d(1, 32, kernel_size=16, stride=16)" + "vae.encoder(datum)[0].shape" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "x = proj(datum)" + "vae(datum)[0].shape" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 32, 36, 40])" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "x.shape" + "from text_recognizer.networks.backbones.efficientnet import EfficientNet" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "xx = x.flatten(2)" + "en = EfficientNet()" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 3, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 32, 1440])" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "xx.shape" + "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "xxx = xx.transpose(1,2)" + "trg = torch.randint(0, 1000, [2, 682])" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 1440, 32])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "xxx.shape" + "trg.shape" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from einops import rearrange" + "datum = torch.randn([2, 1, 224, 224])" ] }, { "cell_type": "code", - "execution_count": 14, - "metadata": {}, + "execution_count": null, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ - "xxxx = rearrange(x, \"b c h w -> b ( h w ) c\")" + "en(datum).shape" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 5, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 1440, 32])" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "xxxx.shape" + "path = \"../training/configs/cnn_transformer.yaml\"" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - " B, N, C = x.shape\n", - " H, W = size\n", - " assert N == 1 + H * W\n", - "\n", - " # Extract CLS token and image tokens.\n", - " cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C].\n", - " \n", - " # Depthwise convolution.\n", - " feat = img_tokens.transpose(1, 2).view(B, C, H, W)" + "conf = OmegaConf.load(path)" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 7, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "torch.Size([2, 32, 36, 40])" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "seed: 4711\n", + "network:\n", + " desc: Configuration of the PyTorch neural network.\n", + " type: CNNTransformer\n", + " args:\n", + " encoder:\n", + " type: EfficientNet\n", + " args: null\n", + " num_decoder_layers: 4\n", + " hidden_dim: 256\n", + " num_heads: 4\n", + " expansion_dim: 1024\n", + " dropout_rate: 0.1\n", + " transformer_activation: glu\n", + "model:\n", + " desc: Configuration of the PyTorch Lightning model.\n", + " type: LitTransformerModel\n", + " args:\n", + " optimizer:\n", + " type: MADGRAD\n", + " args:\n", + " lr: 0.001\n", + " momentum: 0.9\n", + " weight_decay: 0\n", + " eps: 1.0e-06\n", + " lr_scheduler:\n", + " type: OneCycleLR\n", + " args:\n", + " interval: step\n", + " max_lr: 0.001\n", + " three_phase: true\n", + " epochs: 512\n", + " steps_per_epoch: 1246\n", + " criterion:\n", + " type: CrossEntropyLoss\n", + " args:\n", + " weight: None\n", + " ignore_index: -100\n", + " reduction: mean\n", + " monitor: val_loss\n", + " mapping: sentence_piece\n", + "data:\n", + " desc: Configuration of the training/test data.\n", + " type: IAMExtendedParagraphs\n", + " args:\n", + " batch_size: 16\n", + " num_workers: 12\n", + " train_fraction: 0.8\n", + " augment: true\n", + "callbacks:\n", + "- type: ModelCheckpoint\n", + " args:\n", + " monitor: val_loss\n", + " mode: min\n", + " save_last: true\n", + "- type: StochasticWeightAveraging\n", + " args:\n", + " swa_epoch_start: 0.8\n", + " swa_lrs: 0.05\n", + " annealing_epochs: 10\n", + " annealing_strategy: cos\n", + " device: null\n", + "- type: LearningRateMonitor\n", + " args:\n", + " logging_interval: step\n", + "- type: EarlyStopping\n", + " args:\n", + " monitor: val_loss\n", + " mode: min\n", + " patience: 10\n", + "trainer:\n", + " desc: Configuration of the PyTorch Lightning Trainer.\n", + " args:\n", + " stochastic_weight_avg: true\n", + " auto_scale_batch_size: binsearch\n", + " gradient_clip_val: 0\n", + " fast_dev_run: false\n", + " gpus: 1\n", + " precision: 16\n", + " max_epochs: 512\n", + " terminate_on_nan: true\n", + " weights_summary: true\n", + "load_checkpoint: null\n", + "\n" + ] } ], "source": [ - "xxx.transpose(1, 2).view(2, 32, 36, 40).shape" + "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", - "execution_count": 18, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "72.0" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": 8, + "metadata": {}, + "outputs": [], "source": [ - "576 / 8" + "from text_recognizer.networks.cnn_transformer import CNNTransformer" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 9, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "80.0" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "640 / 8" + "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 1, 576, 640])" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "datum.shape" + "t.encode(datum).shape" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 128, 18, 20])" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "vae.encoder(datum)[0].shape" + "trg.shape" ] }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([2, 1, 576, 640])" + "torch.Size([2, 682, 1004])" ] }, - "execution_count": 87, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "vae(datum)[0].shape" + "t(datum, trg).shape" ] }, { diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index eaf5397..add0b80 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "id": "726ac25b", "metadata": {}, "outputs": [], @@ -56,8 +56,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-04-16 23:01:52.352 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n", - "2021-04-16 23:02:08.521 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:79 - IAM Synthetic dataset steup for stage None\n" + "2021-04-25 23:17:44.177 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n", + "2021-04-25 23:18:00.750 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:79 - IAM Synthetic dataset steup for stage None\n" ] }, { @@ -68,9 +68,9 @@ "Num classes: 84\n", "Dims: (1, 576, 640)\n", "Output dims: (682, 1)\n", - "Train/val/test sizes: 19912, 262, 231\n", - "Train Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0043), tensor(0.0333), tensor(0.8588))\n", - "Train Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(78))\n", + "Train/val/test sizes: 19948, 262, 231\n", + "Train Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0109), tensor(0.0499), tensor(0.8314))\n", + "Train Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(83))\n", "Test Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0372), tensor(0.0767), tensor(0.8118))\n", "Test Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(83))\n", "\n" @@ -86,10 +86,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "42501428", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-04-25 23:18:14.449 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "IAM Paragraphs Dataset\n", + "Num classes: 84\n", + "Input dims: (1, 576, 640)\n", + "Output dims: (682, 1)\n", + "Train/val/test sizes: 1046, 262, 231\n", + "Train Batch x stats: (torch.Size([16, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0393), tensor(0.0924), tensor(1.))\n", + "Train Batch y stats: (torch.Size([16, 682]), torch.int64, tensor(1), tensor(83))\n", + "Test Batch x stats: (torch.Size([16, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0312), tensor(0.0817), tensor(0.9294))\n", + "Test Batch y stats: (torch.Size([16, 682]), torch.int64, tensor(1), tensor(83))\n", + "\n" + ] + } + ], "source": [ "dataset = IAMParagraphs()\n", "dataset.prepare_data()\n", @@ -99,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "0cf22683", "metadata": {}, "outputs": [], @@ -109,6 +133,27 @@ }, { "cell_type": "code", + "execution_count": 6, + "id": "af7747a8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([682])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y.shape" + ] + }, + { + "cell_type": "code", "execution_count": 7, "id": "e7778ae2", "metadata": { diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 2380660..0a30a42 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -19,18 +19,10 @@ class IAMExtendedParagraphs(BaseDataModule): super().__init__(batch_size, num_workers) self.iam_paragraphs = IAMParagraphs( - batch_size, - num_workers, - train_fraction, - augment, - word_pieces, + batch_size, num_workers, train_fraction, augment, word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - batch_size, - num_workers, - train_fraction, - augment, - word_pieces, + batch_size, num_workers, train_fraction, augment, word_pieces, ) self.dims = self.iam_paragraphs.dims diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 62c44f9..24409bc 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -101,7 +101,7 @@ class IAMParagraphs(BaseDataModule): data, targets, transform=get_transform(image_shape=self.dims[1:], augment=augment), - target_transform=get_target_transform(self.word_pieces) + target_transform=get_target_transform(self.word_pieces), ) logger.info(f"Loading IAM paragraph regions and lines for {stage}...") @@ -162,10 +162,7 @@ def get_dataset_properties() -> Dict: "min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines")), }, - "crop_shape": { - "min": crop_shapes.min(axis=0), - "max": crop_shapes.max(axis=0), - }, + "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, "aspect_ratio": { "min": aspect_ratio.min(axis=0), "max": aspect_ratio.max(axis=0), @@ -286,9 +283,7 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com ), transforms.ColorJitter(brightness=(0.8, 1.6)), transforms.RandomAffine( - degrees=1, - shear=(-10, 10), - interpolation=InterpolationMode.BILINEAR, + degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, ), ] else: @@ -296,10 +291,12 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com transforms_list.append(transforms.ToTensor()) return transforms.Compose(transforms_list) + def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]: """Transform emnist characters to word pieces.""" return transforms.Compose([WordPiece()]) if word_pieces else None + def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 4ccc5c2..78e6c05 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -97,7 +97,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): transform=get_transform( image_shape=self.dims[1:], augment=self.augment ), - target_transform=get_target_transform(self.word_pieces) + target_transform=get_target_transform(self.word_pieces), ) def __repr__(self) -> str: diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index f4016ba..190febe 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -58,13 +58,13 @@ class WordPieceMapping(EmnistMapping): def __init__( self, num_features: int = 1000, - tokens: str = "iamdb_1kwp_tokens_1000.txt" , + tokens: str = "iamdb_1kwp_tokens_1000.txt", lexicon: str = "iamdb_1kwp_lex_1000.txt", data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), - extra_symbols: Optional[Sequence[str]] = ("\n", ), + extra_symbols: Optional[Sequence[str]] = ("\n",), ) -> None: super().__init__(extra_symbols) self.wordpiece_processor = self._configure_wordpiece_processor( @@ -90,7 +90,13 @@ class WordPieceMapping(EmnistMapping): extra_symbols: Optional[Sequence[str]], ) -> Preprocessor: data_dir = ( - (Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb") + ( + Path(__file__).resolve().parents[2] + / "data" + / "downloaded" + / "iam" + / "iamdb" + ) if data_dir is None else Path(data_dir) ) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index 8d1bedd..d0f1f35 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -13,7 +13,7 @@ class WordPiece: def __init__( self, num_features: int = 1000, - tokens: str = "iamdb_1kwp_tokens_1000.txt" , + tokens: str = "iamdb_1kwp_tokens_1000.txt", lexicon: str = "iamdb_1kwp_lex_1000.txt", data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, @@ -35,4 +35,4 @@ class WordPiece: self.max_len = max_len def __call__(self, x: Tensor) -> Tensor: - return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len] + return self.mapping.emnist_to_wordpiece_indices(x)[: self.max_len] diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 7dc1352..8dd4db2 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -39,7 +39,7 @@ class LitTransformerModel(LitBaseModel): def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: """Configure mapping.""" # TODO: Fix me!!! - mapping, inverse_mapping, _ = emnist_mapping() + mapping, inverse_mapping, _ = emnist_mapping(["\n"]) start_index = inverse_mapping["<s>"] end_index = inverse_mapping["<e>"] pad_index = inverse_mapping["<p>"] diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 41fd43f..63b43b2 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,2 +1,4 @@ """Network modules""" +from .backbones import EfficientNet from .vqvae import VQVAE +from .cnn_transformer import CNNTransformer diff --git a/text_recognizer/networks/backbones/__init__.py b/text_recognizer/networks/backbones/__init__.py new file mode 100644 index 0000000..25aed0e --- /dev/null +++ b/text_recognizer/networks/backbones/__init__.py @@ -0,0 +1,2 @@ +"""Vision backbones.""" +from .efficientnet import EfficientNet diff --git a/text_recognizer/networks/backbones/efficientnet.py b/text_recognizer/networks/backbones/efficientnet.py new file mode 100644 index 0000000..61dea77 --- /dev/null +++ b/text_recognizer/networks/backbones/efficientnet.py @@ -0,0 +1,145 @@ +"""Efficient net b0 implementation.""" +import torch +from torch import nn +from torch import Tensor + + +class ConvNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + groups: int = 1, + ) -> None: + super().__init__() + self.block = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + ), + nn.BatchNorm2d(num_features=out_channels), + nn.SiLU(inplace=True), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + +class SqueezeExcite(nn.Module): + def __init__(self, in_channels: int, reduce_dim: int) -> None: + super().__init__() + self.se = nn.Sequential( + nn.AdaptiveAvgPool2d(1), # [C, H, W] -> [C, 1, 1] + nn.Conv2d(in_channels=in_channels, out_channels=reduce_dim, kernel_size=1), + nn.SiLU(), + nn.Conv2d(in_channels=reduce_dim, out_channels=in_channels, kernel_size=1), + nn.Sigmoid(), + ) + + def forward(self, x: Tensor) -> Tensor: + return x * self.se(x) + + +class InvertedResidulaBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + expand_ratio: float, + reduction: int = 4, + survival_prob: float = 0.8, + ) -> None: + super().__init__() + self.survival_prob = survival_prob + self.use_residual = in_channels == out_channels and stride == 1 + hidden_dim = in_channels * expand_ratio + self.expand = in_channels != hidden_dim + reduce_dim = in_channels // reduction + + if self.expand: + self.expand_conv = ConvNorm( + in_channels, hidden_dim, kernel_size=3, stride=1, padding=1 + ) + + self.conv = nn.Sequential( + ConvNorm( + hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim + ), + SqueezeExcite(hidden_dim, reduce_dim), + nn.Conv2d( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + bias=False, + ), + nn.BatchNorm2d(num_features=out_channels), + ) + + def stochastic_depth(self, x: Tensor) -> Tensor: + if not self.training: + return x + + binary_tensor = ( + torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob + ) + return torch.div(x, self.survival_prob) * binary_tensor + + def forward(self, x: Tensor) -> Tensor: + out = self.expand_conv(x) if self.expand else x + if self.use_residual: + return self.stochastic_depth(self.conv(out)) + x + return self.conv(out) + + +class EfficientNet(nn.Module): + """Efficient net b0 backbone.""" + + def __init__(self) -> None: + super().__init__() + self.base_model = [ + # expand_ratio, channels, repeats, stride, kernel_size + [1, 16, 1, 1, 3], + [6, 24, 2, 2, 3], + [6, 40, 2, 2, 5], + [6, 80, 3, 2, 3], + [6, 112, 3, 1, 5], + [6, 192, 4, 2, 5], + [6, 320, 1, 1, 3], + ] + + self.backbone = self._build_b0() + + def _build_b0(self) -> nn.Sequential: + in_channels = 32 + layers = [ConvNorm(1, in_channels, 3, stride=2, padding=1)] + + for expand_ratio, out_channels, repeats, stride, kernel_size in self.base_model: + for i in range(repeats): + layers.append( + InvertedResidulaBlock( + in_channels, + out_channels, + expand_ratio=expand_ratio, + stride=stride if i == 0 else 1, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + in_channels = out_channels + layers.append(ConvNorm(in_channels, 256, kernel_size=1, stride=1, padding=0)) + + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(x) diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index e23a15d..d42c29d 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -33,8 +33,8 @@ NUM_WORD_PIECES = 1000 class CNNTransformer(nn.Module): def __init__( self, - input_shape: Sequence[int], - output_shape: Sequence[int], + input_dim: Sequence[int], + output_dims: Sequence[int], encoder: Union[DictConfig, Dict], vocab_size: Optional[int] = None, num_decoder_layers: int = 4, @@ -43,22 +43,29 @@ class CNNTransformer(nn.Module): expansion_dim: int = 1024, dropout_rate: float = 0.1, transformer_activation: str = "glu", + *args, + **kwargs, ) -> None: + super().__init__() self.vocab_size = ( NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size ) + self.pad_index = 3 # TODO: fix me self.hidden_dim = hidden_dim - self.max_output_length = output_shape[0] + self.max_output_length = output_dims[0] # Image backbone self.encoder = self._configure_encoder(encoder) + self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) self.feature_map_encoding = PositionalEncoding2D( - hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] + hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] ) # Target token embedding self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + self.trg_position_encoding = PositionalEncoding( + hidden_dim, dropout_rate, max_len=output_dims[0] + ) # Transformer decoder self.decoder = Decoder( @@ -86,24 +93,25 @@ class CNNTransformer(nn.Module): self.head.weight.data.uniform_(-0.1, 0.1) nn.init.kaiming_normal_( - self.feature_map_encoding.weight.data, + self.encoder_proj.weight.data, a=0, mode="fan_out", nonlinearity="relu", ) - if self.feature_map_encoding.bias is not None: + if self.encoder_proj.bias is not None: _, fan_out = nn.init._calculate_fan_in_and_fan_out( - self.feature_map_encoding.weight.data + self.encoder_proj.weight.data ) bound = 1 / math.sqrt(fan_out) - nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) + nn.init.normal_(self.encoder_proj.bias, -bound, bound) @staticmethod def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: encoder = OmegaConf.create(encoder) + args = encoder.args or {} network_module = importlib.import_module("text_recognizer.networks") encoder_class = getattr(network_module, encoder.type) - return encoder_class(**encoder.args) + return encoder_class(**args) def encode(self, image: Tensor) -> Tensor: """Extracts image features with backbone. @@ -121,6 +129,7 @@ class CNNTransformer(nn.Module): """ # Extract image features. image_features = self.encoder(image) + image_features = self.encoder_proj(image_features) # Add 2d encoding to the feature maps. image_features = self.feature_map_encoding(image_features) @@ -133,11 +142,19 @@ class CNNTransformer(nn.Module): """Decodes image features with transformer decoder.""" trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) + trg = rearrange(trg, "b t d -> t b d") trg = self.trg_position_encoding(trg) + trg = rearrange(trg, "t b d -> b t d") out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) logits = self.head(out) return logits + def forward(self, image: Tensor, trg: Tensor) -> Tensor: + image_features = self.encode(image) + output = self.decode(image_features, trg) + output = rearrange(output, "b t c -> b c t") + return output + def predict(self, image: Tensor) -> Tensor: """Transcribes text in image(s).""" bsz = image.shape[0] diff --git a/text_recognizer/networks/coat/__init__.py b/text_recognizer/networks/coat/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/text_recognizer/networks/coat/__init__.py +++ /dev/null diff --git a/text_recognizer/networks/coat/factor_attention.py b/text_recognizer/networks/coat/factor_attention.py deleted file mode 100644 index f91c5ef..0000000 --- a/text_recognizer/networks/coat/factor_attention.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Factorized attention with convolutional relative positional encodings.""" -from torch import nn - - -class FactorAttention(nn.Module): - """Factorized attention with relative positional encodings.""" - def __init__(self, dim: int, num_heads: int) -> None: - pass - diff --git a/text_recognizer/networks/coat/patch_embedding.py b/text_recognizer/networks/coat/patch_embedding.py deleted file mode 100644 index 3b7b76a..0000000 --- a/text_recognizer/networks/coat/patch_embedding.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Patch embedding for images and feature maps.""" -from typing import Sequence, Tuple - -from einops import rearrange -from loguru import logger -from torch import nn -from torch import Tensor - - -class PatchEmbedding(nn.Module): - """Patch embedding of images.""" - - def __init__( - self, - image_shape: Sequence[int], - patch_size: int = 16, - in_channels: int = 1, - embedding_dim: int = 512, - ) -> None: - if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0: - logger.error( - f"Image shape {image_shape} not divisable by patch size {patch_size}" - ) - - self.patch_size = patch_size - self.embedding = nn.Conv2d( - in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size - ) - self.norm = nn.LayerNorm(embedding_dim) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]: - """Embeds image or feature maps with patch embedding.""" - _, _, h, w = x.shape - h_out, w_out = h // self.patch_size, w // self.patch_size - x = self.embedding(x) - x = rearrange(x, "b c h w -> b (h w) c") - x = self.norm(x) - return x, (h_out, w_out) diff --git a/text_recognizer/networks/coat/positional_encodings.py b/text_recognizer/networks/coat/positional_encodings.py deleted file mode 100644 index 925db04..0000000 --- a/text_recognizer/networks/coat/positional_encodings.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Positional encodings for input sequence to transformer.""" -from typing import Dict, Union, Tuple - -from einops import rearrange -from loguru import logger -import torch -from torch import nn -from torch import Tensor - - -class RelativeEncoding(nn.Module): - """Relative positional encoding.""" - def __init__(self, channels: int, heads: int, windows: Union[int, Dict[int, int]]) -> None: - super().__init__() - self.windows = {windows: heads} if isinstance(windows, int) else windows - self.heads = list(self.windows.values()) - self.channel_heads = [head * channels for head in self.heads] - self.convs = nn.ModuleList([ - nn.Conv2d(in_channels=head * channels, - out_channels=head * channels, - kernel_shape=window, - padding=window // 2, - dilation=1, - groups=head * channels, - ) for window, head in self.windows.items()]) - - def forward(self, q: Tensor, v: Tensor, shape: Tuple[int, int]) -> Tensor: - """Applies relative positional encoding.""" - b, heads, hw, c = q.shape - h, w = shape - if hw != h * w: - logger.exception(f"Query width {hw} neq to height x width {h * w}") - raise ValueError - - v = rearrange(v, "b heads (h w) c -> b (heads c) h w", h=h, w=w) - v = torch.split(v, self.channel_heads, dim=1) - v = [conv(x) for conv, x in zip(self.convs, v)] - v = torch.cat(v, dim=1) - v = rearrange(v, "b (heads c) h w -> b heads (h w) c", heads=heads) - - encoding = q * v - zeros = torch.zeros((b, heads, 1, c), dtype=q.dtype, layout=q.layout, device=q.device) - encoding = torch.cat((zeros, encoding), dim=2) - return encoding - - -class PositionalEncoding(nn.Module): - """Convolutional positional encoding.""" - def __init__(self, dim: int, k: int = 3) -> None: - super().__init__() - self.encode = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=k, stride=1, padding=k//2, groups=dim) - - def forward(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: - """Applies convolutional encoding.""" - _, hw, _ = x.shape - h, w = shape - - if hw != h * w: - logger.exception(f"Query width {hw} neq to height x width {h * w}") - raise ValueError - - # Depthwise convolution. - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) - x = self.encode(x) + x - x = rearrange(x, "b c h w -> b (h w) c") - return x - - - - - - - - - - diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py index da7553d..c33f419 100644 --- a/text_recognizer/networks/residual_network.py +++ b/text_recognizer/networks/residual_network.py @@ -20,11 +20,7 @@ class Conv2dAuto(nn.Conv2d): def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: """3x3 convolution with batch norm.""" - conv3x3 = partial( - Conv2dAuto, - kernel_size=3, - bias=False, - ) + conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) return nn.Sequential( conv3x3(in_channels, out_channels, *args, **kwargs), nn.BatchNorm2d(out_channels), diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py index b10f93a..d7e3d08 100644 --- a/text_recognizer/networks/transducer/transducer.py +++ b/text_recognizer/networks/transducer/transducer.py @@ -392,12 +392,7 @@ def load_transducer_loss( transitions = gtn.load(str(processed_path / transitions)) preprocessor = Preprocessor( - data_dir, - num_features, - tokens_path, - lexicon_path, - use_words, - prepend_wordsep, + data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, ) num_tokens = preprocessor.num_tokens diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index 5874e97..c50afc3 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -33,7 +33,10 @@ class PositionalEncoding(nn.Module): def forward(self, x: Tensor) -> Tensor: """Encodes the tensor with a postional embedding.""" - x = x + self.pe[:, : x.shape[1]] + # [T, B, D] + if x.shape[2] != self.pe.shape[2]: + raise ValueError(f"x shape does not match pe in the 3rd dim.") + x = x + self.pe[: x.shape[0]] return self.dropout(x) @@ -48,6 +51,7 @@ class PositionalEncoding2D(nn.Module): pe = self.make_pe(hidden_dim, max_h, max_w) self.register_buffer("pe", pe) + @staticmethod def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: """Returns 2d postional encoding.""" pe_h = PositionalEncoding.make_pe( diff --git a/text_recognizer/networks/transformer/rotary_embedding.py b/text_recognizer/networks/transformer/rotary_embedding.py new file mode 100644 index 0000000..5e80572 --- /dev/null +++ b/text_recognizer/networks/transformer/rotary_embedding.py @@ -0,0 +1,39 @@ +"""Roatary embedding. + +Stolen from lucidrains: + https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py + +Explanation of roatary: + https://blog.eleuther.ai/rotary-embeddings/ + +""" +from typing import Tuple + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb[None, :, :] + + +def rotate_half(x: Tensor) -> Tensor: + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]: + q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) + return q, k diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 93a1e43..32de912 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -44,12 +44,7 @@ class Decoder(nn.Module): # Configure encoder. self.decoder = self._build_decoder( - channels, - kernel_sizes, - strides, - num_residual_layers, - activation, - dropout, + channels, kernel_sizes, strides, num_residual_layers, activation, dropout, ) def _build_decompression_block( @@ -78,9 +73,7 @@ class Decoder(nn.Module): ) if self.upsampling and i < len(self.upsampling): - modules.append( - nn.Upsample(size=self.upsampling[i]), - ) + modules.append(nn.Upsample(size=self.upsampling[i]),) if dropout is not None: modules.append(dropout) @@ -109,12 +102,7 @@ class Decoder(nn.Module): ) -> nn.Sequential: self.res_block.append( - nn.Conv2d( - self.embedding_dim, - channels[0], - kernel_size=1, - stride=1, - ) + nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) ) # Bottleneck module. diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index b0cceed..65801df 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -11,10 +11,7 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer class _ResidualBlock(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - dropout: Optional[Type[nn.Module]], + self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], ) -> None: super().__init__() self.block = [ @@ -138,12 +135,7 @@ class Encoder(nn.Module): ) encoder.append( - nn.Conv2d( - channels[-1], - self.embedding_dim, - kernel_size=1, - stride=1, - ) + nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) ) return nn.Sequential(*encoder) diff --git a/training/configs/image_transformer.yaml b/training/configs/cnn_transformer.yaml index e6637f2..a4f16df 100644 --- a/training/configs/image_transformer.yaml +++ b/training/configs/cnn_transformer.yaml @@ -2,12 +2,13 @@ seed: 4711 network: desc: Configuration of the PyTorch neural network. - type: ImageTransformer + type: CNNTransformer args: encoder: - type: null + type: EfficientNet args: null num_decoder_layers: 4 + vocab_size: 84 hidden_dim: 256 num_heads: 4 expansion_dim: 1024 @@ -26,7 +27,7 @@ model: weight_decay: 0 eps: 1.0e-6 lr_scheduler: - type: OneCycle + type: OneCycleLR args: interval: &interval step max_lr: 1.0e-3 @@ -36,7 +37,7 @@ model: criterion: type: CrossEntropyLoss args: - weight: None + weight: null ignore_index: -100 reduction: mean monitor: val_loss @@ -46,7 +47,7 @@ data: desc: Configuration of the training/test data. type: IAMExtendedParagraphs args: - batch_size: 16 + batch_size: 8 num_workers: 12 train_fraction: 0.8 augment: true @@ -57,33 +58,33 @@ callbacks: monitor: val_loss mode: min save_last: true - - type: StochasticWeightAveraging - args: - swa_epoch_start: 0.8 - swa_lrs: 0.05 - annealing_epochs: 10 - annealing_strategy: cos - device: null + # - type: StochasticWeightAveraging + # args: + # swa_epoch_start: 0.8 + # swa_lrs: 0.05 + # annealing_epochs: 10 + # annealing_strategy: cos + # device: null - type: LearningRateMonitor args: logging_interval: *interval - - type: EarlyStopping - args: - monitor: val_loss - mode: min - patience: 10 + # - type: EarlyStopping + # args: + # monitor: val_loss + # mode: min + # patience: 10 trainer: desc: Configuration of the PyTorch Lightning Trainer. args: - stochastic_weight_avg: true + stochastic_weight_avg: false auto_scale_batch_size: binsearch gradient_clip_val: 0 - fast_dev_run: false + fast_dev_run: true gpus: 1 precision: 16 max_epochs: 512 terminate_on_nan: true - weights_summary: true + weights_summary: top load_checkpoint: null diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml index a7acb3a..13d7c97 100644 --- a/training/configs/vqvae.yaml +++ b/training/configs/vqvae.yaml @@ -5,12 +5,12 @@ network: type: VQVAE args: in_channels: 1 - channels: [32, 64, 64] - kernel_sizes: [4, 4, 4] - strides: [2, 2, 2] + channels: [32, 64, 64, 96, 96] + kernel_sizes: [4, 4, 4, 4, 4] + strides: [2, 2, 2, 2, 2] num_residual_layers: 2 - embedding_dim: 128 - num_embeddings: 512 + embedding_dim: 512 + num_embeddings: 1024 upsampling: null beta: 0.25 activation: leaky_relu |