summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-25 23:32:50 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-25 23:32:50 +0200
commit9426cc794d8c28a65bbbf5ae5466a0a343078558 (patch)
tree44e31b0a7c58597d603ac29a693462aae4b6e9b0
parent4e60c836fb710baceba570c28c06437db3ad5c9b (diff)
Efficient net and non working transformer model.
-rw-r--r--notebooks/00-testing-stuff-out.ipynb555
-rw-r--r--notebooks/03-look-at-iam-paragraphs.ipynb63
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py12
-rw-r--r--text_recognizer/data/iam_paragraphs.py13
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py2
-rw-r--r--text_recognizer/data/mappings.py12
-rw-r--r--text_recognizer/data/transforms.py4
-rw-r--r--text_recognizer/models/transformer.py2
-rw-r--r--text_recognizer/networks/__init__.py2
-rw-r--r--text_recognizer/networks/backbones/__init__.py2
-rw-r--r--text_recognizer/networks/backbones/efficientnet.py145
-rw-r--r--text_recognizer/networks/cnn_transformer.py37
-rw-r--r--text_recognizer/networks/coat/__init__.py0
-rw-r--r--text_recognizer/networks/coat/factor_attention.py9
-rw-r--r--text_recognizer/networks/coat/patch_embedding.py38
-rw-r--r--text_recognizer/networks/coat/positional_encodings.py76
-rw-r--r--text_recognizer/networks/residual_network.py6
-rw-r--r--text_recognizer/networks/transducer/transducer.py7
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py6
-rw-r--r--text_recognizer/networks/transformer/rotary_embedding.py39
-rw-r--r--text_recognizer/networks/vqvae/decoder.py18
-rw-r--r--text_recognizer/networks/vqvae/encoder.py12
-rw-r--r--training/configs/cnn_transformer.yaml (renamed from training/configs/image_transformer.yaml)41
-rw-r--r--training/configs/vqvae.yaml10
24 files changed, 474 insertions, 637 deletions
diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb
index e6cf099..7c7b3a6 100644
--- a/notebooks/00-testing-stuff-out.ipynb
+++ b/notebooks/00-testing-stuff-out.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -25,7 +25,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -34,7 +34,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -43,7 +43,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -52,110 +52,16 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "seed: 4711\n",
- "network:\n",
- " desc: Configuration of the PyTorch neural network.\n",
- " type: VQVAE\n",
- " args:\n",
- " in_channels: 1\n",
- " channels:\n",
- " - 32\n",
- " - 64\n",
- " - 96\n",
- " - 96\n",
- " - 128\n",
- " kernel_sizes:\n",
- " - 4\n",
- " - 4\n",
- " - 4\n",
- " - 4\n",
- " - 4\n",
- " strides:\n",
- " - 2\n",
- " - 2\n",
- " - 2\n",
- " - 2\n",
- " - 2\n",
- " num_residual_layers: 2\n",
- " embedding_dim: 128\n",
- " num_embeddings: 1024\n",
- " upsampling: null\n",
- " beta: 0.25\n",
- " activation: leaky_relu\n",
- " dropout_rate: 0.1\n",
- "model:\n",
- " desc: Configuration of the PyTorch Lightning model.\n",
- " type: LitVQVAEModel\n",
- " args:\n",
- " optimizer:\n",
- " type: MADGRAD\n",
- " args:\n",
- " lr: 0.001\n",
- " momentum: 0.9\n",
- " weight_decay: 0\n",
- " eps: 1.0e-06\n",
- " lr_scheduler:\n",
- " type: OneCycleLR\n",
- " args:\n",
- " interval: step\n",
- " max_lr: 0.001\n",
- " three_phase: true\n",
- " epochs: 1024\n",
- " steps_per_epoch: 317\n",
- " criterion:\n",
- " type: MSELoss\n",
- " args:\n",
- " reduction: mean\n",
- " monitor: val_loss\n",
- " mapping: sentence_piece\n",
- "data:\n",
- " desc: Configuration of the training/test data.\n",
- " type: IAMExtendedParagraphs\n",
- " args:\n",
- " batch_size: 64\n",
- " num_workers: 12\n",
- " train_fraction: 0.8\n",
- " augment: true\n",
- "callbacks:\n",
- "- type: ModelCheckpoint\n",
- " args:\n",
- " monitor: val_loss\n",
- " mode: min\n",
- " save_last: true\n",
- "- type: LearningRateMonitor\n",
- " args:\n",
- " logging_interval: step\n",
- "trainer:\n",
- " desc: Configuration of the PyTorch Lightning Trainer.\n",
- " args:\n",
- " stochastic_weight_avg: false\n",
- " auto_scale_batch_size: binsearch\n",
- " gradient_clip_val: 0\n",
- " fast_dev_run: false\n",
- " gpus: 1\n",
- " precision: 16\n",
- " max_epochs: 1024\n",
- " terminate_on_nan: true\n",
- " weights_summary: full\n",
- "load_checkpoint: null\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"print(OmegaConf.to_yaml(conf))"
]
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -164,7 +70,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -173,167 +79,16 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "VQVAE(\n",
- " (encoder): Encoder(\n",
- " (encoder): Sequential(\n",
- " (0): Sequential(\n",
- " (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (1): Dropout(p=0.1, inplace=False)\n",
- " (2): Sequential(\n",
- " (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " (4): Sequential(\n",
- " (0): Conv2d(64, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (5): Dropout(p=0.1, inplace=False)\n",
- " (6): Sequential(\n",
- " (0): Conv2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (7): Dropout(p=0.1, inplace=False)\n",
- " (8): Sequential(\n",
- " (0): Conv2d(96, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (9): Dropout(p=0.1, inplace=False)\n",
- " (10): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " (11): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " (12): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (vector_quantizer): VectorQuantizer(\n",
- " (embedding): Embedding(1024, 128)\n",
- " )\n",
- " )\n",
- " (decoder): Decoder(\n",
- " (res_block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " (2): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " (upsampling_block): Sequential(\n",
- " (0): Sequential(\n",
- " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (1): Dropout(p=0.1, inplace=False)\n",
- " (2): Sequential(\n",
- " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " (4): Sequential(\n",
- " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (5): Dropout(p=0.1, inplace=False)\n",
- " (6): Sequential(\n",
- " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (7): Dropout(p=0.1, inplace=False)\n",
- " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (9): Tanh()\n",
- " )\n",
- " (decoder): Sequential(\n",
- " (0): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " (2): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " (1): Sequential(\n",
- " (0): Sequential(\n",
- " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (1): Dropout(p=0.1, inplace=False)\n",
- " (2): Sequential(\n",
- " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " (4): Sequential(\n",
- " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (5): Dropout(p=0.1, inplace=False)\n",
- " (6): Sequential(\n",
- " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (7): Dropout(p=0.1, inplace=False)\n",
- " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (9): Tanh()\n",
- " )\n",
- " )\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"vae"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -342,275 +97,259 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "proj = nn.Conv2d(1, 32, kernel_size=16, stride=16)"
+ "vae.encoder(datum)[0].shape"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "x = proj(datum)"
+ "vae(datum)[0].shape"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 32, 36, 40])"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "x.shape"
+ "from text_recognizer.networks.backbones.efficientnet import EfficientNet"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "xx = x.flatten(2)"
+ "en = EfficientNet()"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 3,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 32, 1440])"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "xx.shape"
+ "datum = torch.randn([2, 1, 576, 640])"
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
- "xxx = xx.transpose(1,2)"
+ "trg = torch.randint(0, 1000, [2, 682])"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 1440, 32])"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "xxx.shape"
+ "trg.shape"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "from einops import rearrange"
+ "datum = torch.randn([2, 1, 224, 224])"
]
},
{
"cell_type": "code",
- "execution_count": 14,
- "metadata": {},
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [],
"source": [
- "xxxx = rearrange(x, \"b c h w -> b ( h w ) c\")"
+ "en(datum).shape"
]
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 5,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 1440, 32])"
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "xxxx.shape"
+ "path = \"../training/configs/cnn_transformer.yaml\""
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
- " B, N, C = x.shape\n",
- " H, W = size\n",
- " assert N == 1 + H * W\n",
- "\n",
- " # Extract CLS token and image tokens.\n",
- " cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C].\n",
- " \n",
- " # Depthwise convolution.\n",
- " feat = img_tokens.transpose(1, 2).view(B, C, H, W)"
+ "conf = OmegaConf.load(path)"
]
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
- "data": {
- "text/plain": [
- "torch.Size([2, 32, 36, 40])"
- ]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "seed: 4711\n",
+ "network:\n",
+ " desc: Configuration of the PyTorch neural network.\n",
+ " type: CNNTransformer\n",
+ " args:\n",
+ " encoder:\n",
+ " type: EfficientNet\n",
+ " args: null\n",
+ " num_decoder_layers: 4\n",
+ " hidden_dim: 256\n",
+ " num_heads: 4\n",
+ " expansion_dim: 1024\n",
+ " dropout_rate: 0.1\n",
+ " transformer_activation: glu\n",
+ "model:\n",
+ " desc: Configuration of the PyTorch Lightning model.\n",
+ " type: LitTransformerModel\n",
+ " args:\n",
+ " optimizer:\n",
+ " type: MADGRAD\n",
+ " args:\n",
+ " lr: 0.001\n",
+ " momentum: 0.9\n",
+ " weight_decay: 0\n",
+ " eps: 1.0e-06\n",
+ " lr_scheduler:\n",
+ " type: OneCycleLR\n",
+ " args:\n",
+ " interval: step\n",
+ " max_lr: 0.001\n",
+ " three_phase: true\n",
+ " epochs: 512\n",
+ " steps_per_epoch: 1246\n",
+ " criterion:\n",
+ " type: CrossEntropyLoss\n",
+ " args:\n",
+ " weight: None\n",
+ " ignore_index: -100\n",
+ " reduction: mean\n",
+ " monitor: val_loss\n",
+ " mapping: sentence_piece\n",
+ "data:\n",
+ " desc: Configuration of the training/test data.\n",
+ " type: IAMExtendedParagraphs\n",
+ " args:\n",
+ " batch_size: 16\n",
+ " num_workers: 12\n",
+ " train_fraction: 0.8\n",
+ " augment: true\n",
+ "callbacks:\n",
+ "- type: ModelCheckpoint\n",
+ " args:\n",
+ " monitor: val_loss\n",
+ " mode: min\n",
+ " save_last: true\n",
+ "- type: StochasticWeightAveraging\n",
+ " args:\n",
+ " swa_epoch_start: 0.8\n",
+ " swa_lrs: 0.05\n",
+ " annealing_epochs: 10\n",
+ " annealing_strategy: cos\n",
+ " device: null\n",
+ "- type: LearningRateMonitor\n",
+ " args:\n",
+ " logging_interval: step\n",
+ "- type: EarlyStopping\n",
+ " args:\n",
+ " monitor: val_loss\n",
+ " mode: min\n",
+ " patience: 10\n",
+ "trainer:\n",
+ " desc: Configuration of the PyTorch Lightning Trainer.\n",
+ " args:\n",
+ " stochastic_weight_avg: true\n",
+ " auto_scale_batch_size: binsearch\n",
+ " gradient_clip_val: 0\n",
+ " fast_dev_run: false\n",
+ " gpus: 1\n",
+ " precision: 16\n",
+ " max_epochs: 512\n",
+ " terminate_on_nan: true\n",
+ " weights_summary: true\n",
+ "load_checkpoint: null\n",
+ "\n"
+ ]
}
],
"source": [
- "xxx.transpose(1, 2).view(2, 32, 36, 40).shape"
+ "print(OmegaConf.to_yaml(conf))"
]
},
{
"cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "72.0"
- ]
- },
- "execution_count": 18,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
"source": [
- "576 / 8"
+ "from text_recognizer.networks.cnn_transformer import CNNTransformer"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 9,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "80.0"
- ]
- },
- "execution_count": 19,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "640 / 8"
+ "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)"
]
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 1, 576, 640])"
- ]
- },
- "execution_count": 26,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "datum.shape"
+ "t.encode(datum).shape"
]
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 128, 18, 20])"
- ]
- },
- "execution_count": 27,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "vae.encoder(datum)[0].shape"
+ "trg.shape"
]
},
{
"cell_type": "code",
- "execution_count": 87,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "torch.Size([2, 1, 576, 640])"
+ "torch.Size([2, 682, 1004])"
]
},
- "execution_count": 87,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "vae(datum)[0].shape"
+ "t(datum, trg).shape"
]
},
{
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb
index eaf5397..add0b80 100644
--- a/notebooks/03-look-at-iam-paragraphs.ipynb
+++ b/notebooks/03-look-at-iam-paragraphs.ipynb
@@ -31,7 +31,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 2,
"id": "726ac25b",
"metadata": {},
"outputs": [],
@@ -56,8 +56,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2021-04-16 23:01:52.352 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n",
- "2021-04-16 23:02:08.521 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:79 - IAM Synthetic dataset steup for stage None\n"
+ "2021-04-25 23:17:44.177 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n",
+ "2021-04-25 23:18:00.750 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:79 - IAM Synthetic dataset steup for stage None\n"
]
},
{
@@ -68,9 +68,9 @@
"Num classes: 84\n",
"Dims: (1, 576, 640)\n",
"Output dims: (682, 1)\n",
- "Train/val/test sizes: 19912, 262, 231\n",
- "Train Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0043), tensor(0.0333), tensor(0.8588))\n",
- "Train Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(78))\n",
+ "Train/val/test sizes: 19948, 262, 231\n",
+ "Train Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0109), tensor(0.0499), tensor(0.8314))\n",
+ "Train Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(83))\n",
"Test Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0372), tensor(0.0767), tensor(0.8118))\n",
"Test Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(83))\n",
"\n"
@@ -86,10 +86,34 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"id": "42501428",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-04-25 23:18:14.449 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IAM Paragraphs Dataset\n",
+ "Num classes: 84\n",
+ "Input dims: (1, 576, 640)\n",
+ "Output dims: (682, 1)\n",
+ "Train/val/test sizes: 1046, 262, 231\n",
+ "Train Batch x stats: (torch.Size([16, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0393), tensor(0.0924), tensor(1.))\n",
+ "Train Batch y stats: (torch.Size([16, 682]), torch.int64, tensor(1), tensor(83))\n",
+ "Test Batch x stats: (torch.Size([16, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0312), tensor(0.0817), tensor(0.9294))\n",
+ "Test Batch y stats: (torch.Size([16, 682]), torch.int64, tensor(1), tensor(83))\n",
+ "\n"
+ ]
+ }
+ ],
"source": [
"dataset = IAMParagraphs()\n",
"dataset.prepare_data()\n",
@@ -99,7 +123,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"id": "0cf22683",
"metadata": {},
"outputs": [],
@@ -109,6 +133,27 @@
},
{
"cell_type": "code",
+ "execution_count": 6,
+ "id": "af7747a8",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([682])"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 7,
"id": "e7778ae2",
"metadata": {
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 2380660..0a30a42 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -19,18 +19,10 @@ class IAMExtendedParagraphs(BaseDataModule):
super().__init__(batch_size, num_workers)
self.iam_paragraphs = IAMParagraphs(
- batch_size,
- num_workers,
- train_fraction,
- augment,
- word_pieces,
+ batch_size, num_workers, train_fraction, augment, word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- batch_size,
- num_workers,
- train_fraction,
- augment,
- word_pieces,
+ batch_size, num_workers, train_fraction, augment, word_pieces,
)
self.dims = self.iam_paragraphs.dims
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 62c44f9..24409bc 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -101,7 +101,7 @@ class IAMParagraphs(BaseDataModule):
data,
targets,
transform=get_transform(image_shape=self.dims[1:], augment=augment),
- target_transform=get_target_transform(self.word_pieces)
+ target_transform=get_target_transform(self.word_pieces),
)
logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
@@ -162,10 +162,7 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {
- "min": crop_shapes.min(axis=0),
- "max": crop_shapes.max(axis=0),
- },
+ "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -286,9 +283,7 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
),
transforms.ColorJitter(brightness=(0.8, 1.6)),
transforms.RandomAffine(
- degrees=1,
- shear=(-10, 10),
- interpolation=InterpolationMode.BILINEAR,
+ degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
),
]
else:
@@ -296,10 +291,12 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
transforms_list.append(transforms.ToTensor())
return transforms.Compose(transforms_list)
+
def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]:
"""Transform emnist characters to word pieces."""
return transforms.Compose([WordPiece()]) if word_pieces else None
+
def _labels_filename(split: str) -> Path:
"""Return filename of processed labels."""
return PROCESSED_DATA_DIRNAME / split / "_labels.json"
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 4ccc5c2..78e6c05 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -97,7 +97,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
transform=get_transform(
image_shape=self.dims[1:], augment=self.augment
),
- target_transform=get_target_transform(self.word_pieces)
+ target_transform=get_target_transform(self.word_pieces),
)
def __repr__(self) -> str:
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index f4016ba..190febe 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -58,13 +58,13 @@ class WordPieceMapping(EmnistMapping):
def __init__(
self,
num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt" ,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt",
lexicon: str = "iamdb_1kwp_lex_1000.txt",
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n", ),
+ extra_symbols: Optional[Sequence[str]] = ("\n",),
) -> None:
super().__init__(extra_symbols)
self.wordpiece_processor = self._configure_wordpiece_processor(
@@ -90,7 +90,13 @@ class WordPieceMapping(EmnistMapping):
extra_symbols: Optional[Sequence[str]],
) -> Preprocessor:
data_dir = (
- (Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb")
+ (
+ Path(__file__).resolve().parents[2]
+ / "data"
+ / "downloaded"
+ / "iam"
+ / "iamdb"
+ )
if data_dir is None
else Path(data_dir)
)
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 8d1bedd..d0f1f35 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -13,7 +13,7 @@ class WordPiece:
def __init__(
self,
num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt" ,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt",
lexicon: str = "iamdb_1kwp_lex_1000.txt",
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
@@ -35,4 +35,4 @@ class WordPiece:
self.max_len = max_len
def __call__(self, x: Tensor) -> Tensor:
- return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len]
+ return self.mapping.emnist_to_wordpiece_indices(x)[: self.max_len]
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 7dc1352..8dd4db2 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -39,7 +39,7 @@ class LitTransformerModel(LitBaseModel):
def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]:
"""Configure mapping."""
# TODO: Fix me!!!
- mapping, inverse_mapping, _ = emnist_mapping()
+ mapping, inverse_mapping, _ = emnist_mapping(["\n"])
start_index = inverse_mapping["<s>"]
end_index = inverse_mapping["<e>"]
pad_index = inverse_mapping["<p>"]
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index 41fd43f..63b43b2 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -1,2 +1,4 @@
"""Network modules"""
+from .backbones import EfficientNet
from .vqvae import VQVAE
+from .cnn_transformer import CNNTransformer
diff --git a/text_recognizer/networks/backbones/__init__.py b/text_recognizer/networks/backbones/__init__.py
new file mode 100644
index 0000000..25aed0e
--- /dev/null
+++ b/text_recognizer/networks/backbones/__init__.py
@@ -0,0 +1,2 @@
+"""Vision backbones."""
+from .efficientnet import EfficientNet
diff --git a/text_recognizer/networks/backbones/efficientnet.py b/text_recognizer/networks/backbones/efficientnet.py
new file mode 100644
index 0000000..61dea77
--- /dev/null
+++ b/text_recognizer/networks/backbones/efficientnet.py
@@ -0,0 +1,145 @@
+"""Efficient net b0 implementation."""
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class ConvNorm(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ padding: int,
+ groups: int = 1,
+ ) -> None:
+ super().__init__()
+ self.block = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=False,
+ ),
+ nn.BatchNorm2d(num_features=out_channels),
+ nn.SiLU(inplace=True),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.block(x)
+
+
+class SqueezeExcite(nn.Module):
+ def __init__(self, in_channels: int, reduce_dim: int) -> None:
+ super().__init__()
+ self.se = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1), # [C, H, W] -> [C, 1, 1]
+ nn.Conv2d(in_channels=in_channels, out_channels=reduce_dim, kernel_size=1),
+ nn.SiLU(),
+ nn.Conv2d(in_channels=reduce_dim, out_channels=in_channels, kernel_size=1),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x * self.se(x)
+
+
+class InvertedResidulaBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ padding: int,
+ expand_ratio: float,
+ reduction: int = 4,
+ survival_prob: float = 0.8,
+ ) -> None:
+ super().__init__()
+ self.survival_prob = survival_prob
+ self.use_residual = in_channels == out_channels and stride == 1
+ hidden_dim = in_channels * expand_ratio
+ self.expand = in_channels != hidden_dim
+ reduce_dim = in_channels // reduction
+
+ if self.expand:
+ self.expand_conv = ConvNorm(
+ in_channels, hidden_dim, kernel_size=3, stride=1, padding=1
+ )
+
+ self.conv = nn.Sequential(
+ ConvNorm(
+ hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim
+ ),
+ SqueezeExcite(hidden_dim, reduce_dim),
+ nn.Conv2d(
+ in_channels=hidden_dim,
+ out_channels=out_channels,
+ kernel_size=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(num_features=out_channels),
+ )
+
+ def stochastic_depth(self, x: Tensor) -> Tensor:
+ if not self.training:
+ return x
+
+ binary_tensor = (
+ torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob
+ )
+ return torch.div(x, self.survival_prob) * binary_tensor
+
+ def forward(self, x: Tensor) -> Tensor:
+ out = self.expand_conv(x) if self.expand else x
+ if self.use_residual:
+ return self.stochastic_depth(self.conv(out)) + x
+ return self.conv(out)
+
+
+class EfficientNet(nn.Module):
+ """Efficient net b0 backbone."""
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.base_model = [
+ # expand_ratio, channels, repeats, stride, kernel_size
+ [1, 16, 1, 1, 3],
+ [6, 24, 2, 2, 3],
+ [6, 40, 2, 2, 5],
+ [6, 80, 3, 2, 3],
+ [6, 112, 3, 1, 5],
+ [6, 192, 4, 2, 5],
+ [6, 320, 1, 1, 3],
+ ]
+
+ self.backbone = self._build_b0()
+
+ def _build_b0(self) -> nn.Sequential:
+ in_channels = 32
+ layers = [ConvNorm(1, in_channels, 3, stride=2, padding=1)]
+
+ for expand_ratio, out_channels, repeats, stride, kernel_size in self.base_model:
+ for i in range(repeats):
+ layers.append(
+ InvertedResidulaBlock(
+ in_channels,
+ out_channels,
+ expand_ratio=expand_ratio,
+ stride=stride if i == 0 else 1,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+ in_channels = out_channels
+ layers.append(ConvNorm(in_channels, 256, kernel_size=1, stride=1, padding=0))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.backbone(x)
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
index e23a15d..d42c29d 100644
--- a/text_recognizer/networks/cnn_transformer.py
+++ b/text_recognizer/networks/cnn_transformer.py
@@ -33,8 +33,8 @@ NUM_WORD_PIECES = 1000
class CNNTransformer(nn.Module):
def __init__(
self,
- input_shape: Sequence[int],
- output_shape: Sequence[int],
+ input_dim: Sequence[int],
+ output_dims: Sequence[int],
encoder: Union[DictConfig, Dict],
vocab_size: Optional[int] = None,
num_decoder_layers: int = 4,
@@ -43,22 +43,29 @@ class CNNTransformer(nn.Module):
expansion_dim: int = 1024,
dropout_rate: float = 0.1,
transformer_activation: str = "glu",
+ *args,
+ **kwargs,
) -> None:
+ super().__init__()
self.vocab_size = (
NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
)
+ self.pad_index = 3 # TODO: fix me
self.hidden_dim = hidden_dim
- self.max_output_length = output_shape[0]
+ self.max_output_length = output_dims[0]
# Image backbone
self.encoder = self._configure_encoder(encoder)
+ self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1)
self.feature_map_encoding = PositionalEncoding2D(
- hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
+ hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2]
)
# Target token embedding
self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
- self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ self.trg_position_encoding = PositionalEncoding(
+ hidden_dim, dropout_rate, max_len=output_dims[0]
+ )
# Transformer decoder
self.decoder = Decoder(
@@ -86,24 +93,25 @@ class CNNTransformer(nn.Module):
self.head.weight.data.uniform_(-0.1, 0.1)
nn.init.kaiming_normal_(
- self.feature_map_encoding.weight.data,
+ self.encoder_proj.weight.data,
a=0,
mode="fan_out",
nonlinearity="relu",
)
- if self.feature_map_encoding.bias is not None:
+ if self.encoder_proj.bias is not None:
_, fan_out = nn.init._calculate_fan_in_and_fan_out(
- self.feature_map_encoding.weight.data
+ self.encoder_proj.weight.data
)
bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
+ nn.init.normal_(self.encoder_proj.bias, -bound, bound)
@staticmethod
def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
encoder = OmegaConf.create(encoder)
+ args = encoder.args or {}
network_module = importlib.import_module("text_recognizer.networks")
encoder_class = getattr(network_module, encoder.type)
- return encoder_class(**encoder.args)
+ return encoder_class(**args)
def encode(self, image: Tensor) -> Tensor:
"""Extracts image features with backbone.
@@ -121,6 +129,7 @@ class CNNTransformer(nn.Module):
"""
# Extract image features.
image_features = self.encoder(image)
+ image_features = self.encoder_proj(image_features)
# Add 2d encoding to the feature maps.
image_features = self.feature_map_encoding(image_features)
@@ -133,11 +142,19 @@ class CNNTransformer(nn.Module):
"""Decodes image features with transformer decoder."""
trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
+ trg = rearrange(trg, "b t d -> t b d")
trg = self.trg_position_encoding(trg)
+ trg = rearrange(trg, "t b d -> b t d")
out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
logits = self.head(out)
return logits
+ def forward(self, image: Tensor, trg: Tensor) -> Tensor:
+ image_features = self.encode(image)
+ output = self.decode(image_features, trg)
+ output = rearrange(output, "b t c -> b c t")
+ return output
+
def predict(self, image: Tensor) -> Tensor:
"""Transcribes text in image(s)."""
bsz = image.shape[0]
diff --git a/text_recognizer/networks/coat/__init__.py b/text_recognizer/networks/coat/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/text_recognizer/networks/coat/__init__.py
+++ /dev/null
diff --git a/text_recognizer/networks/coat/factor_attention.py b/text_recognizer/networks/coat/factor_attention.py
deleted file mode 100644
index f91c5ef..0000000
--- a/text_recognizer/networks/coat/factor_attention.py
+++ /dev/null
@@ -1,9 +0,0 @@
-"""Factorized attention with convolutional relative positional encodings."""
-from torch import nn
-
-
-class FactorAttention(nn.Module):
- """Factorized attention with relative positional encodings."""
- def __init__(self, dim: int, num_heads: int) -> None:
- pass
-
diff --git a/text_recognizer/networks/coat/patch_embedding.py b/text_recognizer/networks/coat/patch_embedding.py
deleted file mode 100644
index 3b7b76a..0000000
--- a/text_recognizer/networks/coat/patch_embedding.py
+++ /dev/null
@@ -1,38 +0,0 @@
-"""Patch embedding for images and feature maps."""
-from typing import Sequence, Tuple
-
-from einops import rearrange
-from loguru import logger
-from torch import nn
-from torch import Tensor
-
-
-class PatchEmbedding(nn.Module):
- """Patch embedding of images."""
-
- def __init__(
- self,
- image_shape: Sequence[int],
- patch_size: int = 16,
- in_channels: int = 1,
- embedding_dim: int = 512,
- ) -> None:
- if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0:
- logger.error(
- f"Image shape {image_shape} not divisable by patch size {patch_size}"
- )
-
- self.patch_size = patch_size
- self.embedding = nn.Conv2d(
- in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size
- )
- self.norm = nn.LayerNorm(embedding_dim)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
- """Embeds image or feature maps with patch embedding."""
- _, _, h, w = x.shape
- h_out, w_out = h // self.patch_size, w // self.patch_size
- x = self.embedding(x)
- x = rearrange(x, "b c h w -> b (h w) c")
- x = self.norm(x)
- return x, (h_out, w_out)
diff --git a/text_recognizer/networks/coat/positional_encodings.py b/text_recognizer/networks/coat/positional_encodings.py
deleted file mode 100644
index 925db04..0000000
--- a/text_recognizer/networks/coat/positional_encodings.py
+++ /dev/null
@@ -1,76 +0,0 @@
-"""Positional encodings for input sequence to transformer."""
-from typing import Dict, Union, Tuple
-
-from einops import rearrange
-from loguru import logger
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class RelativeEncoding(nn.Module):
- """Relative positional encoding."""
- def __init__(self, channels: int, heads: int, windows: Union[int, Dict[int, int]]) -> None:
- super().__init__()
- self.windows = {windows: heads} if isinstance(windows, int) else windows
- self.heads = list(self.windows.values())
- self.channel_heads = [head * channels for head in self.heads]
- self.convs = nn.ModuleList([
- nn.Conv2d(in_channels=head * channels,
- out_channels=head * channels,
- kernel_shape=window,
- padding=window // 2,
- dilation=1,
- groups=head * channels,
- ) for window, head in self.windows.items()])
-
- def forward(self, q: Tensor, v: Tensor, shape: Tuple[int, int]) -> Tensor:
- """Applies relative positional encoding."""
- b, heads, hw, c = q.shape
- h, w = shape
- if hw != h * w:
- logger.exception(f"Query width {hw} neq to height x width {h * w}")
- raise ValueError
-
- v = rearrange(v, "b heads (h w) c -> b (heads c) h w", h=h, w=w)
- v = torch.split(v, self.channel_heads, dim=1)
- v = [conv(x) for conv, x in zip(self.convs, v)]
- v = torch.cat(v, dim=1)
- v = rearrange(v, "b (heads c) h w -> b heads (h w) c", heads=heads)
-
- encoding = q * v
- zeros = torch.zeros((b, heads, 1, c), dtype=q.dtype, layout=q.layout, device=q.device)
- encoding = torch.cat((zeros, encoding), dim=2)
- return encoding
-
-
-class PositionalEncoding(nn.Module):
- """Convolutional positional encoding."""
- def __init__(self, dim: int, k: int = 3) -> None:
- super().__init__()
- self.encode = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=k, stride=1, padding=k//2, groups=dim)
-
- def forward(self, x: Tensor, shape: Tuple[int, int]) -> Tensor:
- """Applies convolutional encoding."""
- _, hw, _ = x.shape
- h, w = shape
-
- if hw != h * w:
- logger.exception(f"Query width {hw} neq to height x width {h * w}")
- raise ValueError
-
- # Depthwise convolution.
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
- x = self.encode(x) + x
- x = rearrange(x, "b c h w -> b (h w) c")
- return x
-
-
-
-
-
-
-
-
-
-
diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py
index da7553d..c33f419 100644
--- a/text_recognizer/networks/residual_network.py
+++ b/text_recognizer/networks/residual_network.py
@@ -20,11 +20,7 @@ class Conv2dAuto(nn.Conv2d):
def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
"""3x3 convolution with batch norm."""
- conv3x3 = partial(
- Conv2dAuto,
- kernel_size=3,
- bias=False,
- )
+ conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
return nn.Sequential(
conv3x3(in_channels, out_channels, *args, **kwargs),
nn.BatchNorm2d(out_channels),
diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py
index b10f93a..d7e3d08 100644
--- a/text_recognizer/networks/transducer/transducer.py
+++ b/text_recognizer/networks/transducer/transducer.py
@@ -392,12 +392,7 @@ def load_transducer_loss(
transitions = gtn.load(str(processed_path / transitions))
preprocessor = Preprocessor(
- data_dir,
- num_features,
- tokens_path,
- lexicon_path,
- use_words,
- prepend_wordsep,
+ data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
)
num_tokens = preprocessor.num_tokens
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index 5874e97..c50afc3 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -33,7 +33,10 @@ class PositionalEncoding(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Encodes the tensor with a postional embedding."""
- x = x + self.pe[:, : x.shape[1]]
+ # [T, B, D]
+ if x.shape[2] != self.pe.shape[2]:
+ raise ValueError(f"x shape does not match pe in the 3rd dim.")
+ x = x + self.pe[: x.shape[0]]
return self.dropout(x)
@@ -48,6 +51,7 @@ class PositionalEncoding2D(nn.Module):
pe = self.make_pe(hidden_dim, max_h, max_w)
self.register_buffer("pe", pe)
+ @staticmethod
def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
"""Returns 2d postional encoding."""
pe_h = PositionalEncoding.make_pe(
diff --git a/text_recognizer/networks/transformer/rotary_embedding.py b/text_recognizer/networks/transformer/rotary_embedding.py
new file mode 100644
index 0000000..5e80572
--- /dev/null
+++ b/text_recognizer/networks/transformer/rotary_embedding.py
@@ -0,0 +1,39 @@
+"""Roatary embedding.
+
+Stolen from lucidrains:
+ https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
+
+Explanation of roatary:
+ https://blog.eleuther.ai/rotary-embeddings/
+
+"""
+from typing import Tuple
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor:
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ return emb[None, :, :]
+
+
+def rotate_half(x: Tensor) -> Tensor:
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
+ x1, x2 = x.unbind(dim=-2)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
+ q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
+ return q, k
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 93a1e43..32de912 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -44,12 +44,7 @@ class Decoder(nn.Module):
# Configure encoder.
self.decoder = self._build_decoder(
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- activation,
- dropout,
+ channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
)
def _build_decompression_block(
@@ -78,9 +73,7 @@ class Decoder(nn.Module):
)
if self.upsampling and i < len(self.upsampling):
- modules.append(
- nn.Upsample(size=self.upsampling[i]),
- )
+ modules.append(nn.Upsample(size=self.upsampling[i]),)
if dropout is not None:
modules.append(dropout)
@@ -109,12 +102,7 @@ class Decoder(nn.Module):
) -> nn.Sequential:
self.res_block.append(
- nn.Conv2d(
- self.embedding_dim,
- channels[0],
- kernel_size=1,
- stride=1,
- )
+ nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
)
# Bottleneck module.
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index b0cceed..65801df 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -11,10 +11,7 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self,
- in_channels: int,
- out_channels: int,
- dropout: Optional[Type[nn.Module]],
+ self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
@@ -138,12 +135,7 @@ class Encoder(nn.Module):
)
encoder.append(
- nn.Conv2d(
- channels[-1],
- self.embedding_dim,
- kernel_size=1,
- stride=1,
- )
+ nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
)
return nn.Sequential(*encoder)
diff --git a/training/configs/image_transformer.yaml b/training/configs/cnn_transformer.yaml
index e6637f2..a4f16df 100644
--- a/training/configs/image_transformer.yaml
+++ b/training/configs/cnn_transformer.yaml
@@ -2,12 +2,13 @@ seed: 4711
network:
desc: Configuration of the PyTorch neural network.
- type: ImageTransformer
+ type: CNNTransformer
args:
encoder:
- type: null
+ type: EfficientNet
args: null
num_decoder_layers: 4
+ vocab_size: 84
hidden_dim: 256
num_heads: 4
expansion_dim: 1024
@@ -26,7 +27,7 @@ model:
weight_decay: 0
eps: 1.0e-6
lr_scheduler:
- type: OneCycle
+ type: OneCycleLR
args:
interval: &interval step
max_lr: 1.0e-3
@@ -36,7 +37,7 @@ model:
criterion:
type: CrossEntropyLoss
args:
- weight: None
+ weight: null
ignore_index: -100
reduction: mean
monitor: val_loss
@@ -46,7 +47,7 @@ data:
desc: Configuration of the training/test data.
type: IAMExtendedParagraphs
args:
- batch_size: 16
+ batch_size: 8
num_workers: 12
train_fraction: 0.8
augment: true
@@ -57,33 +58,33 @@ callbacks:
monitor: val_loss
mode: min
save_last: true
- - type: StochasticWeightAveraging
- args:
- swa_epoch_start: 0.8
- swa_lrs: 0.05
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
+ # - type: StochasticWeightAveraging
+ # args:
+ # swa_epoch_start: 0.8
+ # swa_lrs: 0.05
+ # annealing_epochs: 10
+ # annealing_strategy: cos
+ # device: null
- type: LearningRateMonitor
args:
logging_interval: *interval
- - type: EarlyStopping
- args:
- monitor: val_loss
- mode: min
- patience: 10
+ # - type: EarlyStopping
+ # args:
+ # monitor: val_loss
+ # mode: min
+ # patience: 10
trainer:
desc: Configuration of the PyTorch Lightning Trainer.
args:
- stochastic_weight_avg: true
+ stochastic_weight_avg: false
auto_scale_batch_size: binsearch
gradient_clip_val: 0
- fast_dev_run: false
+ fast_dev_run: true
gpus: 1
precision: 16
max_epochs: 512
terminate_on_nan: true
- weights_summary: true
+ weights_summary: top
load_checkpoint: null
diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml
index a7acb3a..13d7c97 100644
--- a/training/configs/vqvae.yaml
+++ b/training/configs/vqvae.yaml
@@ -5,12 +5,12 @@ network:
type: VQVAE
args:
in_channels: 1
- channels: [32, 64, 64]
- kernel_sizes: [4, 4, 4]
- strides: [2, 2, 2]
+ channels: [32, 64, 64, 96, 96]
+ kernel_sizes: [4, 4, 4, 4, 4]
+ strides: [2, 2, 2, 2, 2]
num_residual_layers: 2
- embedding_dim: 128
- num_embeddings: 512
+ embedding_dim: 512
+ num_embeddings: 1024
upsampling: null
beta: 0.25
activation: leaky_relu