summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-24 23:09:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-24 23:09:20 +0200
commit4e60c836fb710baceba570c28c06437db3ad5c9b (patch)
tree21caf6d1792bd83a47fb3d372ee7120211e83f18
parent1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (diff)
Implementing CoaT transformer, continue tomorrow...
-rw-r--r--notebooks/00-testing-stuff-out.ipynb403
-rw-r--r--text_recognizer/networks/coat/__init__.py0
-rw-r--r--text_recognizer/networks/coat/factor_attention.py9
-rw-r--r--text_recognizer/networks/coat/patch_embedding.py38
-rw-r--r--text_recognizer/networks/coat/positional_encodings.py76
-rw-r--r--training/configs/vqvae.yaml44
-rw-r--r--training/run_experiment.py2
7 files changed, 431 insertions, 141 deletions
diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb
index d4840ef..e6cf099 100644
--- a/notebooks/00-testing-stuff-out.ipynb
+++ b/notebooks/00-testing-stuff-out.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -25,7 +25,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@@ -34,7 +34,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@@ -43,7 +43,7 @@
},
{
"cell_type": "code",
- "execution_count": 74,
+ "execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@@ -52,7 +52,7 @@
},
{
"cell_type": "code",
- "execution_count": 75,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -62,31 +62,37 @@
"seed: 4711\n",
"network:\n",
" desc: Configuration of the PyTorch neural network.\n",
- " type: ImageTransformer\n",
+ " type: VQVAE\n",
" args:\n",
" in_channels: 1\n",
" channels:\n",
- " - 128\n",
- " - 64\n",
" - 32\n",
+ " - 64\n",
+ " - 96\n",
+ " - 96\n",
+ " - 128\n",
" kernel_sizes:\n",
" - 4\n",
" - 4\n",
" - 4\n",
+ " - 4\n",
+ " - 4\n",
" strides:\n",
" - 2\n",
" - 2\n",
" - 2\n",
- " num_residual_layers: 4\n",
+ " - 2\n",
+ " - 2\n",
+ " num_residual_layers: 2\n",
" embedding_dim: 128\n",
" num_embeddings: 1024\n",
" upsampling: null\n",
- " beta: 6.6\n",
+ " beta: 0.25\n",
" activation: leaky_relu\n",
- " dropout_rate: 0.25\n",
+ " dropout_rate: 0.1\n",
"model:\n",
" desc: Configuration of the PyTorch Lightning model.\n",
- " type: LitTransformerModel\n",
+ " type: LitVQVAEModel\n",
" args:\n",
" optimizer:\n",
" type: MADGRAD\n",
@@ -96,18 +102,16 @@
" weight_decay: 0\n",
" eps: 1.0e-06\n",
" lr_scheduler:\n",
- " type: OneCycle\n",
+ " type: OneCycleLR\n",
" args:\n",
" interval: step\n",
" max_lr: 0.001\n",
" three_phase: true\n",
- " epochs: 512\n",
- " steps_per_epoch: 1246\n",
+ " epochs: 1024\n",
+ " steps_per_epoch: 317\n",
" criterion:\n",
- " type: CrossEntropyLoss\n",
+ " type: MSELoss\n",
" args:\n",
- " weight: None\n",
- " ignore_index: -100\n",
" reduction: mean\n",
" monitor: val_loss\n",
" mapping: sentence_piece\n",
@@ -115,7 +119,7 @@
" desc: Configuration of the training/test data.\n",
" type: IAMExtendedParagraphs\n",
" args:\n",
- " batch_size: 16\n",
+ " batch_size: 64\n",
" num_workers: 12\n",
" train_fraction: 0.8\n",
" augment: true\n",
@@ -125,33 +129,21 @@
" monitor: val_loss\n",
" mode: min\n",
" save_last: true\n",
- "- type: StochasticWeightAveraging\n",
- " args:\n",
- " swa_epoch_start: 0.8\n",
- " swa_lrs: 0.05\n",
- " annealing_epochs: 10\n",
- " annealing_strategy: cos\n",
- " device: null\n",
"- type: LearningRateMonitor\n",
" args:\n",
" logging_interval: step\n",
- "- type: EarlyStopping\n",
- " args:\n",
- " monitor: val_loss\n",
- " mode: min\n",
- " patience: 10\n",
"trainer:\n",
" desc: Configuration of the PyTorch Lightning Trainer.\n",
" args:\n",
- " stochastic_weight_avg: true\n",
+ " stochastic_weight_avg: false\n",
" auto_scale_batch_size: binsearch\n",
" gradient_clip_val: 0\n",
" fast_dev_run: false\n",
" gpus: 1\n",
" precision: 16\n",
- " max_epochs: 512\n",
+ " max_epochs: 1024\n",
" terminate_on_nan: true\n",
- " weights_summary: true\n",
+ " weights_summary: full\n",
"load_checkpoint: null\n",
"\n"
]
@@ -163,7 +155,7 @@
},
{
"cell_type": "code",
- "execution_count": 76,
+ "execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@@ -172,7 +164,7 @@
},
{
"cell_type": "code",
- "execution_count": 78,
+ "execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
@@ -181,7 +173,7 @@
},
{
"cell_type": "code",
- "execution_count": 79,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
@@ -194,50 +186,44 @@
" (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
" )\n",
- " (1): Dropout(p=0.25, inplace=False)\n",
+ " (1): Dropout(p=0.1, inplace=False)\n",
" (2): Sequential(\n",
" (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
" )\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
" (4): Sequential(\n",
- " (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+ " (0): Conv2d(64, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
" )\n",
- " (5): Dropout(p=0.25, inplace=False)\n",
- " (6): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
- " )\n",
+ " (5): Dropout(p=0.1, inplace=False)\n",
+ " (6): Sequential(\n",
+ " (0): Conv2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+ " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
" )\n",
- " (7): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
- " )\n",
+ " (7): Dropout(p=0.1, inplace=False)\n",
+ " (8): Sequential(\n",
+ " (0): Conv2d(96, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+ " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
" )\n",
- " (8): _ResidualBlock(\n",
+ " (9): Dropout(p=0.1, inplace=False)\n",
+ " (10): _ResidualBlock(\n",
" (block): Sequential(\n",
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): ReLU(inplace=True)\n",
" (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
- " (9): _ResidualBlock(\n",
+ " (11): _ResidualBlock(\n",
" (block): Sequential(\n",
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): ReLU(inplace=True)\n",
" (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
- " (10): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (12): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" (vector_quantizer): VectorQuantizer(\n",
" (embedding): Embedding(1024, 128)\n",
@@ -251,7 +237,7 @@
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): ReLU(inplace=True)\n",
" (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): _ResidualBlock(\n",
@@ -259,39 +245,33 @@
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): ReLU(inplace=True)\n",
" (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
- " )\n",
- " )\n",
- " (3): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
- " )\n",
- " )\n",
- " (4): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (upsampling_block): Sequential(\n",
" (0): Sequential(\n",
- " (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+ " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
" )\n",
- " (1): Dropout(p=0.25, inplace=False)\n",
+ " (1): Dropout(p=0.1, inplace=False)\n",
" (2): Sequential(\n",
+ " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+ " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
+ " )\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
+ " (4): Sequential(\n",
+ " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+ " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
+ " )\n",
+ " (5): Dropout(p=0.1, inplace=False)\n",
+ " (6): Sequential(\n",
" (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
" )\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
- " (4): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (5): Tanh()\n",
+ " (7): Dropout(p=0.1, inplace=False)\n",
+ " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+ " (9): Tanh()\n",
" )\n",
" (decoder): Sequential(\n",
" (0): Sequential(\n",
@@ -301,7 +281,7 @@
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): ReLU(inplace=True)\n",
" (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): _ResidualBlock(\n",
@@ -309,46 +289,40 @@
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): ReLU(inplace=True)\n",
" (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
- " )\n",
- " )\n",
- " (3): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
- " )\n",
- " )\n",
- " (4): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (1): Sequential(\n",
" (0): Sequential(\n",
- " (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+ " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
" )\n",
- " (1): Dropout(p=0.25, inplace=False)\n",
+ " (1): Dropout(p=0.1, inplace=False)\n",
" (2): Sequential(\n",
+ " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+ " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
+ " )\n",
+ " (3): Dropout(p=0.1, inplace=False)\n",
+ " (4): Sequential(\n",
+ " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+ " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
+ " )\n",
+ " (5): Dropout(p=0.1, inplace=False)\n",
+ " (6): Sequential(\n",
" (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
" (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
" )\n",
- " (3): Dropout(p=0.25, inplace=False)\n",
- " (4): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (5): Tanh()\n",
+ " (7): Dropout(p=0.1, inplace=False)\n",
+ " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+ " (9): Tanh()\n",
" )\n",
" )\n",
" )\n",
")"
]
},
- "execution_count": 79,
+ "execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@@ -359,36 +333,229 @@
},
{
"cell_type": "code",
- "execution_count": 80,
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "datum = torch.randn([2, 1, 576, 640])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "proj = nn.Conv2d(1, 32, kernel_size=16, stride=16)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = proj(datum)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "tensor([1.])"
+ "torch.Size([2, 32, 36, 40])"
]
},
- "execution_count": 80,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "torch.Tensor([1])"
+ "x.shape"
]
},
{
"cell_type": "code",
- "execution_count": 81,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
- "datum = torch.randn([2, 1, 576, 640])"
+ "xx = x.flatten(2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 32, 1440])"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xx.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "xxx = xx.transpose(1,2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 1440, 32])"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xxx.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from einops import rearrange"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "xxxx = rearrange(x, \"b c h w -> b ( h w ) c\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 1440, 32])"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xxxx.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ " B, N, C = x.shape\n",
+ " H, W = size\n",
+ " assert N == 1 + H * W\n",
+ "\n",
+ " # Extract CLS token and image tokens.\n",
+ " cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C].\n",
+ " \n",
+ " # Depthwise convolution.\n",
+ " feat = img_tokens.transpose(1, 2).view(B, C, H, W)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 32, 36, 40])"
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xxx.transpose(1, 2).view(2, 32, 36, 40).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "72.0"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "576 / 8"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "80.0"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "640 / 8"
]
},
{
"cell_type": "code",
- "execution_count": 82,
+ "execution_count": 26,
"metadata": {},
"outputs": [
{
@@ -397,7 +564,7 @@
"torch.Size([2, 1, 576, 640])"
]
},
- "execution_count": 82,
+ "execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
@@ -408,16 +575,16 @@
},
{
"cell_type": "code",
- "execution_count": 85,
+ "execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "torch.Size([2, 128, 72, 80])"
+ "torch.Size([2, 128, 18, 20])"
]
},
- "execution_count": 85,
+ "execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
diff --git a/text_recognizer/networks/coat/__init__.py b/text_recognizer/networks/coat/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/text_recognizer/networks/coat/__init__.py
diff --git a/text_recognizer/networks/coat/factor_attention.py b/text_recognizer/networks/coat/factor_attention.py
new file mode 100644
index 0000000..f91c5ef
--- /dev/null
+++ b/text_recognizer/networks/coat/factor_attention.py
@@ -0,0 +1,9 @@
+"""Factorized attention with convolutional relative positional encodings."""
+from torch import nn
+
+
+class FactorAttention(nn.Module):
+ """Factorized attention with relative positional encodings."""
+ def __init__(self, dim: int, num_heads: int) -> None:
+ pass
+
diff --git a/text_recognizer/networks/coat/patch_embedding.py b/text_recognizer/networks/coat/patch_embedding.py
new file mode 100644
index 0000000..3b7b76a
--- /dev/null
+++ b/text_recognizer/networks/coat/patch_embedding.py
@@ -0,0 +1,38 @@
+"""Patch embedding for images and feature maps."""
+from typing import Sequence, Tuple
+
+from einops import rearrange
+from loguru import logger
+from torch import nn
+from torch import Tensor
+
+
+class PatchEmbedding(nn.Module):
+ """Patch embedding of images."""
+
+ def __init__(
+ self,
+ image_shape: Sequence[int],
+ patch_size: int = 16,
+ in_channels: int = 1,
+ embedding_dim: int = 512,
+ ) -> None:
+ if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0:
+ logger.error(
+ f"Image shape {image_shape} not divisable by patch size {patch_size}"
+ )
+
+ self.patch_size = patch_size
+ self.embedding = nn.Conv2d(
+ in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size
+ )
+ self.norm = nn.LayerNorm(embedding_dim)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
+ """Embeds image or feature maps with patch embedding."""
+ _, _, h, w = x.shape
+ h_out, w_out = h // self.patch_size, w // self.patch_size
+ x = self.embedding(x)
+ x = rearrange(x, "b c h w -> b (h w) c")
+ x = self.norm(x)
+ return x, (h_out, w_out)
diff --git a/text_recognizer/networks/coat/positional_encodings.py b/text_recognizer/networks/coat/positional_encodings.py
new file mode 100644
index 0000000..925db04
--- /dev/null
+++ b/text_recognizer/networks/coat/positional_encodings.py
@@ -0,0 +1,76 @@
+"""Positional encodings for input sequence to transformer."""
+from typing import Dict, Union, Tuple
+
+from einops import rearrange
+from loguru import logger
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class RelativeEncoding(nn.Module):
+ """Relative positional encoding."""
+ def __init__(self, channels: int, heads: int, windows: Union[int, Dict[int, int]]) -> None:
+ super().__init__()
+ self.windows = {windows: heads} if isinstance(windows, int) else windows
+ self.heads = list(self.windows.values())
+ self.channel_heads = [head * channels for head in self.heads]
+ self.convs = nn.ModuleList([
+ nn.Conv2d(in_channels=head * channels,
+ out_channels=head * channels,
+ kernel_shape=window,
+ padding=window // 2,
+ dilation=1,
+ groups=head * channels,
+ ) for window, head in self.windows.items()])
+
+ def forward(self, q: Tensor, v: Tensor, shape: Tuple[int, int]) -> Tensor:
+ """Applies relative positional encoding."""
+ b, heads, hw, c = q.shape
+ h, w = shape
+ if hw != h * w:
+ logger.exception(f"Query width {hw} neq to height x width {h * w}")
+ raise ValueError
+
+ v = rearrange(v, "b heads (h w) c -> b (heads c) h w", h=h, w=w)
+ v = torch.split(v, self.channel_heads, dim=1)
+ v = [conv(x) for conv, x in zip(self.convs, v)]
+ v = torch.cat(v, dim=1)
+ v = rearrange(v, "b (heads c) h w -> b heads (h w) c", heads=heads)
+
+ encoding = q * v
+ zeros = torch.zeros((b, heads, 1, c), dtype=q.dtype, layout=q.layout, device=q.device)
+ encoding = torch.cat((zeros, encoding), dim=2)
+ return encoding
+
+
+class PositionalEncoding(nn.Module):
+ """Convolutional positional encoding."""
+ def __init__(self, dim: int, k: int = 3) -> None:
+ super().__init__()
+ self.encode = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=k, stride=1, padding=k//2, groups=dim)
+
+ def forward(self, x: Tensor, shape: Tuple[int, int]) -> Tensor:
+ """Applies convolutional encoding."""
+ _, hw, _ = x.shape
+ h, w = shape
+
+ if hw != h * w:
+ logger.exception(f"Query width {hw} neq to height x width {h * w}")
+ raise ValueError
+
+ # Depthwise convolution.
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.encode(x) + x
+ x = rearrange(x, "b c h w -> b (h w) c")
+ return x
+
+
+
+
+
+
+
+
+
+
diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml
index 90082f7..a7acb3a 100644
--- a/training/configs/vqvae.yaml
+++ b/training/configs/vqvae.yaml
@@ -5,16 +5,16 @@ network:
type: VQVAE
args:
in_channels: 1
- channels: [32, 64, 96]
+ channels: [32, 64, 64]
kernel_sizes: [4, 4, 4]
strides: [2, 2, 2]
num_residual_layers: 2
- embedding_dim: 64
- num_embeddings: 1024
+ embedding_dim: 128
+ num_embeddings: 512
upsampling: null
beta: 0.25
activation: leaky_relu
- dropout_rate: 0.1
+ dropout_rate: 0.2
model:
desc: Configuration of the PyTorch Lightning model.
@@ -33,8 +33,8 @@ model:
interval: &interval step
max_lr: 1.0e-3
three_phase: true
- epochs: 512
- steps_per_epoch: 317 # num_samples / batch_size
+ epochs: 64
+ steps_per_epoch: 633 # num_samples / batch_size
criterion:
type: MSELoss
args:
@@ -46,7 +46,7 @@ data:
desc: Configuration of the training/test data.
type: IAMExtendedParagraphs
args:
- batch_size: 64
+ batch_size: 32
num_workers: 12
train_fraction: 0.8
augment: true
@@ -57,33 +57,33 @@ callbacks:
monitor: val_loss
mode: min
save_last: true
- # - type: StochasticWeightAveraging
- # args:
- # swa_epoch_start: 0.8
- # swa_lrs: 0.05
- # annealing_epochs: 10
- # annealing_strategy: cos
- # device: null
+ - type: StochasticWeightAveraging
+ args:
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
- type: LearningRateMonitor
args:
logging_interval: *interval
- - type: EarlyStopping
- args:
- monitor: val_loss
- mode: min
- patience: 10
+ # - type: EarlyStopping
+ # args:
+ # monitor: val_loss
+ # mode: min
+ # patience: 10
trainer:
desc: Configuration of the PyTorch Lightning Trainer.
args:
- stochastic_weight_avg: false # true
+ stochastic_weight_avg: true
auto_scale_batch_size: binsearch
gradient_clip_val: 0
fast_dev_run: false
gpus: 1
precision: 16
- max_epochs: 512
+ max_epochs: 64
terminate_on_nan: true
- weights_summary: full
+ weights_summary: top
load_checkpoint: null
diff --git a/training/run_experiment.py b/training/run_experiment.py
index e1aae4e..bdefbf0 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -22,7 +22,7 @@ def _create_experiment_dir(config: DictConfig) -> Path:
"""Creates log directory for experiment."""
log_dir = (
LOGS_DIRNAME
- / f"{config.model.type}_{config.network.type}"
+ / f"{config.model.type}_{config.network.type}".lower()
/ datetime.now().strftime("%m%d_%H%M%S")
)
log_dir.mkdir(parents=True, exist_ok=True)