From 4d1f2cef39688871d2caafce42a09316381a27ae Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Jul 2021 23:05:25 +0200 Subject: Refactor with attr, working on cnn+transformer network --- notebooks/00-scratch-pad.ipynb | 644 +++++++++++++++++++-- notebooks/03-look-at-iam-paragraphs.ipynb | 43 +- text_recognizer/callbacks/__init__.py | 1 + text_recognizer/callbacks/wandb_callbacks.py | 8 +- text_recognizer/criterions/__init__.py | 1 + text_recognizer/criterions/label_smoothing_loss.py | 42 ++ text_recognizer/data/base_data_module.py | 14 +- text_recognizer/data/base_dataset.py | 24 +- text_recognizer/models/__init__.py | 2 - text_recognizer/models/base.py | 11 +- text_recognizer/models/transformer.py | 30 +- text_recognizer/models/vqvae.py | 6 +- text_recognizer/networks/cnn_tranformer.py | 14 + text_recognizer/networks/loss/__init__.py | 2 - .../networks/loss/label_smoothing_loss.py | 42 -- text_recognizer/networks/util.py | 7 +- 16 files changed, 732 insertions(+), 159 deletions(-) create mode 100644 text_recognizer/criterions/__init__.py create mode 100644 text_recognizer/criterions/label_smoothing_loss.py create mode 100644 text_recognizer/networks/cnn_tranformer.py delete mode 100644 text_recognizer/networks/loss/__init__.py delete mode 100644 text_recognizer/networks/loss/label_smoothing_loss.py diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index 2ade2bb..16c6533 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -30,106 +30,244 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet" + "from pathlib import Path" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, + "execution_count": 2, + "metadata": {}, "outputs": [], "source": [ - "en = EfficientNet(\"b0\")" + "import attr" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, + "execution_count": 9, + "metadata": {}, "outputs": [], "source": [ - "summary(en, (1, 224, 224));" + "@attr.s\n", + "class B:\n", + " batch_size = attr.ib()\n", + " num_workers = attr.ib()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ - "torch.cuda.is_available()" + "@attr.s\n", + "class T(B):\n", + "\n", + " def __attrs_post_init__(self) -> None:\n", + " super().__init__(self.batch_size, self.num_workers)\n", + " self.hej = None\n", + " \n", + " batch_size = attr.ib()\n", + " num_workers = attr.ib()\n", + " h: Path = attr.ib(converter=Path)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" + "t = T(batch_size=16, num_workers=2, h=\"hej\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('hej')" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "decoder.cuda()" + "t.h" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "16" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "transformer_decoder = Transformer(num_tokens=1000, max_seq_len=690, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)" + "t.batch_size" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('hej')" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "transformer_decoder.cuda()" + "t.h" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "../text_recognizer/__init__.py\n", + "../text_recognizer/callbacks/__init__.py\n", + "../text_recognizer/callbacks/wandb_callbacks.py\n", + "../text_recognizer/data/image_utils.py\n", + "../text_recognizer/data/emnist.py\n", + "../text_recognizer/data/iam_lines.py\n", + "../text_recognizer/data/download_utils.py\n", + "../text_recognizer/data/mappings.py\n", + "../text_recognizer/data/iam_preprocessor.py\n", + "../text_recognizer/data/__init__.py\n", + "../text_recognizer/data/make_wordpieces.py\n", + "../text_recognizer/data/iam_paragraphs.py\n", + "../text_recognizer/data/sentence_generator.py\n", + "../text_recognizer/data/emnist_lines.py\n", + "../text_recognizer/data/build_transitions.py\n", + "../text_recognizer/data/base_dataset.py\n", + "../text_recognizer/data/base_data_module.py\n", + "../text_recognizer/data/iam.py\n", + "../text_recognizer/data/iam_synthetic_paragraphs.py\n", + "../text_recognizer/data/transforms.py\n", + "../text_recognizer/data/iam_extended_paragraphs.py\n", + "../text_recognizer/networks/__init__.py\n", + "../text_recognizer/networks/util.py\n", + "../text_recognizer/networks/cnn_tranformer.py\n", + "../text_recognizer/networks/encoders/__init__.py\n", + "../text_recognizer/networks/encoders/efficientnet/efficientnet.py\n", + "../text_recognizer/networks/encoders/efficientnet/__init__.py\n", + "../text_recognizer/networks/encoders/efficientnet/utils.py\n", + "../text_recognizer/networks/encoders/efficientnet/mbconv.py\n", + "../text_recognizer/networks/loss/__init__.py\n", + "../text_recognizer/networks/loss/label_smoothing_loss.py\n", + "../text_recognizer/networks/vqvae/__init__.py\n", + "../text_recognizer/networks/vqvae/decoder.py\n", + "../text_recognizer/networks/vqvae/vqvae.py\n", + "../text_recognizer/networks/vqvae/vector_quantizer.py\n", + "../text_recognizer/networks/vqvae/encoder.py\n", + "../text_recognizer/networks/transformer/__init__.py\n", + "../text_recognizer/networks/transformer/layers.py\n", + "../text_recognizer/networks/transformer/residual.py\n", + "../text_recognizer/networks/transformer/attention.py\n", + "../text_recognizer/networks/transformer/transformer.py\n", + "../text_recognizer/networks/transformer/vit.py\n", + "../text_recognizer/networks/transformer/mlp.py\n", + "../text_recognizer/networks/transformer/norm.py\n", + "../text_recognizer/networks/transformer/positional_encodings/positional_encoding.py\n", + "../text_recognizer/networks/transformer/positional_encodings/__init__.py\n", + "../text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py\n", + "../text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py\n", + "../text_recognizer/networks/transformer/nystromer/__init__.py\n", + "../text_recognizer/networks/transformer/nystromer/nystromer.py\n", + "../text_recognizer/networks/transformer/nystromer/attention.py\n", + "../text_recognizer/models/__init__.py\n", + "../text_recognizer/models/base.py\n", + "../text_recognizer/models/vqvae.py\n", + "../text_recognizer/models/transformer.py\n", + "../text_recognizer/models/dino.py\n", + "../text_recognizer/models/metrics.py\n" + ] + } + ], + "source": [ + "for f in Path(\"../text_recognizer\").glob(\"**/*.py\"):\n", + " print(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Path(\"..\").glob(\"**/*.py\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "efficient_transformer = Nystromer(\n", - " dim = 64,\n", - " depth = 4,\n", - " num_heads = 8,\n", - " num_landmarks = 64\n", - ")" + "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "en = EfficientNet(\"b0\")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ - "v = ViT(\n", - " dim = 64,\n", - " image_size = (576, 640),\n", - " patch_size = (32, 32),\n", - " transformer = efficient_transformer\n", - ").cuda()" + "summary(en, (1, 224, 224));" ] }, { @@ -138,7 +276,7 @@ "metadata": {}, "outputs": [], "source": [ - "t = torch.randn(8, 1, 576, 640).cuda()" + "torch.cuda.is_available()" ] }, { @@ -147,7 +285,7 @@ "metadata": {}, "outputs": [], "source": [ - "en.cuda()" + "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" ] }, { @@ -156,7 +294,7 @@ "metadata": {}, "outputs": [], "source": [ - "en(t)" + "decoder.cuda()" ] }, { @@ -165,7 +303,7 @@ "metadata": {}, "outputs": [], "source": [ - "o = v(t)" + "transformer_decoder = Transformer(num_tokens=1003, max_seq_len=451, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)" ] }, { @@ -174,7 +312,7 @@ "metadata": {}, "outputs": [], "source": [ - "caption = torch.randint(0, 90, (16, 690)).cuda()" + "transformer_decoder.cuda()" ] }, { @@ -183,7 +321,12 @@ "metadata": {}, "outputs": [], "source": [ - "o.shape" + "efficient_transformer = Nystromer(\n", + " dim = 64,\n", + " depth = 4,\n", + " num_heads = 8,\n", + " num_landmarks = 64\n", + ")" ] }, { @@ -192,16 +335,405 @@ "metadata": {}, "outputs": [], "source": [ - "caption.shape" + "v = ViT(\n", + " dim = 64,\n", + " image_size = (576, 640),\n", + " patch_size = (32, 32),\n", + " transformer = efficient_transformer\n", + ").cuda()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "o = torch.randn(16, 20 * 18, 128).cuda()" + "t = torch.randn(4, 1, 576, 640).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "EfficientNet(\n", + " (_conv_stem): Sequential(\n", + " (0): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)\n", + " (1): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", + " (2): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (3): Mish(inplace=True)\n", + " )\n", + " (_blocks): ModuleList(\n", + " (0): MBConvBlock(\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)\n", + " (1): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96, bias=False)\n", + " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (2): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), groups=144, bias=False)\n", + " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (3): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(144, 144, kernel_size=(5, 5), stride=(2, 2), groups=144, bias=False)\n", + " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (4): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), groups=240, bias=False)\n", + " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (5): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), groups=240, bias=False)\n", + " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (6): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (7): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (8): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(480, 480, kernel_size=(5, 5), stride=(1, 1), groups=480, bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (9): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (10): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (11): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), groups=672, bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(672, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (12): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (13): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (14): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (15): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), groups=1152, bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(1152, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(320, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " )\n", + " (_conv_head): Sequential(\n", + " (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(1280, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "en.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 1280, 18, 20])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "en(t).shape" ] }, { @@ -210,7 +742,7 @@ "metadata": {}, "outputs": [], "source": [ - "caption = torch.randint(0, 1000, (16, 200)).cuda()" + "o = v(t)" ] }, { @@ -219,7 +751,7 @@ "metadata": {}, "outputs": [], "source": [ - "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)" + "caption = torch.randint(0, 90, (16, 690)).cuda()" ] }, { @@ -228,7 +760,7 @@ "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet" + "o.shape" ] }, { @@ -237,7 +769,7 @@ "metadata": {}, "outputs": [], "source": [ - "en = EfficientNet()" + "caption.shape" ] }, { @@ -246,7 +778,7 @@ "metadata": {}, "outputs": [], "source": [ - "en.cuda()" + "o = torch.randn(16, 20 * 18, 128).cuda()" ] }, { @@ -255,7 +787,7 @@ "metadata": {}, "outputs": [], "source": [ - "summary(en, (1, 576, 640))" + "caption = torch.randint(0, 1000, (16, 200)).cuda()" ] }, { @@ -264,7 +796,7 @@ "metadata": {}, "outputs": [], "source": [ - "type(efficient_transformer)" + "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)" ] }, { diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index 37fef04..315b7bf 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 2, "id": "726ac25b", "metadata": {}, "outputs": [], @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 3, "id": "c6188bce", "metadata": { "scrolled": true @@ -67,13 +67,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-06-27 20:17:40.498 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", - "2021-06-27 20:17:40.682 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", - "2021-06-27 20:17:40.777 | INFO | text_recognizer.data.iam_paragraphs:setup:111 - Loading IAM paragraph regions and lines for None...\n", - "2021-06-27 20:17:54.542 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", - "2021-06-27 20:17:56.911 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", - "2021-06-27 20:17:57.147 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:75 - IAM Synthetic dataset steup for stage None...\n", - "2021-06-27 20:18:07.707 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" + "2021-06-27 20:59:27.366 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", + "2021-06-27 20:59:27.464 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", + "2021-06-27 20:59:27.559 | INFO | text_recognizer.data.iam_paragraphs:setup:111 - Loading IAM paragraph regions and lines for None...\n", + "2021-06-27 20:59:40.932 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", + "2021-06-27 20:59:43.173 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", + "2021-06-27 20:59:43.267 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:75 - IAM Synthetic dataset steup for stage None...\n", + "2021-06-27 20:59:53.470 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" ] }, { @@ -84,8 +84,8 @@ "Num classes: 1006\n", "Dims: (1, 576, 640)\n", "Output dims: (682, 1)\n", - "Train/val/test sizes: 19907, 262, 231\n", - "Train Batch x stats: (torch.Size([8, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0105), tensor(0.0575), tensor(1.))\n", + "Train/val/test sizes: 19957, 262, 231\n", + "Train Batch x stats: (torch.Size([8, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0111), tensor(0.0604), tensor(1.))\n", "Train Batch y stats: (torch.Size([8, 451]), torch.int64, tensor(1), tensor(1004))\n", "Test Batch x stats: (torch.Size([8, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0315), tensor(0.0799), tensor(0.9098))\n", "Test Batch y stats: (torch.Size([8, 451]), torch.int64, tensor(1), tensor(1003))\n", @@ -100,6 +100,27 @@ "print(dataset)" ] }, + { + "cell_type": "code", + "execution_count": 4, + "id": "55b26b5d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1006" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(dataset.mapping)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/text_recognizer/callbacks/__init__.py b/text_recognizer/callbacks/__init__.py index e69de29..82d8ce3 100644 --- a/text_recognizer/callbacks/__init__.py +++ b/text_recognizer/callbacks/__init__.py @@ -0,0 +1 @@ +"""Module for PyTorch Lightning callbacks.""" diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py index 900c3b1..4186b4a 100644 --- a/text_recognizer/callbacks/wandb_callbacks.py +++ b/text_recognizer/callbacks/wandb_callbacks.py @@ -29,7 +29,7 @@ class WatchModel(Callback): log: str = attr.ib(default="gradients") log_freq: int = attr.ib(default=100) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -44,7 +44,7 @@ class UploadCodeAsArtifact(Callback): project_dir: Path = attr.ib(converter=Path) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -65,7 +65,7 @@ class UploadCheckpointAsArtifact(Callback): ckpt_dir: Path = attr.ib(converter=Path) upload_best_only: bool = attr.ib() - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -90,7 +90,7 @@ class LogTextPredictions(Callback): num_samples: int = attr.ib(default=8) ready: bool = attr.ib(default=True) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_sanity_check_start( diff --git a/text_recognizer/criterions/__init__.py b/text_recognizer/criterions/__init__.py new file mode 100644 index 0000000..5b0a7ab --- /dev/null +++ b/text_recognizer/criterions/__init__.py @@ -0,0 +1 @@ +"""Module with custom loss functions.""" diff --git a/text_recognizer/criterions/label_smoothing_loss.py b/text_recognizer/criterions/label_smoothing_loss.py new file mode 100644 index 0000000..40a7609 --- /dev/null +++ b/text_recognizer/criterions/label_smoothing_loss.py @@ -0,0 +1,42 @@ +"""Implementations of custom loss functions.""" +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +class LabelSmoothingLoss(nn.Module): + """Label smoothing cross entropy loss.""" + + def __init__( + self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 + ) -> None: + assert 0.0 < label_smoothing <= 1.0 + self.ignore_index = ignore_index + super().__init__() + + smoothing_value = label_smoothing / (vocab_size - 2) + one_hot = torch.full((vocab_size,), smoothing_value) + one_hot[self.ignore_index] = 0 + self.register_buffer("one_hot", one_hot.unsqueeze(0)) + + self.confidence = 1.0 - label_smoothing + + def forward(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the loss. + + Args: + output (Tensor): Predictions from the network. + targets (Tensor): Ground truth. + + Shapes: + outpus: Batch size x num classes + targets: Batch size + + Returns: + Tensor: Label smoothing loss. + """ + model_prob = self.one_hot.repeat(targets.size(0), 1) + model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) + model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) + return F.kl_div(output, model_prob, reduction="sum") diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 8b5c188..de5628f 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -2,7 +2,8 @@ from pathlib import Path from typing import Dict -import pytorch_lightning as pl +import attr +import pytorch_lightning as LightningDataModule from torch.utils.data import DataLoader @@ -14,14 +15,17 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -class BaseDataModule(pl.LightningDataModule): +@attr.s +class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: + batch_size: int = attr.ib(default=16) + num_workers: int = attr.ib(default=0) + + def __attrs_pre_init__(self) -> None: super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + def __attrs_post_init__(self) -> None: # Placeholders for subclasses. self.dims = None self.output_dims = None diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 8d644d4..4318dfb 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,11 +1,13 @@ """Base PyTorch Dataset class.""" from typing import Any, Callable, Dict, Sequence, Tuple, Union +import attr import torch from torch import Tensor from torch.utils.data import Dataset +@attr.s class BaseDataset(Dataset): """ Base Dataset class that processes data and targets through optional transfroms. @@ -18,19 +20,17 @@ class BaseDataset(Dataset): target transforms. """ - def __init__( - self, - data: Union[Sequence, Tensor], - targets: Union[Sequence, Tensor], - transform: Callable = None, - target_transform: Callable = None, - ) -> None: - if len(data) != len(targets): + data: Union[Sequence, Tensor] = attr.ib() + targets: Union[Sequence, Tensor] = attr.ib() + transform: Callable = attr.ib() + target_transform: Callable = attr.ib() + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def __attrs_post_init__(self) -> None: + if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") - self.data = data - self.targets = targets - self.transform = transform - self.target_transform = target_transform def __len__(self) -> int: """Return the length of the dataset.""" diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py index 5ac2510..1982daf 100644 --- a/text_recognizer/models/__init__.py +++ b/text_recognizer/models/__init__.py @@ -1,3 +1 @@ """PyTorch Lightning models modules.""" -from .transformer import LitTransformerModel -from .vqvae import LitVQVAEModel diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 4e803eb..8dc7a36 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,5 +1,5 @@ """Base PyTorch Lightning model.""" -from typing import Any, Dict, List, Union, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import attr import hydra @@ -13,7 +13,7 @@ import torchmetrics @attr.s -class LitBaseModel(pl.LightningModule): +class BaseLitModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" network: Type[nn.Module] = attr.ib() @@ -30,18 +30,17 @@ class LitBaseModel(pl.LightningModule): val_acc = attr.ib(init=False) test_acc = attr.ib(init=False) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() - def __attrs_post_init__(self): - self.loss_fn = self.configure_criterion() + def __attrs_post_init__(self) -> None: + self.loss_fn = self._configure_criterion() # Accuracy metric self.train_acc = torchmetrics.Accuracy() self.val_acc = torchmetrics.Accuracy() self.test_acc = torchmetrics.Accuracy() - @staticmethod def configure_criterion(self) -> Type[nn.Module]: """Returns a loss functions.""" log.info(f"Instantiating criterion <{self.criterion_config._target_}>") diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 6be0ac5..ea54d83 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,27 +1,35 @@ """PyTorch Lightning model for base Transformers.""" from typing import Dict, List, Optional, Union, Tuple, Type +import attr from omegaconf import DictConfig from torch import nn, Tensor from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import LitBaseModel -class LitTransformerModel(LitBaseModel): +@attr.s +class TransformerLitModel(LitBaseModel): """A PyTorch Lightning model for transformer networks.""" - def __init__( - self, - network: Type[nn.Module], - optimizer: Union[DictConfig, Dict], - lr_scheduler: Union[DictConfig, Dict], - criterion: Union[DictConfig, Dict], - monitor: str = "val_loss", - mapping: Optional[List[str]] = None, - ) -> None: - super().__init__(network, optimizer, lr_scheduler, criterion, monitor) + network: Type[nn.Module] = attr.ib() + criterion_config: DictConfig = attr.ib(converter=DictConfig) + optimizer_config: DictConfig = attr.ib(converter=DictConfig) + lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) + monitor: str = attr.ib() + mapping: Type[AbstractMapping] = attr.ib() + + def __attrs_post_init__(self) -> None: + super().__init__( + network=self.network, + optimizer_config=self.optimizer_config, + lr_scheduler_config=self.lr_scheduler_config, + criterion_config=self.criterion_config, + monitor=self.monitor, + ) self.mapping, ignore_tokens = self.configure_mapping(mapping) self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 18e8691..7dc950f 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -18,7 +18,7 @@ class LitVQVAEModel(LitBaseModel): optimizer: Union[DictConfig, Dict], lr_scheduler: Union[DictConfig, Dict], criterion: Union[DictConfig, Dict], - monitor: str = "val_loss", + monitor: str = "val/loss", *args: Any, **kwargs: Dict, ) -> None: @@ -50,7 +50,7 @@ class LitVQVAEModel(LitBaseModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - self.log("train_loss", loss) + self.log("train/loss", loss) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -59,7 +59,7 @@ class LitVQVAEModel(LitBaseModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - self.log("val_loss", loss, prog_bar=True) + self.log("val/loss", loss, prog_bar=True) title = "val_pred_examples" self._log_prediction(data, reconstructions, title) diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py new file mode 100644 index 0000000..da69311 --- /dev/null +++ b/text_recognizer/networks/cnn_tranformer.py @@ -0,0 +1,14 @@ +"""Vision transformer for character recognition.""" +from typing import Type + +import attr +from torch import nn, Tensor + + +@attr.s +class CnnTransformer(nn.Module): + def __attrs_pre_init__(self) -> None: + super().__init__() + + backbone: Type[nn.Module] = attr.ib() + head = Type[nn.Module] = attr.ib() diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py deleted file mode 100644 index cb83608..0000000 --- a/text_recognizer/networks/loss/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Loss module.""" -from .loss import LabelSmoothingCrossEntropy diff --git a/text_recognizer/networks/loss/label_smoothing_loss.py b/text_recognizer/networks/loss/label_smoothing_loss.py deleted file mode 100644 index 40a7609..0000000 --- a/text_recognizer/networks/loss/label_smoothing_loss.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -class LabelSmoothingLoss(nn.Module): - """Label smoothing cross entropy loss.""" - - def __init__( - self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 - ) -> None: - assert 0.0 < label_smoothing <= 1.0 - self.ignore_index = ignore_index - super().__init__() - - smoothing_value = label_smoothing / (vocab_size - 2) - one_hot = torch.full((vocab_size,), smoothing_value) - one_hot[self.ignore_index] = 0 - self.register_buffer("one_hot", one_hot.unsqueeze(0)) - - self.confidence = 1.0 - label_smoothing - - def forward(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the loss. - - Args: - output (Tensor): Predictions from the network. - targets (Tensor): Ground truth. - - Shapes: - outpus: Batch size x num classes - targets: Batch size - - Returns: - Tensor: Label smoothing loss. - """ - model_prob = self.one_hot.repeat(targets.size(0), 1) - model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) - model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) - return F.kl_div(output, model_prob, reduction="sum") diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 05b10a8..109bf4d 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -1,10 +1,6 @@ """Miscellaneous neural network functionality.""" -import importlib -from pathlib import Path -from typing import Dict, NamedTuple, Union, Type +from typing import Type -from loguru import logger -import torch from torch import nn @@ -19,6 +15,7 @@ def activation_function(activation: str) -> Type[nn.Module]: ["none", nn.Identity()], ["relu", nn.ReLU(inplace=True)], ["selu", nn.SELU(inplace=True)], + ["mish", nn.Mish(inplace=True)], ] ) return activation_fns[activation.lower()] -- cgit v1.2.3-70-g09d2