diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-12 23:42:03 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-12 23:42:03 +0100 |
commit | 8fdb6435e15703fa5b76df19728d905650ee1aef (patch) | |
tree | be3bec9e5cab4ef7f9d94528d102e57ce9b16c3f /src/notebooks/00-testing-stuff-out.ipynb | |
parent | dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (diff) | |
parent | 6cb08a110620ee09fe9d8a5d008197a801d025df (diff) |
Working cnn transformer.
Diffstat (limited to 'src/notebooks/00-testing-stuff-out.ipynb')
-rw-r--r-- | src/notebooks/00-testing-stuff-out.ipynb | 919 |
1 files changed, 632 insertions, 287 deletions
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb index 62e549c..3686dcd 100644 --- a/src/notebooks/00-testing-stuff-out.ipynb +++ b/src/notebooks/00-testing-stuff-out.ipynb @@ -2,18 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -50,7 +41,56 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Path(\"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/TransformerModel_EmnistLinesDataset_CNNTransformer/1112_081300/model/best.pt\").exists()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Path(\"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/TransformerModel_EmnistLinesDataset_CNNTransformer/1112_201649/model/best.pt\").exists()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -63,13 +103,13 @@ " width_factor=1,\n", " dropout_rate= 0.2,\n", " activation= \"SELU\",\n", - " use_decoder= True,\n", + " use_decoder= False,\n", ")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -78,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -89,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -97,78 +137,10 @@ "text/plain": [ "Sequential(\n", " (0): SELU(inplace=True)\n", - " (1): Sequential(\n", - " (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): Sequential(\n", - " (0): WideBlock(\n", - " (activation): SELU(inplace=True)\n", - " (blocks): Sequential(\n", - " (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (1): SELU(inplace=True)\n", - " (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): SELU(inplace=True)\n", - " (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " )\n", - " (2): Sequential(\n", - " (0): WideBlock(\n", - " (activation): SELU(inplace=True)\n", - " (blocks): Sequential(\n", - " (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (1): SELU(inplace=True)\n", - " (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): SELU(inplace=True)\n", - " (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", - " )\n", - " (shortcut): Sequential(\n", - " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", - " )\n", - " )\n", - " )\n", - " (3): Sequential(\n", - " (0): WideBlock(\n", - " (activation): SELU(inplace=True)\n", - " (blocks): Sequential(\n", - " (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (1): SELU(inplace=True)\n", - " (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): SELU(inplace=True)\n", - " (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", - " )\n", - " (shortcut): Sequential(\n", - " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", - " )\n", - " )\n", - " )\n", - " (4): Sequential(\n", - " (0): WideBlock(\n", - " (activation): SELU(inplace=True)\n", - " (blocks): Sequential(\n", - " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (1): SELU(inplace=True)\n", - " (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): SELU(inplace=True)\n", - " (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", - " )\n", - " (shortcut): Sequential(\n", - " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", - " )\n", - " )\n", - " )\n", - " )\n", ")" ] }, - "execution_count": 40, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -179,94 +151,302 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 86, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 256, 4, 119] --\n", - "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n", - "| └─Sequential: 2-2 [-1, 32, 28, 952] --\n", - "| | └─WideBlock: 3-1 [-1, 32, 28, 952] 18,560\n", - "| └─Sequential: 2-3 [-1, 64, 14, 476] --\n", - "| | └─WideBlock: 3-2 [-1, 64, 14, 476] 57,536\n", - "| └─Sequential: 2-4 [-1, 128, 7, 238] --\n", - "| | └─WideBlock: 3-3 [-1, 128, 7, 238] 229,760\n", - "| └─Sequential: 2-5 [-1, 256, 4, 119] --\n", - "| | └─WideBlock: 3-4 [-1, 256, 4, 119] 918,272\n", - "├─Sequential: 1-2 [-1, 80] --\n", - "| └─BatchNorm2d: 2-6 [-1, 256, 4, 119] 512\n", - "├─SELU: 1-3 [-1, 256, 4, 119] --\n", - "├─Sequential: 1 [] --\n", - "| └─SELU: 2-7 [-1, 256, 4, 119] --\n", - "| └─Reduce: 2-8 [-1, 256] --\n", - "| └─Linear: 2-9 [-1, 80] 20,560\n", - "==========================================================================================\n", - "Total params: 1,245,488\n", - "Trainable params: 1,245,488\n", + "===============================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "===============================================================================================\n", + "├─Sequential: 1-1 [-1, 256, 4, 119] --\n", + "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n", + "| └─Sequential: 2-2 [-1, 32, 28, 952] --\n", + "| | └─WideBlock: 3-1 [-1, 32, 28, 952] --\n", + "| | | └─Sequential: 4-1 [-1, 32, 28, 952] --\n", + "| | | | └─BatchNorm2d: 5-1 [-1, 32, 28, 952] 64\n", + "| | | └─SELU: 4-2 [-1, 32, 28, 952] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-2 [-1, 32, 28, 952] --\n", + "| | | | └─Conv2d: 5-3 [-1, 32, 28, 952] 9,216\n", + "| | | | └─Dropout: 5-4 [-1, 32, 28, 952] --\n", + "| | | | └─BatchNorm2d: 5-5 [-1, 32, 28, 952] 64\n", + "| | | └─SELU: 4-3 [-1, 32, 28, 952] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-6 [-1, 32, 28, 952] --\n", + "| | | | └─Conv2d: 5-7 [-1, 32, 28, 952] 9,216\n", + "| └─Sequential: 2-3 [-1, 64, 14, 476] --\n", + "| | └─WideBlock: 3-2 [-1, 64, 14, 476] --\n", + "| | | └─Sequential: 4-4 [-1, 64, 14, 476] --\n", + "| | | | └─Conv2d: 5-8 [-1, 64, 14, 476] 2,048\n", + "| | | └─Sequential: 4-5 [-1, 64, 14, 476] --\n", + "| | | | └─BatchNorm2d: 5-9 [-1, 32, 28, 952] 64\n", + "| | | └─SELU: 4-6 [-1, 32, 28, 952] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-10 [-1, 32, 28, 952] --\n", + "| | | | └─Conv2d: 5-11 [-1, 64, 28, 952] 18,432\n", + "| | | | └─Dropout: 5-12 [-1, 64, 28, 952] --\n", + "| | | | └─BatchNorm2d: 5-13 [-1, 64, 28, 952] 128\n", + "| | | └─SELU: 4-7 [-1, 64, 28, 952] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-14 [-1, 64, 28, 952] --\n", + "| | | | └─Conv2d: 5-15 [-1, 64, 14, 476] 36,864\n", + "| └─Sequential: 2-4 [-1, 128, 7, 238] --\n", + "| | └─WideBlock: 3-3 [-1, 128, 7, 238] --\n", + "| | | └─Sequential: 4-8 [-1, 128, 7, 238] --\n", + "| | | | └─Conv2d: 5-16 [-1, 128, 7, 238] 8,192\n", + "| | | └─Sequential: 4-9 [-1, 128, 7, 238] --\n", + "| | | | └─BatchNorm2d: 5-17 [-1, 64, 14, 476] 128\n", + "| | | └─SELU: 4-10 [-1, 64, 14, 476] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-18 [-1, 64, 14, 476] --\n", + "| | | | └─Conv2d: 5-19 [-1, 128, 14, 476] 73,728\n", + "| | | | └─Dropout: 5-20 [-1, 128, 14, 476] --\n", + "| | | | └─BatchNorm2d: 5-21 [-1, 128, 14, 476] 256\n", + "| | | └─SELU: 4-11 [-1, 128, 14, 476] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-22 [-1, 128, 14, 476] --\n", + "| | | | └─Conv2d: 5-23 [-1, 128, 7, 238] 147,456\n", + "| └─Sequential: 2-5 [-1, 256, 4, 119] --\n", + "| | └─WideBlock: 3-4 [-1, 256, 4, 119] --\n", + "| | | └─Sequential: 4-12 [-1, 256, 4, 119] --\n", + "| | | | └─Conv2d: 5-24 [-1, 256, 4, 119] 32,768\n", + "| | | └─Sequential: 4-13 [-1, 256, 4, 119] --\n", + "| | | | └─BatchNorm2d: 5-25 [-1, 128, 7, 238] 256\n", + "| | | └─SELU: 4-14 [-1, 128, 7, 238] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-26 [-1, 128, 7, 238] --\n", + "| | | | └─Conv2d: 5-27 [-1, 256, 7, 238] 294,912\n", + "| | | | └─Dropout: 5-28 [-1, 256, 7, 238] --\n", + "| | | | └─BatchNorm2d: 5-29 [-1, 256, 7, 238] 512\n", + "| | | └─SELU: 4-15 [-1, 256, 7, 238] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-30 [-1, 256, 7, 238] --\n", + "| | | | └─Conv2d: 5-31 [-1, 256, 4, 119] 589,824\n", + "===============================================================================================\n", + "Total params: 1,224,416\n", + "Trainable params: 1,224,416\n", "Non-trainable params: 0\n", - "Total mult-adds (M): 12.61\n", - "==========================================================================================\n", + "Total mult-adds (G): 2.79\n", + "===============================================================================================\n", "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 7.44\n", - "Params size (MB): 4.75\n", - "Estimated Total Size (MB): 12.29\n", - "==========================================================================================\n" + "Forward/backward pass size (MB): 101.10\n", + "Params size (MB): 4.67\n", + "Estimated Total Size (MB): 105.88\n", + "===============================================================================================\n" ] }, { "data": { "text/plain": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 256, 4, 119] --\n", - "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n", - "| └─Sequential: 2-2 [-1, 32, 28, 952] --\n", - "| | └─WideBlock: 3-1 [-1, 32, 28, 952] 18,560\n", - "| └─Sequential: 2-3 [-1, 64, 14, 476] --\n", - "| | └─WideBlock: 3-2 [-1, 64, 14, 476] 57,536\n", - "| └─Sequential: 2-4 [-1, 128, 7, 238] --\n", - "| | └─WideBlock: 3-3 [-1, 128, 7, 238] 229,760\n", - "| └─Sequential: 2-5 [-1, 256, 4, 119] --\n", - "| | └─WideBlock: 3-4 [-1, 256, 4, 119] 918,272\n", - "├─Sequential: 1-2 [-1, 80] --\n", - "| └─BatchNorm2d: 2-6 [-1, 256, 4, 119] 512\n", - "├─SELU: 1-3 [-1, 256, 4, 119] --\n", - "├─Sequential: 1 [] --\n", - "| └─SELU: 2-7 [-1, 256, 4, 119] --\n", - "| └─Reduce: 2-8 [-1, 256] --\n", - "| └─Linear: 2-9 [-1, 80] 20,560\n", - "==========================================================================================\n", - "Total params: 1,245,488\n", - "Trainable params: 1,245,488\n", + "===============================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "===============================================================================================\n", + "├─Sequential: 1-1 [-1, 256, 4, 119] --\n", + "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n", + "| └─Sequential: 2-2 [-1, 32, 28, 952] --\n", + "| | └─WideBlock: 3-1 [-1, 32, 28, 952] --\n", + "| | | └─Sequential: 4-1 [-1, 32, 28, 952] --\n", + "| | | | └─BatchNorm2d: 5-1 [-1, 32, 28, 952] 64\n", + "| | | └─SELU: 4-2 [-1, 32, 28, 952] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-2 [-1, 32, 28, 952] --\n", + "| | | | └─Conv2d: 5-3 [-1, 32, 28, 952] 9,216\n", + "| | | | └─Dropout: 5-4 [-1, 32, 28, 952] --\n", + "| | | | └─BatchNorm2d: 5-5 [-1, 32, 28, 952] 64\n", + "| | | └─SELU: 4-3 [-1, 32, 28, 952] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-6 [-1, 32, 28, 952] --\n", + "| | | | └─Conv2d: 5-7 [-1, 32, 28, 952] 9,216\n", + "| └─Sequential: 2-3 [-1, 64, 14, 476] --\n", + "| | └─WideBlock: 3-2 [-1, 64, 14, 476] --\n", + "| | | └─Sequential: 4-4 [-1, 64, 14, 476] --\n", + "| | | | └─Conv2d: 5-8 [-1, 64, 14, 476] 2,048\n", + "| | | └─Sequential: 4-5 [-1, 64, 14, 476] --\n", + "| | | | └─BatchNorm2d: 5-9 [-1, 32, 28, 952] 64\n", + "| | | └─SELU: 4-6 [-1, 32, 28, 952] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-10 [-1, 32, 28, 952] --\n", + "| | | | └─Conv2d: 5-11 [-1, 64, 28, 952] 18,432\n", + "| | | | └─Dropout: 5-12 [-1, 64, 28, 952] --\n", + "| | | | └─BatchNorm2d: 5-13 [-1, 64, 28, 952] 128\n", + "| | | └─SELU: 4-7 [-1, 64, 28, 952] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-14 [-1, 64, 28, 952] --\n", + "| | | | └─Conv2d: 5-15 [-1, 64, 14, 476] 36,864\n", + "| └─Sequential: 2-4 [-1, 128, 7, 238] --\n", + "| | └─WideBlock: 3-3 [-1, 128, 7, 238] --\n", + "| | | └─Sequential: 4-8 [-1, 128, 7, 238] --\n", + "| | | | └─Conv2d: 5-16 [-1, 128, 7, 238] 8,192\n", + "| | | └─Sequential: 4-9 [-1, 128, 7, 238] --\n", + "| | | | └─BatchNorm2d: 5-17 [-1, 64, 14, 476] 128\n", + "| | | └─SELU: 4-10 [-1, 64, 14, 476] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-18 [-1, 64, 14, 476] --\n", + "| | | | └─Conv2d: 5-19 [-1, 128, 14, 476] 73,728\n", + "| | | | └─Dropout: 5-20 [-1, 128, 14, 476] --\n", + "| | | | └─BatchNorm2d: 5-21 [-1, 128, 14, 476] 256\n", + "| | | └─SELU: 4-11 [-1, 128, 14, 476] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-22 [-1, 128, 14, 476] --\n", + "| | | | └─Conv2d: 5-23 [-1, 128, 7, 238] 147,456\n", + "| └─Sequential: 2-5 [-1, 256, 4, 119] --\n", + "| | └─WideBlock: 3-4 [-1, 256, 4, 119] --\n", + "| | | └─Sequential: 4-12 [-1, 256, 4, 119] --\n", + "| | | | └─Conv2d: 5-24 [-1, 256, 4, 119] 32,768\n", + "| | | └─Sequential: 4-13 [-1, 256, 4, 119] --\n", + "| | | | └─BatchNorm2d: 5-25 [-1, 128, 7, 238] 256\n", + "| | | └─SELU: 4-14 [-1, 128, 7, 238] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-26 [-1, 128, 7, 238] --\n", + "| | | | └─Conv2d: 5-27 [-1, 256, 7, 238] 294,912\n", + "| | | | └─Dropout: 5-28 [-1, 256, 7, 238] --\n", + "| | | | └─BatchNorm2d: 5-29 [-1, 256, 7, 238] 512\n", + "| | | └─SELU: 4-15 [-1, 256, 7, 238] --\n", + "| | | └─Sequential: 4 [] --\n", + "| | | | └─SELU: 5-30 [-1, 256, 7, 238] --\n", + "| | | | └─Conv2d: 5-31 [-1, 256, 4, 119] 589,824\n", + "===============================================================================================\n", + "Total params: 1,224,416\n", + "Trainable params: 1,224,416\n", "Non-trainable params: 0\n", - "Total mult-adds (M): 12.61\n", - "==========================================================================================\n", + "Total mult-adds (G): 2.79\n", + "===============================================================================================\n", "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 7.44\n", - "Params size (MB): 4.75\n", - "Estimated Total Size (MB): 12.29\n", - "==========================================================================================" + "Forward/backward pass size (MB): 101.10\n", + "Params size (MB): 4.67\n", + "Estimated Total Size (MB): 105.88\n", + "===============================================================================================" + ] + }, + "execution_count": 86, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summary(wr, (1, 28, 952), device=\"cpu\", depth=7)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "a = torch.rand(1, 1, 28, 952)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "b = wr(a)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "from einops import rearrange" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "b = rearrange(b, \"b c h w -> b w c h\")" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "c = nn.AdaptiveAvgPool2d((None, 1))" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "d = c(b)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 119, 256, 1])" ] }, - "execution_count": 8, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "summary(wr, (1, 28, 952), device=\"cpu\", depth=3)" + "d.shape" ] }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 119, 256])" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d.squeeze(3).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 256, 4, 119])" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -533,7 +713,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -542,16 +722,36 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 57, "metadata": {}, "outputs": [], "source": [ - "dnet = DenseNet(12, (6, 8, 10, 6), 1, 24, 80, 4, 0, False)" + "dnet = DenseNet(12, (6, 12, 10), 1, 24, 80, 4, 0, True)" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "27.0" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "216 / 8" ] }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 59, "metadata": {}, "outputs": [ { @@ -561,31 +761,31 @@ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 168, 3, 119] --\n", + "├─Sequential: 1-1 [-1, 80] --\n", "| └─Conv2d: 2-1 [-1, 24, 28, 952] 216\n", "| └─BatchNorm2d: 2-2 [-1, 24, 28, 952] 48\n", "| └─ReLU: 2-3 [-1, 24, 28, 952] --\n", "| └─_DenseBlock: 2-4 [-1, 96, 28, 952] --\n", "| └─_Transition: 2-5 [-1, 48, 14, 476] --\n", "| | └─Sequential: 3-1 [-1, 48, 14, 476] 4,800\n", - "| └─_DenseBlock: 2-6 [-1, 144, 14, 476] --\n", - "| └─_Transition: 2-7 [-1, 72, 7, 238] --\n", - "| | └─Sequential: 3-2 [-1, 72, 7, 238] 10,656\n", - "| └─_DenseBlock: 2-8 [-1, 192, 7, 238] --\n", - "| └─_Transition: 2-9 [-1, 96, 3, 119] --\n", - "| | └─Sequential: 3-3 [-1, 96, 3, 119] 18,816\n", - "| └─_DenseBlock: 2-10 [-1, 168, 3, 119] --\n", - "| └─ReLU: 2-11 [-1, 168, 3, 119] --\n", + "| └─_DenseBlock: 2-6 [-1, 192, 14, 476] --\n", + "| └─_Transition: 2-7 [-1, 96, 7, 238] --\n", + "| | └─Sequential: 3-2 [-1, 96, 7, 238] 18,816\n", + "| └─_DenseBlock: 2-8 [-1, 216, 7, 238] --\n", + "| └─ReLU: 2-9 [-1, 216, 7, 238] --\n", + "| └─AdaptiveAvgPool2d: 2-10 [-1, 216, 1, 1] --\n", + "| └─Rearrange: 2-11 [-1, 216] --\n", + "| └─Linear: 2-12 [-1, 80] 17,360\n", "==========================================================================================\n", - "Total params: 34,536\n", - "Trainable params: 34,536\n", + "Total params: 41,240\n", + "Trainable params: 41,240\n", "Non-trainable params: 0\n", - "Total mult-adds (M): 229.41\n", + "Total mult-adds (M): 252.43\n", "==========================================================================================\n", "Input size (MB): 0.10\n", "Forward/backward pass size (MB): 53.69\n", - "Params size (MB): 0.13\n", - "Estimated Total Size (MB): 53.92\n", + "Params size (MB): 0.16\n", + "Estimated Total Size (MB): 53.95\n", "==========================================================================================\n" ] }, @@ -595,35 +795,35 @@ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", - "├─Sequential: 1-1 [-1, 168, 3, 119] --\n", + "├─Sequential: 1-1 [-1, 80] --\n", "| └─Conv2d: 2-1 [-1, 24, 28, 952] 216\n", "| └─BatchNorm2d: 2-2 [-1, 24, 28, 952] 48\n", "| └─ReLU: 2-3 [-1, 24, 28, 952] --\n", "| └─_DenseBlock: 2-4 [-1, 96, 28, 952] --\n", "| └─_Transition: 2-5 [-1, 48, 14, 476] --\n", "| | └─Sequential: 3-1 [-1, 48, 14, 476] 4,800\n", - "| └─_DenseBlock: 2-6 [-1, 144, 14, 476] --\n", - "| └─_Transition: 2-7 [-1, 72, 7, 238] --\n", - "| | └─Sequential: 3-2 [-1, 72, 7, 238] 10,656\n", - "| └─_DenseBlock: 2-8 [-1, 192, 7, 238] --\n", - "| └─_Transition: 2-9 [-1, 96, 3, 119] --\n", - "| | └─Sequential: 3-3 [-1, 96, 3, 119] 18,816\n", - "| └─_DenseBlock: 2-10 [-1, 168, 3, 119] --\n", - "| └─ReLU: 2-11 [-1, 168, 3, 119] --\n", + "| └─_DenseBlock: 2-6 [-1, 192, 14, 476] --\n", + "| └─_Transition: 2-7 [-1, 96, 7, 238] --\n", + "| | └─Sequential: 3-2 [-1, 96, 7, 238] 18,816\n", + "| └─_DenseBlock: 2-8 [-1, 216, 7, 238] --\n", + "| └─ReLU: 2-9 [-1, 216, 7, 238] --\n", + "| └─AdaptiveAvgPool2d: 2-10 [-1, 216, 1, 1] --\n", + "| └─Rearrange: 2-11 [-1, 216] --\n", + "| └─Linear: 2-12 [-1, 80] 17,360\n", "==========================================================================================\n", - "Total params: 34,536\n", - "Trainable params: 34,536\n", + "Total params: 41,240\n", + "Trainable params: 41,240\n", "Non-trainable params: 0\n", - "Total mult-adds (M): 229.41\n", + "Total mult-adds (M): 252.43\n", "==========================================================================================\n", "Input size (MB): 0.10\n", "Forward/backward pass size (MB): 53.69\n", - "Params size (MB): 0.13\n", - "Estimated Total Size (MB): 53.92\n", + "Params size (MB): 0.16\n", + "Estimated Total Size (MB): 53.95\n", "==========================================================================================" ] }, - "execution_count": 114, + "execution_count": 59, "metadata": {}, "output_type": "execute_result" } @@ -634,6 +834,37 @@ }, { "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [], + "source": [ + " backbone = nn.Sequential(\n", + " *list(dnet.children())[:][:-4]\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Sequential()" + ] + }, + "execution_count": 85, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "backbone" + ] + }, + { + "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], @@ -821,166 +1052,280 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "pred = torch.Tensor([1,1,1,1,1, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n", + "target = torch.Tensor([1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.models.metrics import accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [], + "source": [ + "pad_indcies = torch.nonzero(target == 79, as_tuple=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "t1 = torch.nonzero(target == 81, as_tuple=False).squeeze(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "15.0" + "30" ] }, - "execution_count": 8, + "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "120 / 8" + "target.shape[0]" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 84, + "metadata": {}, + "outputs": [], + "source": [ + "t2 = torch.arange(10, target.shape[0] + 1, 10)" + ] + }, + { + "cell_type": "code", + "execution_count": 85, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "120" + "tensor([10, 20, 30])" ] }, - "execution_count": 27, + "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "2 * 60" + "t2" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 89, "metadata": {}, "outputs": [], "source": [ - "import yaml" + "for start, stop in zip(t1, t2):\n", + " pred[start+1:stop] = 79" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 90, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 1, 1, 1, 1, 1, 81, 79, 79, 79, 79, 2, 1, 1, 1, 1, 81, 79, 79,\n", + " 79, 79, 1, 1, 1, 1, 1, 81, 79, 79, 79, 79])" + ] + }, + "execution_count": 90, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "path = \"../training/experiments/cnn_transformer.yml\"" + "pred" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 88, + "metadata": {}, + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid syntax (<ipython-input-88-b8a4aef86401>, line 1)", + "output_type": "error", + "traceback": [ + "\u001b[0;36m File \u001b[0;32m\"<ipython-input-88-b8a4aef86401>\"\u001b[0;36m, line \u001b[0;32m1\u001b[0m\n\u001b[0;31m [pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" + ] + } + ], + "source": [ + "[pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 6],\n", + " [ 7],\n", + " [ 8],\n", + " [ 9],\n", + " [16],\n", + " [17],\n", + " [18],\n", + " [19],\n", + " [26],\n", + " [27],\n", + " [28],\n", + " [29]])" + ] + }, + "execution_count": 69, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pad_indcies" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "only integer tensors of a single element can be converted to an index", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-71-39b5cc3b1445>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpad_indcies\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mpad_indcies\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m79\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m: only integer tensors of a single element can be converted to an index" + ] + } + ], + "source": [ + "pred[pad_indcies:pad_indcies] = 79" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([20])" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([20])" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "target.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.0" + ] + }, + "execution_count": 91, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "accuracy(pred, target)" + ] + }, + { + "cell_type": "code", + "execution_count": 92, "metadata": {}, "outputs": [], "source": [ - "with open(path, \"r\") as f:\n", - " f = yaml.safe_load(f)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'experiment_group': 'Transformer Experiments',\n", - " 'experiments': [{'train_args': {'transformer_model': True,\n", - " 'batch_size': 16,\n", - " 'max_epochs': 128,\n", - " 'input_shape': [[1, 28, 952], [92]]},\n", - " 'dataset': {'type': 'EmnistLinesDataset',\n", - " 'args': {'subsample_fraction': None,\n", - " 'transform': [{'type': 'ToPILImage', 'args': None},\n", - " {'type': 'Resize', 'args': {'size': [28, 952]}},\n", - " {'type': 'ToTensor', 'args': None}],\n", - " 'max_length': 97,\n", - " 'min_overlap': 0.0,\n", - " 'max_overlap': 0.33,\n", - " 'num_samples': 1,\n", - " 'seed': 4711,\n", - " 'init_token': '<sos>',\n", - " 'pad_token': '_',\n", - " 'eos_token': '<eos>',\n", - " 'target_transform': [{'type': 'AddTokens',\n", - " 'args': {'init_token': '<sos>',\n", - " 'eos_token': '<eos>',\n", - " 'pad_token': '_'}}]},\n", - " 'train_args': {'num_workers': 8,\n", - " 'train_fraction': 0.85,\n", - " 'batch_size': 16}},\n", - " 'model': 'VisionTransformerModel',\n", - " 'metrics': ['accuracy'],\n", - " 'network': {'type': 'CNNTransformer',\n", - " 'args': {'backbone': 'DenseNet',\n", - " 'backbone_args': {'growth_rate': 8,\n", - " 'block_config': [4, 6, 8, 6],\n", - " 'in_channels': 1,\n", - " 'base_channels': 24,\n", - " 'num_classes': 256,\n", - " 'bn_size': 4,\n", - " 'dropout_rate': 0.1,\n", - " 'classifier': False,\n", - " 'activation': 'elu'},\n", - " 'num_encoder_layers': 3,\n", - " 'num_decoder_layers': 3,\n", - " 'hidden_dim': 256,\n", - " 'vocab_size': 82,\n", - " 'num_heads': 8,\n", - " 'max_len': 99,\n", - " 'expansion_dim': 512,\n", - " 'mlp_dim': 256,\n", - " 'spatial_dim': 357,\n", - " 'dropout_rate': 0.1,\n", - " 'trg_pad_index': 79,\n", - " 'activation': 'gelu'}},\n", - " 'criterion': {'type': 'CrossEntropyLoss', 'args': {'ignore_index': 79}},\n", - " 'optimizer': {'type': 'AdamW',\n", - " 'args': {'lr': 0.0003,\n", - " 'betas': [0.9, 0.999],\n", - " 'eps': 1e-08,\n", - " 'weight_decay': 3e-06,\n", - " 'amsgrad': False}},\n", - " 'lr_scheduler': {'type': 'OneCycleLR',\n", - " 'args': {'max_lr': 0.0007,\n", - " 'epochs': 128,\n", - " 'anneal_strategy': 'cos',\n", - " 'pct_start': 0.475,\n", - " 'cycle_momentum': True,\n", - " 'base_momentum': 0.85,\n", - " 'max_momentum': 0.9,\n", - " 'div_factor': 10,\n", - " 'final_div_factor': 10000,\n", - " 'interval': 'step'}},\n", - " 'callbacks': ['Checkpoint',\n", - " 'ProgressBar',\n", - " 'WandbCallback',\n", - " 'WandbImageLogger'],\n", - " 'callback_args': {'Checkpoint': {'monitor': 'val_loss', 'mode': 'min'},\n", - " 'ProgressBar': {'epochs': 128},\n", - " 'WandbCallback': {'log_batch_frequency': 10},\n", - " 'WandbImageLogger': {'num_examples': 6}},\n", - " 'test_metric': 'test_accuracy'}]}" - ] - }, - "execution_count": 27, + "acc = (pred == target).sum().float() / target.shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.9667)" + ] + }, + "execution_count": 93, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "f" + "acc" ] }, { |