From 4e60c836fb710baceba570c28c06437db3ad5c9b Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 24 Apr 2021 23:09:20 +0200 Subject: Implementing CoaT transformer, continue tomorrow... --- notebooks/00-testing-stuff-out.ipynb | 403 +++++++++++++++------ text_recognizer/networks/coat/__init__.py | 0 text_recognizer/networks/coat/factor_attention.py | 9 + text_recognizer/networks/coat/patch_embedding.py | 38 ++ .../networks/coat/positional_encodings.py | 76 ++++ training/configs/vqvae.yaml | 44 +-- training/run_experiment.py | 2 +- 7 files changed, 431 insertions(+), 141 deletions(-) create mode 100644 text_recognizer/networks/coat/__init__.py create mode 100644 text_recognizer/networks/coat/factor_attention.py create mode 100644 text_recognizer/networks/coat/patch_embedding.py create mode 100644 text_recognizer/networks/coat/positional_encodings.py diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb index d4840ef..e6cf099 100644 --- a/notebooks/00-testing-stuff-out.ipynb +++ b/notebooks/00-testing-stuff-out.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -62,31 +62,37 @@ "seed: 4711\n", "network:\n", " desc: Configuration of the PyTorch neural network.\n", - " type: ImageTransformer\n", + " type: VQVAE\n", " args:\n", " in_channels: 1\n", " channels:\n", - " - 128\n", - " - 64\n", " - 32\n", + " - 64\n", + " - 96\n", + " - 96\n", + " - 128\n", " kernel_sizes:\n", " - 4\n", " - 4\n", " - 4\n", + " - 4\n", + " - 4\n", " strides:\n", " - 2\n", " - 2\n", " - 2\n", - " num_residual_layers: 4\n", + " - 2\n", + " - 2\n", + " num_residual_layers: 2\n", " embedding_dim: 128\n", " num_embeddings: 1024\n", " upsampling: null\n", - " beta: 6.6\n", + " beta: 0.25\n", " activation: leaky_relu\n", - " dropout_rate: 0.25\n", + " dropout_rate: 0.1\n", "model:\n", " desc: Configuration of the PyTorch Lightning model.\n", - " type: LitTransformerModel\n", + " type: LitVQVAEModel\n", " args:\n", " optimizer:\n", " type: MADGRAD\n", @@ -96,18 +102,16 @@ " weight_decay: 0\n", " eps: 1.0e-06\n", " lr_scheduler:\n", - " type: OneCycle\n", + " type: OneCycleLR\n", " args:\n", " interval: step\n", " max_lr: 0.001\n", " three_phase: true\n", - " epochs: 512\n", - " steps_per_epoch: 1246\n", + " epochs: 1024\n", + " steps_per_epoch: 317\n", " criterion:\n", - " type: CrossEntropyLoss\n", + " type: MSELoss\n", " args:\n", - " weight: None\n", - " ignore_index: -100\n", " reduction: mean\n", " monitor: val_loss\n", " mapping: sentence_piece\n", @@ -115,7 +119,7 @@ " desc: Configuration of the training/test data.\n", " type: IAMExtendedParagraphs\n", " args:\n", - " batch_size: 16\n", + " batch_size: 64\n", " num_workers: 12\n", " train_fraction: 0.8\n", " augment: true\n", @@ -125,33 +129,21 @@ " monitor: val_loss\n", " mode: min\n", " save_last: true\n", - "- type: StochasticWeightAveraging\n", - " args:\n", - " swa_epoch_start: 0.8\n", - " swa_lrs: 0.05\n", - " annealing_epochs: 10\n", - " annealing_strategy: cos\n", - " device: null\n", "- type: LearningRateMonitor\n", " args:\n", " logging_interval: step\n", - "- type: EarlyStopping\n", - " args:\n", - " monitor: val_loss\n", - " mode: min\n", - " patience: 10\n", "trainer:\n", " desc: Configuration of the PyTorch Lightning Trainer.\n", " args:\n", - " stochastic_weight_avg: true\n", + " stochastic_weight_avg: false\n", " auto_scale_batch_size: binsearch\n", " gradient_clip_val: 0\n", " fast_dev_run: false\n", " gpus: 1\n", " precision: 16\n", - " max_epochs: 512\n", + " max_epochs: 1024\n", " terminate_on_nan: true\n", - " weights_summary: true\n", + " weights_summary: full\n", "load_checkpoint: null\n", "\n" ] @@ -163,7 +155,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -172,7 +164,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -181,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -194,50 +186,44 @@ " (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", - " (1): Dropout(p=0.25, inplace=False)\n", + " (1): Dropout(p=0.1, inplace=False)\n", " (2): Sequential(\n", " (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", - " (3): Dropout(p=0.25, inplace=False)\n", + " (3): Dropout(p=0.1, inplace=False)\n", " (4): Sequential(\n", - " (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (0): Conv2d(64, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", - " (5): Dropout(p=0.25, inplace=False)\n", - " (6): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", - " )\n", + " (5): Dropout(p=0.1, inplace=False)\n", + " (6): Sequential(\n", + " (0): Conv2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", - " (7): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", - " )\n", + " (7): Dropout(p=0.1, inplace=False)\n", + " (8): Sequential(\n", + " (0): Conv2d(96, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", - " (8): _ResidualBlock(\n", + " (9): Dropout(p=0.1, inplace=False)\n", + " (10): _ResidualBlock(\n", " (block): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", + " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", - " (9): _ResidualBlock(\n", + " (11): _ResidualBlock(\n", " (block): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", + " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", - " (10): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", + " (12): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (vector_quantizer): VectorQuantizer(\n", " (embedding): Embedding(1024, 128)\n", @@ -251,7 +237,7 @@ " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", + " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (2): _ResidualBlock(\n", @@ -259,39 +245,33 @@ " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", - " )\n", - " )\n", - " (3): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", - " )\n", - " )\n", - " (4): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", + " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (upsampling_block): Sequential(\n", " (0): Sequential(\n", - " (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", - " (1): Dropout(p=0.25, inplace=False)\n", + " (1): Dropout(p=0.1, inplace=False)\n", " (2): Sequential(\n", + " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", + " )\n", + " (3): Dropout(p=0.1, inplace=False)\n", + " (4): Sequential(\n", + " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", + " )\n", + " (5): Dropout(p=0.1, inplace=False)\n", + " (6): Sequential(\n", " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", - " (3): Dropout(p=0.25, inplace=False)\n", - " (4): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (5): Tanh()\n", + " (7): Dropout(p=0.1, inplace=False)\n", + " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (9): Tanh()\n", " )\n", " (decoder): Sequential(\n", " (0): Sequential(\n", @@ -301,7 +281,7 @@ " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", + " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (2): _ResidualBlock(\n", @@ -309,46 +289,40 @@ " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", - " )\n", - " )\n", - " (3): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", - " )\n", - " )\n", - " (4): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.25, inplace=False)\n", + " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (1): Sequential(\n", " (0): Sequential(\n", - " (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", - " (1): Dropout(p=0.25, inplace=False)\n", + " (1): Dropout(p=0.1, inplace=False)\n", " (2): Sequential(\n", + " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", + " )\n", + " (3): Dropout(p=0.1, inplace=False)\n", + " (4): Sequential(\n", + " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", + " )\n", + " (5): Dropout(p=0.1, inplace=False)\n", + " (6): Sequential(\n", " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", - " (3): Dropout(p=0.25, inplace=False)\n", - " (4): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (5): Tanh()\n", + " (7): Dropout(p=0.1, inplace=False)\n", + " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (9): Tanh()\n", " )\n", " )\n", " )\n", ")" ] }, - "execution_count": 79, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -359,36 +333,229 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "datum = torch.randn([2, 1, 576, 640])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "proj = nn.Conv2d(1, 32, kernel_size=16, stride=16)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "x = proj(datum)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([1.])" + "torch.Size([2, 32, 36, 40])" ] }, - "execution_count": 80, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "torch.Tensor([1])" + "x.shape" ] }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "datum = torch.randn([2, 1, 576, 640])" + "xx = x.flatten(2)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 32, 1440])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xx.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "xxx = xx.transpose(1,2)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 1440, 32])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xxx.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from einops import rearrange" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "xxxx = rearrange(x, \"b c h w -> b ( h w ) c\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 1440, 32])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xxxx.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + " B, N, C = x.shape\n", + " H, W = size\n", + " assert N == 1 + H * W\n", + "\n", + " # Extract CLS token and image tokens.\n", + " cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C].\n", + " \n", + " # Depthwise convolution.\n", + " feat = img_tokens.transpose(1, 2).view(B, C, H, W)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 32, 36, 40])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xxx.transpose(1, 2).view(2, 32, 36, 40).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "72.0" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "576 / 8" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "80.0" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "640 / 8" ] }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -397,7 +564,7 @@ "torch.Size([2, 1, 576, 640])" ] }, - "execution_count": 82, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -408,16 +575,16 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([2, 128, 72, 80])" + "torch.Size([2, 128, 18, 20])" ] }, - "execution_count": 85, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } diff --git a/text_recognizer/networks/coat/__init__.py b/text_recognizer/networks/coat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text_recognizer/networks/coat/factor_attention.py b/text_recognizer/networks/coat/factor_attention.py new file mode 100644 index 0000000..f91c5ef --- /dev/null +++ b/text_recognizer/networks/coat/factor_attention.py @@ -0,0 +1,9 @@ +"""Factorized attention with convolutional relative positional encodings.""" +from torch import nn + + +class FactorAttention(nn.Module): + """Factorized attention with relative positional encodings.""" + def __init__(self, dim: int, num_heads: int) -> None: + pass + diff --git a/text_recognizer/networks/coat/patch_embedding.py b/text_recognizer/networks/coat/patch_embedding.py new file mode 100644 index 0000000..3b7b76a --- /dev/null +++ b/text_recognizer/networks/coat/patch_embedding.py @@ -0,0 +1,38 @@ +"""Patch embedding for images and feature maps.""" +from typing import Sequence, Tuple + +from einops import rearrange +from loguru import logger +from torch import nn +from torch import Tensor + + +class PatchEmbedding(nn.Module): + """Patch embedding of images.""" + + def __init__( + self, + image_shape: Sequence[int], + patch_size: int = 16, + in_channels: int = 1, + embedding_dim: int = 512, + ) -> None: + if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0: + logger.error( + f"Image shape {image_shape} not divisable by patch size {patch_size}" + ) + + self.patch_size = patch_size + self.embedding = nn.Conv2d( + in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size + ) + self.norm = nn.LayerNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]: + """Embeds image or feature maps with patch embedding.""" + _, _, h, w = x.shape + h_out, w_out = h // self.patch_size, w // self.patch_size + x = self.embedding(x) + x = rearrange(x, "b c h w -> b (h w) c") + x = self.norm(x) + return x, (h_out, w_out) diff --git a/text_recognizer/networks/coat/positional_encodings.py b/text_recognizer/networks/coat/positional_encodings.py new file mode 100644 index 0000000..925db04 --- /dev/null +++ b/text_recognizer/networks/coat/positional_encodings.py @@ -0,0 +1,76 @@ +"""Positional encodings for input sequence to transformer.""" +from typing import Dict, Union, Tuple + +from einops import rearrange +from loguru import logger +import torch +from torch import nn +from torch import Tensor + + +class RelativeEncoding(nn.Module): + """Relative positional encoding.""" + def __init__(self, channels: int, heads: int, windows: Union[int, Dict[int, int]]) -> None: + super().__init__() + self.windows = {windows: heads} if isinstance(windows, int) else windows + self.heads = list(self.windows.values()) + self.channel_heads = [head * channels for head in self.heads] + self.convs = nn.ModuleList([ + nn.Conv2d(in_channels=head * channels, + out_channels=head * channels, + kernel_shape=window, + padding=window // 2, + dilation=1, + groups=head * channels, + ) for window, head in self.windows.items()]) + + def forward(self, q: Tensor, v: Tensor, shape: Tuple[int, int]) -> Tensor: + """Applies relative positional encoding.""" + b, heads, hw, c = q.shape + h, w = shape + if hw != h * w: + logger.exception(f"Query width {hw} neq to height x width {h * w}") + raise ValueError + + v = rearrange(v, "b heads (h w) c -> b (heads c) h w", h=h, w=w) + v = torch.split(v, self.channel_heads, dim=1) + v = [conv(x) for conv, x in zip(self.convs, v)] + v = torch.cat(v, dim=1) + v = rearrange(v, "b (heads c) h w -> b heads (h w) c", heads=heads) + + encoding = q * v + zeros = torch.zeros((b, heads, 1, c), dtype=q.dtype, layout=q.layout, device=q.device) + encoding = torch.cat((zeros, encoding), dim=2) + return encoding + + +class PositionalEncoding(nn.Module): + """Convolutional positional encoding.""" + def __init__(self, dim: int, k: int = 3) -> None: + super().__init__() + self.encode = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=k, stride=1, padding=k//2, groups=dim) + + def forward(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: + """Applies convolutional encoding.""" + _, hw, _ = x.shape + h, w = shape + + if hw != h * w: + logger.exception(f"Query width {hw} neq to height x width {h * w}") + raise ValueError + + # Depthwise convolution. + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.encode(x) + x + x = rearrange(x, "b c h w -> b (h w) c") + return x + + + + + + + + + + diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml index 90082f7..a7acb3a 100644 --- a/training/configs/vqvae.yaml +++ b/training/configs/vqvae.yaml @@ -5,16 +5,16 @@ network: type: VQVAE args: in_channels: 1 - channels: [32, 64, 96] + channels: [32, 64, 64] kernel_sizes: [4, 4, 4] strides: [2, 2, 2] num_residual_layers: 2 - embedding_dim: 64 - num_embeddings: 1024 + embedding_dim: 128 + num_embeddings: 512 upsampling: null beta: 0.25 activation: leaky_relu - dropout_rate: 0.1 + dropout_rate: 0.2 model: desc: Configuration of the PyTorch Lightning model. @@ -33,8 +33,8 @@ model: interval: &interval step max_lr: 1.0e-3 three_phase: true - epochs: 512 - steps_per_epoch: 317 # num_samples / batch_size + epochs: 64 + steps_per_epoch: 633 # num_samples / batch_size criterion: type: MSELoss args: @@ -46,7 +46,7 @@ data: desc: Configuration of the training/test data. type: IAMExtendedParagraphs args: - batch_size: 64 + batch_size: 32 num_workers: 12 train_fraction: 0.8 augment: true @@ -57,33 +57,33 @@ callbacks: monitor: val_loss mode: min save_last: true - # - type: StochasticWeightAveraging - # args: - # swa_epoch_start: 0.8 - # swa_lrs: 0.05 - # annealing_epochs: 10 - # annealing_strategy: cos - # device: null + - type: StochasticWeightAveraging + args: + swa_epoch_start: 0.8 + swa_lrs: 0.05 + annealing_epochs: 10 + annealing_strategy: cos + device: null - type: LearningRateMonitor args: logging_interval: *interval - - type: EarlyStopping - args: - monitor: val_loss - mode: min - patience: 10 + # - type: EarlyStopping + # args: + # monitor: val_loss + # mode: min + # patience: 10 trainer: desc: Configuration of the PyTorch Lightning Trainer. args: - stochastic_weight_avg: false # true + stochastic_weight_avg: true auto_scale_batch_size: binsearch gradient_clip_val: 0 fast_dev_run: false gpus: 1 precision: 16 - max_epochs: 512 + max_epochs: 64 terminate_on_nan: true - weights_summary: full + weights_summary: top load_checkpoint: null diff --git a/training/run_experiment.py b/training/run_experiment.py index e1aae4e..bdefbf0 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -22,7 +22,7 @@ def _create_experiment_dir(config: DictConfig) -> Path: """Creates log directory for experiment.""" log_dir = ( LOGS_DIRNAME - / f"{config.model.type}_{config.network.type}" + / f"{config.model.type}_{config.network.type}".lower() / datetime.now().strftime("%m%d_%H%M%S") ) log_dir.mkdir(parents=True, exist_ok=True) -- cgit v1.2.3-70-g09d2