From 4d1f2cef39688871d2caafce42a09316381a27ae Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Jul 2021 23:05:25 +0200 Subject: Refactor with attr, working on cnn+transformer network --- notebooks/00-scratch-pad.ipynb | 644 +++++++++++++++++++++++++++--- notebooks/03-look-at-iam-paragraphs.ipynb | 43 +- 2 files changed, 620 insertions(+), 67 deletions(-) (limited to 'notebooks') diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index 2ade2bb..16c6533 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -30,106 +30,244 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet" + "from pathlib import Path" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, + "execution_count": 2, + "metadata": {}, "outputs": [], "source": [ - "en = EfficientNet(\"b0\")" + "import attr" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, + "execution_count": 9, + "metadata": {}, "outputs": [], "source": [ - "summary(en, (1, 224, 224));" + "@attr.s\n", + "class B:\n", + " batch_size = attr.ib()\n", + " num_workers = attr.ib()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ - "torch.cuda.is_available()" + "@attr.s\n", + "class T(B):\n", + "\n", + " def __attrs_post_init__(self) -> None:\n", + " super().__init__(self.batch_size, self.num_workers)\n", + " self.hej = None\n", + " \n", + " batch_size = attr.ib()\n", + " num_workers = attr.ib()\n", + " h: Path = attr.ib(converter=Path)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" + "t = T(batch_size=16, num_workers=2, h=\"hej\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('hej')" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "decoder.cuda()" + "t.h" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "16" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "transformer_decoder = Transformer(num_tokens=1000, max_seq_len=690, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)" + "t.batch_size" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('hej')" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "transformer_decoder.cuda()" + "t.h" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "../text_recognizer/__init__.py\n", + "../text_recognizer/callbacks/__init__.py\n", + "../text_recognizer/callbacks/wandb_callbacks.py\n", + "../text_recognizer/data/image_utils.py\n", + "../text_recognizer/data/emnist.py\n", + "../text_recognizer/data/iam_lines.py\n", + "../text_recognizer/data/download_utils.py\n", + "../text_recognizer/data/mappings.py\n", + "../text_recognizer/data/iam_preprocessor.py\n", + "../text_recognizer/data/__init__.py\n", + "../text_recognizer/data/make_wordpieces.py\n", + "../text_recognizer/data/iam_paragraphs.py\n", + "../text_recognizer/data/sentence_generator.py\n", + "../text_recognizer/data/emnist_lines.py\n", + "../text_recognizer/data/build_transitions.py\n", + "../text_recognizer/data/base_dataset.py\n", + "../text_recognizer/data/base_data_module.py\n", + "../text_recognizer/data/iam.py\n", + "../text_recognizer/data/iam_synthetic_paragraphs.py\n", + "../text_recognizer/data/transforms.py\n", + "../text_recognizer/data/iam_extended_paragraphs.py\n", + "../text_recognizer/networks/__init__.py\n", + "../text_recognizer/networks/util.py\n", + "../text_recognizer/networks/cnn_tranformer.py\n", + "../text_recognizer/networks/encoders/__init__.py\n", + "../text_recognizer/networks/encoders/efficientnet/efficientnet.py\n", + "../text_recognizer/networks/encoders/efficientnet/__init__.py\n", + "../text_recognizer/networks/encoders/efficientnet/utils.py\n", + "../text_recognizer/networks/encoders/efficientnet/mbconv.py\n", + "../text_recognizer/networks/loss/__init__.py\n", + "../text_recognizer/networks/loss/label_smoothing_loss.py\n", + "../text_recognizer/networks/vqvae/__init__.py\n", + "../text_recognizer/networks/vqvae/decoder.py\n", + "../text_recognizer/networks/vqvae/vqvae.py\n", + "../text_recognizer/networks/vqvae/vector_quantizer.py\n", + "../text_recognizer/networks/vqvae/encoder.py\n", + "../text_recognizer/networks/transformer/__init__.py\n", + "../text_recognizer/networks/transformer/layers.py\n", + "../text_recognizer/networks/transformer/residual.py\n", + "../text_recognizer/networks/transformer/attention.py\n", + "../text_recognizer/networks/transformer/transformer.py\n", + "../text_recognizer/networks/transformer/vit.py\n", + "../text_recognizer/networks/transformer/mlp.py\n", + "../text_recognizer/networks/transformer/norm.py\n", + "../text_recognizer/networks/transformer/positional_encodings/positional_encoding.py\n", + "../text_recognizer/networks/transformer/positional_encodings/__init__.py\n", + "../text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py\n", + "../text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py\n", + "../text_recognizer/networks/transformer/nystromer/__init__.py\n", + "../text_recognizer/networks/transformer/nystromer/nystromer.py\n", + "../text_recognizer/networks/transformer/nystromer/attention.py\n", + "../text_recognizer/models/__init__.py\n", + "../text_recognizer/models/base.py\n", + "../text_recognizer/models/vqvae.py\n", + "../text_recognizer/models/transformer.py\n", + "../text_recognizer/models/dino.py\n", + "../text_recognizer/models/metrics.py\n" + ] + } + ], + "source": [ + "for f in Path(\"../text_recognizer\").glob(\"**/*.py\"):\n", + " print(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Path(\"..\").glob(\"**/*.py\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "efficient_transformer = Nystromer(\n", - " dim = 64,\n", - " depth = 4,\n", - " num_heads = 8,\n", - " num_landmarks = 64\n", - ")" + "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "en = EfficientNet(\"b0\")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ - "v = ViT(\n", - " dim = 64,\n", - " image_size = (576, 640),\n", - " patch_size = (32, 32),\n", - " transformer = efficient_transformer\n", - ").cuda()" + "summary(en, (1, 224, 224));" ] }, { @@ -138,7 +276,7 @@ "metadata": {}, "outputs": [], "source": [ - "t = torch.randn(8, 1, 576, 640).cuda()" + "torch.cuda.is_available()" ] }, { @@ -147,7 +285,7 @@ "metadata": {}, "outputs": [], "source": [ - "en.cuda()" + "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" ] }, { @@ -156,7 +294,7 @@ "metadata": {}, "outputs": [], "source": [ - "en(t)" + "decoder.cuda()" ] }, { @@ -165,7 +303,7 @@ "metadata": {}, "outputs": [], "source": [ - "o = v(t)" + "transformer_decoder = Transformer(num_tokens=1003, max_seq_len=451, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)" ] }, { @@ -174,7 +312,7 @@ "metadata": {}, "outputs": [], "source": [ - "caption = torch.randint(0, 90, (16, 690)).cuda()" + "transformer_decoder.cuda()" ] }, { @@ -183,7 +321,12 @@ "metadata": {}, "outputs": [], "source": [ - "o.shape" + "efficient_transformer = Nystromer(\n", + " dim = 64,\n", + " depth = 4,\n", + " num_heads = 8,\n", + " num_landmarks = 64\n", + ")" ] }, { @@ -192,16 +335,405 @@ "metadata": {}, "outputs": [], "source": [ - "caption.shape" + "v = ViT(\n", + " dim = 64,\n", + " image_size = (576, 640),\n", + " patch_size = (32, 32),\n", + " transformer = efficient_transformer\n", + ").cuda()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "o = torch.randn(16, 20 * 18, 128).cuda()" + "t = torch.randn(4, 1, 576, 640).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "EfficientNet(\n", + " (_conv_stem): Sequential(\n", + " (0): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)\n", + " (1): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", + " (2): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (3): Mish(inplace=True)\n", + " )\n", + " (_blocks): ModuleList(\n", + " (0): MBConvBlock(\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)\n", + " (1): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96, bias=False)\n", + " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (2): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), groups=144, bias=False)\n", + " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (3): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(144, 144, kernel_size=(5, 5), stride=(2, 2), groups=144, bias=False)\n", + " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (4): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), groups=240, bias=False)\n", + " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (5): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), groups=240, bias=False)\n", + " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (6): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (7): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (8): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(480, 480, kernel_size=(5, 5), stride=(1, 1), groups=480, bias=False)\n", + " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (9): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (10): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (11): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), groups=672, bias=False)\n", + " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(672, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (12): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (13): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (14): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (15): MBConvBlock(\n", + " (_inverted_bottleneck): Sequential(\n", + " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_depthwise): Sequential(\n", + " (0): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), groups=1152, bias=False)\n", + " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " (2): Mish(inplace=True)\n", + " )\n", + " (_squeeze_excite): Sequential(\n", + " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Mish(inplace=True)\n", + " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (_pointwise): Sequential(\n", + " (0): Conv2d(1152, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(320, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " )\n", + " (_conv_head): Sequential(\n", + " (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(1280, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "en.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 1280, 18, 20])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "en(t).shape" ] }, { @@ -210,7 +742,7 @@ "metadata": {}, "outputs": [], "source": [ - "caption = torch.randint(0, 1000, (16, 200)).cuda()" + "o = v(t)" ] }, { @@ -219,7 +751,7 @@ "metadata": {}, "outputs": [], "source": [ - "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)" + "caption = torch.randint(0, 90, (16, 690)).cuda()" ] }, { @@ -228,7 +760,7 @@ "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet" + "o.shape" ] }, { @@ -237,7 +769,7 @@ "metadata": {}, "outputs": [], "source": [ - "en = EfficientNet()" + "caption.shape" ] }, { @@ -246,7 +778,7 @@ "metadata": {}, "outputs": [], "source": [ - "en.cuda()" + "o = torch.randn(16, 20 * 18, 128).cuda()" ] }, { @@ -255,7 +787,7 @@ "metadata": {}, "outputs": [], "source": [ - "summary(en, (1, 576, 640))" + "caption = torch.randint(0, 1000, (16, 200)).cuda()" ] }, { @@ -264,7 +796,7 @@ "metadata": {}, "outputs": [], "source": [ - "type(efficient_transformer)" + "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)" ] }, { diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index 37fef04..315b7bf 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 2, "id": "726ac25b", "metadata": {}, "outputs": [], @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 3, "id": "c6188bce", "metadata": { "scrolled": true @@ -67,13 +67,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-06-27 20:17:40.498 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", - "2021-06-27 20:17:40.682 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", - "2021-06-27 20:17:40.777 | INFO | text_recognizer.data.iam_paragraphs:setup:111 - Loading IAM paragraph regions and lines for None...\n", - "2021-06-27 20:17:54.542 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", - "2021-06-27 20:17:56.911 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", - "2021-06-27 20:17:57.147 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:75 - IAM Synthetic dataset steup for stage None...\n", - "2021-06-27 20:18:07.707 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" + "2021-06-27 20:59:27.366 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", + "2021-06-27 20:59:27.464 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", + "2021-06-27 20:59:27.559 | INFO | text_recognizer.data.iam_paragraphs:setup:111 - Loading IAM paragraph regions and lines for None...\n", + "2021-06-27 20:59:40.932 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", + "2021-06-27 20:59:43.173 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n", + "2021-06-27 20:59:43.267 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:75 - IAM Synthetic dataset steup for stage None...\n", + "2021-06-27 20:59:53.470 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" ] }, { @@ -84,8 +84,8 @@ "Num classes: 1006\n", "Dims: (1, 576, 640)\n", "Output dims: (682, 1)\n", - "Train/val/test sizes: 19907, 262, 231\n", - "Train Batch x stats: (torch.Size([8, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0105), tensor(0.0575), tensor(1.))\n", + "Train/val/test sizes: 19957, 262, 231\n", + "Train Batch x stats: (torch.Size([8, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0111), tensor(0.0604), tensor(1.))\n", "Train Batch y stats: (torch.Size([8, 451]), torch.int64, tensor(1), tensor(1004))\n", "Test Batch x stats: (torch.Size([8, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0315), tensor(0.0799), tensor(0.9098))\n", "Test Batch y stats: (torch.Size([8, 451]), torch.int64, tensor(1), tensor(1003))\n", @@ -100,6 +100,27 @@ "print(dataset)" ] }, + { + "cell_type": "code", + "execution_count": 4, + "id": "55b26b5d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1006" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(dataset.mapping)" + ] + }, { "cell_type": "code", "execution_count": null, -- cgit v1.2.3-70-g09d2