summaryrefslogtreecommitdiff
path: root/src/notebooks/00-testing-stuff-out.ipynb
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
commit8fdb6435e15703fa5b76df19728d905650ee1aef (patch)
treebe3bec9e5cab4ef7f9d94528d102e57ce9b16c3f /src/notebooks/00-testing-stuff-out.ipynb
parentdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (diff)
parent6cb08a110620ee09fe9d8a5d008197a801d025df (diff)
Working cnn transformer.
Diffstat (limited to 'src/notebooks/00-testing-stuff-out.ipynb')
-rw-r--r--src/notebooks/00-testing-stuff-out.ipynb919
1 files changed, 632 insertions, 287 deletions
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb
index 62e549c..3686dcd 100644
--- a/src/notebooks/00-testing-stuff-out.ipynb
+++ b/src/notebooks/00-testing-stuff-out.ipynb
@@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 6,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The autoreload extension is already loaded. To reload it, use:\n",
- " %reload_ext autoreload\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@@ -50,7 +41,56 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pathlib import Path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "Path(\"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/TransformerModel_EmnistLinesDataset_CNNTransformer/1112_081300/model/best.pt\").exists()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "False"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "Path(\"/home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/src/training/experiments/TransformerModel_EmnistLinesDataset_CNNTransformer/1112_201649/model/best.pt\").exists()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@@ -63,13 +103,13 @@
" width_factor=1,\n",
" dropout_rate= 0.2,\n",
" activation= \"SELU\",\n",
- " use_decoder= True,\n",
+ " use_decoder= False,\n",
")"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@@ -78,7 +118,7 @@
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@@ -89,7 +129,7 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": 21,
"metadata": {},
"outputs": [
{
@@ -97,78 +137,10 @@
"text/plain": [
"Sequential(\n",
" (0): SELU(inplace=True)\n",
- " (1): Sequential(\n",
- " (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): Sequential(\n",
- " (0): WideBlock(\n",
- " (activation): SELU(inplace=True)\n",
- " (blocks): Sequential(\n",
- " (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (1): SELU(inplace=True)\n",
- " (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.2, inplace=False)\n",
- " (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): SELU(inplace=True)\n",
- " (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " (2): Sequential(\n",
- " (0): WideBlock(\n",
- " (activation): SELU(inplace=True)\n",
- " (blocks): Sequential(\n",
- " (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (1): SELU(inplace=True)\n",
- " (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.2, inplace=False)\n",
- " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): SELU(inplace=True)\n",
- " (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
- " )\n",
- " (shortcut): Sequential(\n",
- " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " (3): Sequential(\n",
- " (0): WideBlock(\n",
- " (activation): SELU(inplace=True)\n",
- " (blocks): Sequential(\n",
- " (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (1): SELU(inplace=True)\n",
- " (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.2, inplace=False)\n",
- " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): SELU(inplace=True)\n",
- " (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
- " )\n",
- " (shortcut): Sequential(\n",
- " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " (4): Sequential(\n",
- " (0): WideBlock(\n",
- " (activation): SELU(inplace=True)\n",
- " (blocks): Sequential(\n",
- " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (1): SELU(inplace=True)\n",
- " (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.2, inplace=False)\n",
- " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
- " (5): SELU(inplace=True)\n",
- " (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
- " )\n",
- " (shortcut): Sequential(\n",
- " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " )\n",
")"
]
},
- "execution_count": 40,
+ "execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@@ -179,94 +151,302 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 256, 4, 119] --\n",
- "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n",
- "| └─Sequential: 2-2 [-1, 32, 28, 952] --\n",
- "| | └─WideBlock: 3-1 [-1, 32, 28, 952] 18,560\n",
- "| └─Sequential: 2-3 [-1, 64, 14, 476] --\n",
- "| | └─WideBlock: 3-2 [-1, 64, 14, 476] 57,536\n",
- "| └─Sequential: 2-4 [-1, 128, 7, 238] --\n",
- "| | └─WideBlock: 3-3 [-1, 128, 7, 238] 229,760\n",
- "| └─Sequential: 2-5 [-1, 256, 4, 119] --\n",
- "| | └─WideBlock: 3-4 [-1, 256, 4, 119] 918,272\n",
- "├─Sequential: 1-2 [-1, 80] --\n",
- "| └─BatchNorm2d: 2-6 [-1, 256, 4, 119] 512\n",
- "├─SELU: 1-3 [-1, 256, 4, 119] --\n",
- "├─Sequential: 1 [] --\n",
- "| └─SELU: 2-7 [-1, 256, 4, 119] --\n",
- "| └─Reduce: 2-8 [-1, 256] --\n",
- "| └─Linear: 2-9 [-1, 80] 20,560\n",
- "==========================================================================================\n",
- "Total params: 1,245,488\n",
- "Trainable params: 1,245,488\n",
+ "===============================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "===============================================================================================\n",
+ "├─Sequential: 1-1 [-1, 256, 4, 119] --\n",
+ "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n",
+ "| └─Sequential: 2-2 [-1, 32, 28, 952] --\n",
+ "| | └─WideBlock: 3-1 [-1, 32, 28, 952] --\n",
+ "| | | └─Sequential: 4-1 [-1, 32, 28, 952] --\n",
+ "| | | | └─BatchNorm2d: 5-1 [-1, 32, 28, 952] 64\n",
+ "| | | └─SELU: 4-2 [-1, 32, 28, 952] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-2 [-1, 32, 28, 952] --\n",
+ "| | | | └─Conv2d: 5-3 [-1, 32, 28, 952] 9,216\n",
+ "| | | | └─Dropout: 5-4 [-1, 32, 28, 952] --\n",
+ "| | | | └─BatchNorm2d: 5-5 [-1, 32, 28, 952] 64\n",
+ "| | | └─SELU: 4-3 [-1, 32, 28, 952] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-6 [-1, 32, 28, 952] --\n",
+ "| | | | └─Conv2d: 5-7 [-1, 32, 28, 952] 9,216\n",
+ "| └─Sequential: 2-3 [-1, 64, 14, 476] --\n",
+ "| | └─WideBlock: 3-2 [-1, 64, 14, 476] --\n",
+ "| | | └─Sequential: 4-4 [-1, 64, 14, 476] --\n",
+ "| | | | └─Conv2d: 5-8 [-1, 64, 14, 476] 2,048\n",
+ "| | | └─Sequential: 4-5 [-1, 64, 14, 476] --\n",
+ "| | | | └─BatchNorm2d: 5-9 [-1, 32, 28, 952] 64\n",
+ "| | | └─SELU: 4-6 [-1, 32, 28, 952] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-10 [-1, 32, 28, 952] --\n",
+ "| | | | └─Conv2d: 5-11 [-1, 64, 28, 952] 18,432\n",
+ "| | | | └─Dropout: 5-12 [-1, 64, 28, 952] --\n",
+ "| | | | └─BatchNorm2d: 5-13 [-1, 64, 28, 952] 128\n",
+ "| | | └─SELU: 4-7 [-1, 64, 28, 952] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-14 [-1, 64, 28, 952] --\n",
+ "| | | | └─Conv2d: 5-15 [-1, 64, 14, 476] 36,864\n",
+ "| └─Sequential: 2-4 [-1, 128, 7, 238] --\n",
+ "| | └─WideBlock: 3-3 [-1, 128, 7, 238] --\n",
+ "| | | └─Sequential: 4-8 [-1, 128, 7, 238] --\n",
+ "| | | | └─Conv2d: 5-16 [-1, 128, 7, 238] 8,192\n",
+ "| | | └─Sequential: 4-9 [-1, 128, 7, 238] --\n",
+ "| | | | └─BatchNorm2d: 5-17 [-1, 64, 14, 476] 128\n",
+ "| | | └─SELU: 4-10 [-1, 64, 14, 476] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-18 [-1, 64, 14, 476] --\n",
+ "| | | | └─Conv2d: 5-19 [-1, 128, 14, 476] 73,728\n",
+ "| | | | └─Dropout: 5-20 [-1, 128, 14, 476] --\n",
+ "| | | | └─BatchNorm2d: 5-21 [-1, 128, 14, 476] 256\n",
+ "| | | └─SELU: 4-11 [-1, 128, 14, 476] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-22 [-1, 128, 14, 476] --\n",
+ "| | | | └─Conv2d: 5-23 [-1, 128, 7, 238] 147,456\n",
+ "| └─Sequential: 2-5 [-1, 256, 4, 119] --\n",
+ "| | └─WideBlock: 3-4 [-1, 256, 4, 119] --\n",
+ "| | | └─Sequential: 4-12 [-1, 256, 4, 119] --\n",
+ "| | | | └─Conv2d: 5-24 [-1, 256, 4, 119] 32,768\n",
+ "| | | └─Sequential: 4-13 [-1, 256, 4, 119] --\n",
+ "| | | | └─BatchNorm2d: 5-25 [-1, 128, 7, 238] 256\n",
+ "| | | └─SELU: 4-14 [-1, 128, 7, 238] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-26 [-1, 128, 7, 238] --\n",
+ "| | | | └─Conv2d: 5-27 [-1, 256, 7, 238] 294,912\n",
+ "| | | | └─Dropout: 5-28 [-1, 256, 7, 238] --\n",
+ "| | | | └─BatchNorm2d: 5-29 [-1, 256, 7, 238] 512\n",
+ "| | | └─SELU: 4-15 [-1, 256, 7, 238] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-30 [-1, 256, 7, 238] --\n",
+ "| | | | └─Conv2d: 5-31 [-1, 256, 4, 119] 589,824\n",
+ "===============================================================================================\n",
+ "Total params: 1,224,416\n",
+ "Trainable params: 1,224,416\n",
"Non-trainable params: 0\n",
- "Total mult-adds (M): 12.61\n",
- "==========================================================================================\n",
+ "Total mult-adds (G): 2.79\n",
+ "===============================================================================================\n",
"Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 7.44\n",
- "Params size (MB): 4.75\n",
- "Estimated Total Size (MB): 12.29\n",
- "==========================================================================================\n"
+ "Forward/backward pass size (MB): 101.10\n",
+ "Params size (MB): 4.67\n",
+ "Estimated Total Size (MB): 105.88\n",
+ "===============================================================================================\n"
]
},
{
"data": {
"text/plain": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 256, 4, 119] --\n",
- "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n",
- "| └─Sequential: 2-2 [-1, 32, 28, 952] --\n",
- "| | └─WideBlock: 3-1 [-1, 32, 28, 952] 18,560\n",
- "| └─Sequential: 2-3 [-1, 64, 14, 476] --\n",
- "| | └─WideBlock: 3-2 [-1, 64, 14, 476] 57,536\n",
- "| └─Sequential: 2-4 [-1, 128, 7, 238] --\n",
- "| | └─WideBlock: 3-3 [-1, 128, 7, 238] 229,760\n",
- "| └─Sequential: 2-5 [-1, 256, 4, 119] --\n",
- "| | └─WideBlock: 3-4 [-1, 256, 4, 119] 918,272\n",
- "├─Sequential: 1-2 [-1, 80] --\n",
- "| └─BatchNorm2d: 2-6 [-1, 256, 4, 119] 512\n",
- "├─SELU: 1-3 [-1, 256, 4, 119] --\n",
- "├─Sequential: 1 [] --\n",
- "| └─SELU: 2-7 [-1, 256, 4, 119] --\n",
- "| └─Reduce: 2-8 [-1, 256] --\n",
- "| └─Linear: 2-9 [-1, 80] 20,560\n",
- "==========================================================================================\n",
- "Total params: 1,245,488\n",
- "Trainable params: 1,245,488\n",
+ "===============================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "===============================================================================================\n",
+ "├─Sequential: 1-1 [-1, 256, 4, 119] --\n",
+ "| └─Conv2d: 2-1 [-1, 32, 28, 952] 288\n",
+ "| └─Sequential: 2-2 [-1, 32, 28, 952] --\n",
+ "| | └─WideBlock: 3-1 [-1, 32, 28, 952] --\n",
+ "| | | └─Sequential: 4-1 [-1, 32, 28, 952] --\n",
+ "| | | | └─BatchNorm2d: 5-1 [-1, 32, 28, 952] 64\n",
+ "| | | └─SELU: 4-2 [-1, 32, 28, 952] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-2 [-1, 32, 28, 952] --\n",
+ "| | | | └─Conv2d: 5-3 [-1, 32, 28, 952] 9,216\n",
+ "| | | | └─Dropout: 5-4 [-1, 32, 28, 952] --\n",
+ "| | | | └─BatchNorm2d: 5-5 [-1, 32, 28, 952] 64\n",
+ "| | | └─SELU: 4-3 [-1, 32, 28, 952] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-6 [-1, 32, 28, 952] --\n",
+ "| | | | └─Conv2d: 5-7 [-1, 32, 28, 952] 9,216\n",
+ "| └─Sequential: 2-3 [-1, 64, 14, 476] --\n",
+ "| | └─WideBlock: 3-2 [-1, 64, 14, 476] --\n",
+ "| | | └─Sequential: 4-4 [-1, 64, 14, 476] --\n",
+ "| | | | └─Conv2d: 5-8 [-1, 64, 14, 476] 2,048\n",
+ "| | | └─Sequential: 4-5 [-1, 64, 14, 476] --\n",
+ "| | | | └─BatchNorm2d: 5-9 [-1, 32, 28, 952] 64\n",
+ "| | | └─SELU: 4-6 [-1, 32, 28, 952] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-10 [-1, 32, 28, 952] --\n",
+ "| | | | └─Conv2d: 5-11 [-1, 64, 28, 952] 18,432\n",
+ "| | | | └─Dropout: 5-12 [-1, 64, 28, 952] --\n",
+ "| | | | └─BatchNorm2d: 5-13 [-1, 64, 28, 952] 128\n",
+ "| | | └─SELU: 4-7 [-1, 64, 28, 952] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-14 [-1, 64, 28, 952] --\n",
+ "| | | | └─Conv2d: 5-15 [-1, 64, 14, 476] 36,864\n",
+ "| └─Sequential: 2-4 [-1, 128, 7, 238] --\n",
+ "| | └─WideBlock: 3-3 [-1, 128, 7, 238] --\n",
+ "| | | └─Sequential: 4-8 [-1, 128, 7, 238] --\n",
+ "| | | | └─Conv2d: 5-16 [-1, 128, 7, 238] 8,192\n",
+ "| | | └─Sequential: 4-9 [-1, 128, 7, 238] --\n",
+ "| | | | └─BatchNorm2d: 5-17 [-1, 64, 14, 476] 128\n",
+ "| | | └─SELU: 4-10 [-1, 64, 14, 476] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-18 [-1, 64, 14, 476] --\n",
+ "| | | | └─Conv2d: 5-19 [-1, 128, 14, 476] 73,728\n",
+ "| | | | └─Dropout: 5-20 [-1, 128, 14, 476] --\n",
+ "| | | | └─BatchNorm2d: 5-21 [-1, 128, 14, 476] 256\n",
+ "| | | └─SELU: 4-11 [-1, 128, 14, 476] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-22 [-1, 128, 14, 476] --\n",
+ "| | | | └─Conv2d: 5-23 [-1, 128, 7, 238] 147,456\n",
+ "| └─Sequential: 2-5 [-1, 256, 4, 119] --\n",
+ "| | └─WideBlock: 3-4 [-1, 256, 4, 119] --\n",
+ "| | | └─Sequential: 4-12 [-1, 256, 4, 119] --\n",
+ "| | | | └─Conv2d: 5-24 [-1, 256, 4, 119] 32,768\n",
+ "| | | └─Sequential: 4-13 [-1, 256, 4, 119] --\n",
+ "| | | | └─BatchNorm2d: 5-25 [-1, 128, 7, 238] 256\n",
+ "| | | └─SELU: 4-14 [-1, 128, 7, 238] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-26 [-1, 128, 7, 238] --\n",
+ "| | | | └─Conv2d: 5-27 [-1, 256, 7, 238] 294,912\n",
+ "| | | | └─Dropout: 5-28 [-1, 256, 7, 238] --\n",
+ "| | | | └─BatchNorm2d: 5-29 [-1, 256, 7, 238] 512\n",
+ "| | | └─SELU: 4-15 [-1, 256, 7, 238] --\n",
+ "| | | └─Sequential: 4 [] --\n",
+ "| | | | └─SELU: 5-30 [-1, 256, 7, 238] --\n",
+ "| | | | └─Conv2d: 5-31 [-1, 256, 4, 119] 589,824\n",
+ "===============================================================================================\n",
+ "Total params: 1,224,416\n",
+ "Trainable params: 1,224,416\n",
"Non-trainable params: 0\n",
- "Total mult-adds (M): 12.61\n",
- "==========================================================================================\n",
+ "Total mult-adds (G): 2.79\n",
+ "===============================================================================================\n",
"Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 7.44\n",
- "Params size (MB): 4.75\n",
- "Estimated Total Size (MB): 12.29\n",
- "=========================================================================================="
+ "Forward/backward pass size (MB): 101.10\n",
+ "Params size (MB): 4.67\n",
+ "Estimated Total Size (MB): 105.88\n",
+ "==============================================================================================="
+ ]
+ },
+ "execution_count": 86,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "summary(wr, (1, 28, 952), device=\"cpu\", depth=7)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "a = torch.rand(1, 1, 28, 952)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b = wr(a)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from einops import rearrange"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b = rearrange(b, \"b c h w -> b w c h\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "c = nn.AdaptiveAvgPool2d((None, 1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "d = c(b)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 119, 256, 1])"
]
},
- "execution_count": 8,
+ "execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "summary(wr, (1, 28, 952), device=\"cpu\", depth=3)"
+ "d.shape"
]
},
{
"cell_type": "code",
- "execution_count": 64,
+ "execution_count": 44,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 119, 256])"
+ ]
+ },
+ "execution_count": 44,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "d.squeeze(3).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 256, 4, 119])"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "b.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@@ -533,7 +713,7 @@
},
{
"cell_type": "code",
- "execution_count": 74,
+ "execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
@@ -542,16 +722,36 @@
},
{
"cell_type": "code",
- "execution_count": 113,
+ "execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
- "dnet = DenseNet(12, (6, 8, 10, 6), 1, 24, 80, 4, 0, False)"
+ "dnet = DenseNet(12, (6, 12, 10), 1, 24, 80, 4, 0, True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "27.0"
+ ]
+ },
+ "execution_count": 58,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "216 / 8"
]
},
{
"cell_type": "code",
- "execution_count": 114,
+ "execution_count": 59,
"metadata": {},
"outputs": [
{
@@ -561,31 +761,31 @@
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 168, 3, 119] --\n",
+ "├─Sequential: 1-1 [-1, 80] --\n",
"| └─Conv2d: 2-1 [-1, 24, 28, 952] 216\n",
"| └─BatchNorm2d: 2-2 [-1, 24, 28, 952] 48\n",
"| └─ReLU: 2-3 [-1, 24, 28, 952] --\n",
"| └─_DenseBlock: 2-4 [-1, 96, 28, 952] --\n",
"| └─_Transition: 2-5 [-1, 48, 14, 476] --\n",
"| | └─Sequential: 3-1 [-1, 48, 14, 476] 4,800\n",
- "| └─_DenseBlock: 2-6 [-1, 144, 14, 476] --\n",
- "| └─_Transition: 2-7 [-1, 72, 7, 238] --\n",
- "| | └─Sequential: 3-2 [-1, 72, 7, 238] 10,656\n",
- "| └─_DenseBlock: 2-8 [-1, 192, 7, 238] --\n",
- "| └─_Transition: 2-9 [-1, 96, 3, 119] --\n",
- "| | └─Sequential: 3-3 [-1, 96, 3, 119] 18,816\n",
- "| └─_DenseBlock: 2-10 [-1, 168, 3, 119] --\n",
- "| └─ReLU: 2-11 [-1, 168, 3, 119] --\n",
+ "| └─_DenseBlock: 2-6 [-1, 192, 14, 476] --\n",
+ "| └─_Transition: 2-7 [-1, 96, 7, 238] --\n",
+ "| | └─Sequential: 3-2 [-1, 96, 7, 238] 18,816\n",
+ "| └─_DenseBlock: 2-8 [-1, 216, 7, 238] --\n",
+ "| └─ReLU: 2-9 [-1, 216, 7, 238] --\n",
+ "| └─AdaptiveAvgPool2d: 2-10 [-1, 216, 1, 1] --\n",
+ "| └─Rearrange: 2-11 [-1, 216] --\n",
+ "| └─Linear: 2-12 [-1, 80] 17,360\n",
"==========================================================================================\n",
- "Total params: 34,536\n",
- "Trainable params: 34,536\n",
+ "Total params: 41,240\n",
+ "Trainable params: 41,240\n",
"Non-trainable params: 0\n",
- "Total mult-adds (M): 229.41\n",
+ "Total mult-adds (M): 252.43\n",
"==========================================================================================\n",
"Input size (MB): 0.10\n",
"Forward/backward pass size (MB): 53.69\n",
- "Params size (MB): 0.13\n",
- "Estimated Total Size (MB): 53.92\n",
+ "Params size (MB): 0.16\n",
+ "Estimated Total Size (MB): 53.95\n",
"==========================================================================================\n"
]
},
@@ -595,35 +795,35 @@
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
- "├─Sequential: 1-1 [-1, 168, 3, 119] --\n",
+ "├─Sequential: 1-1 [-1, 80] --\n",
"| └─Conv2d: 2-1 [-1, 24, 28, 952] 216\n",
"| └─BatchNorm2d: 2-2 [-1, 24, 28, 952] 48\n",
"| └─ReLU: 2-3 [-1, 24, 28, 952] --\n",
"| └─_DenseBlock: 2-4 [-1, 96, 28, 952] --\n",
"| └─_Transition: 2-5 [-1, 48, 14, 476] --\n",
"| | └─Sequential: 3-1 [-1, 48, 14, 476] 4,800\n",
- "| └─_DenseBlock: 2-6 [-1, 144, 14, 476] --\n",
- "| └─_Transition: 2-7 [-1, 72, 7, 238] --\n",
- "| | └─Sequential: 3-2 [-1, 72, 7, 238] 10,656\n",
- "| └─_DenseBlock: 2-8 [-1, 192, 7, 238] --\n",
- "| └─_Transition: 2-9 [-1, 96, 3, 119] --\n",
- "| | └─Sequential: 3-3 [-1, 96, 3, 119] 18,816\n",
- "| └─_DenseBlock: 2-10 [-1, 168, 3, 119] --\n",
- "| └─ReLU: 2-11 [-1, 168, 3, 119] --\n",
+ "| └─_DenseBlock: 2-6 [-1, 192, 14, 476] --\n",
+ "| └─_Transition: 2-7 [-1, 96, 7, 238] --\n",
+ "| | └─Sequential: 3-2 [-1, 96, 7, 238] 18,816\n",
+ "| └─_DenseBlock: 2-8 [-1, 216, 7, 238] --\n",
+ "| └─ReLU: 2-9 [-1, 216, 7, 238] --\n",
+ "| └─AdaptiveAvgPool2d: 2-10 [-1, 216, 1, 1] --\n",
+ "| └─Rearrange: 2-11 [-1, 216] --\n",
+ "| └─Linear: 2-12 [-1, 80] 17,360\n",
"==========================================================================================\n",
- "Total params: 34,536\n",
- "Trainable params: 34,536\n",
+ "Total params: 41,240\n",
+ "Trainable params: 41,240\n",
"Non-trainable params: 0\n",
- "Total mult-adds (M): 229.41\n",
+ "Total mult-adds (M): 252.43\n",
"==========================================================================================\n",
"Input size (MB): 0.10\n",
"Forward/backward pass size (MB): 53.69\n",
- "Params size (MB): 0.13\n",
- "Estimated Total Size (MB): 53.92\n",
+ "Params size (MB): 0.16\n",
+ "Estimated Total Size (MB): 53.95\n",
"=========================================================================================="
]
},
- "execution_count": 114,
+ "execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
@@ -634,6 +834,37 @@
},
{
"cell_type": "code",
+ "execution_count": 84,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ " backbone = nn.Sequential(\n",
+ " *list(dnet.children())[:][:-4]\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 85,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Sequential()"
+ ]
+ },
+ "execution_count": 85,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "backbone"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
@@ -821,166 +1052,280 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 59,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred = torch.Tensor([1,1,1,1,1, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n",
+ "target = torch.Tensor([1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.models.metrics import accuracy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 61,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pad_indcies = torch.nonzero(target == 79, as_tuple=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 76,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t1 = torch.nonzero(target == 81, as_tuple=False).squeeze(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 66,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "15.0"
+ "30"
]
},
- "execution_count": 8,
+ "execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "120 / 8"
+ "target.shape[0]"
]
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 84,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t2 = torch.arange(10, target.shape[0] + 1, 10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 85,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "120"
+ "tensor([10, 20, 30])"
]
},
- "execution_count": 27,
+ "execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "2 * 60"
+ "t2"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
- "import yaml"
+ "for start, stop in zip(t1, t2):\n",
+ " pred[start+1:stop] = 79"
]
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 90,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ 1, 1, 1, 1, 1, 81, 79, 79, 79, 79, 2, 1, 1, 1, 1, 81, 79, 79,\n",
+ " 79, 79, 1, 1, 1, 1, 1, 81, 79, 79, 79, 79])"
+ ]
+ },
+ "execution_count": 90,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "path = \"../training/experiments/cnn_transformer.yml\""
+ "pred"
]
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 88,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "SyntaxError",
+ "evalue": "invalid syntax (<ipython-input-88-b8a4aef86401>, line 1)",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;36m File \u001b[0;32m\"<ipython-input-88-b8a4aef86401>\"\u001b[0;36m, line \u001b[0;32m1\u001b[0m\n\u001b[0;31m [pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
+ ]
+ }
+ ],
+ "source": [
+ "[pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 6],\n",
+ " [ 7],\n",
+ " [ 8],\n",
+ " [ 9],\n",
+ " [16],\n",
+ " [17],\n",
+ " [18],\n",
+ " [19],\n",
+ " [26],\n",
+ " [27],\n",
+ " [28],\n",
+ " [29]])"
+ ]
+ },
+ "execution_count": 69,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pad_indcies"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": "only integer tensors of a single element can be converted to an index",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m<ipython-input-71-39b5cc3b1445>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpad_indcies\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mpad_indcies\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m79\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m: only integer tensors of a single element can be converted to an index"
+ ]
+ }
+ ],
+ "source": [
+ "pred[pad_indcies:pad_indcies] = 79"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([20])"
+ ]
+ },
+ "execution_count": 50,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pred.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([20])"
+ ]
+ },
+ "execution_count": 51,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "target.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 91,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.0"
+ ]
+ },
+ "execution_count": 91,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "accuracy(pred, target)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
- "with open(path, \"r\") as f:\n",
- " f = yaml.safe_load(f)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'experiment_group': 'Transformer Experiments',\n",
- " 'experiments': [{'train_args': {'transformer_model': True,\n",
- " 'batch_size': 16,\n",
- " 'max_epochs': 128,\n",
- " 'input_shape': [[1, 28, 952], [92]]},\n",
- " 'dataset': {'type': 'EmnistLinesDataset',\n",
- " 'args': {'subsample_fraction': None,\n",
- " 'transform': [{'type': 'ToPILImage', 'args': None},\n",
- " {'type': 'Resize', 'args': {'size': [28, 952]}},\n",
- " {'type': 'ToTensor', 'args': None}],\n",
- " 'max_length': 97,\n",
- " 'min_overlap': 0.0,\n",
- " 'max_overlap': 0.33,\n",
- " 'num_samples': 1,\n",
- " 'seed': 4711,\n",
- " 'init_token': '<sos>',\n",
- " 'pad_token': '_',\n",
- " 'eos_token': '<eos>',\n",
- " 'target_transform': [{'type': 'AddTokens',\n",
- " 'args': {'init_token': '<sos>',\n",
- " 'eos_token': '<eos>',\n",
- " 'pad_token': '_'}}]},\n",
- " 'train_args': {'num_workers': 8,\n",
- " 'train_fraction': 0.85,\n",
- " 'batch_size': 16}},\n",
- " 'model': 'VisionTransformerModel',\n",
- " 'metrics': ['accuracy'],\n",
- " 'network': {'type': 'CNNTransformer',\n",
- " 'args': {'backbone': 'DenseNet',\n",
- " 'backbone_args': {'growth_rate': 8,\n",
- " 'block_config': [4, 6, 8, 6],\n",
- " 'in_channels': 1,\n",
- " 'base_channels': 24,\n",
- " 'num_classes': 256,\n",
- " 'bn_size': 4,\n",
- " 'dropout_rate': 0.1,\n",
- " 'classifier': False,\n",
- " 'activation': 'elu'},\n",
- " 'num_encoder_layers': 3,\n",
- " 'num_decoder_layers': 3,\n",
- " 'hidden_dim': 256,\n",
- " 'vocab_size': 82,\n",
- " 'num_heads': 8,\n",
- " 'max_len': 99,\n",
- " 'expansion_dim': 512,\n",
- " 'mlp_dim': 256,\n",
- " 'spatial_dim': 357,\n",
- " 'dropout_rate': 0.1,\n",
- " 'trg_pad_index': 79,\n",
- " 'activation': 'gelu'}},\n",
- " 'criterion': {'type': 'CrossEntropyLoss', 'args': {'ignore_index': 79}},\n",
- " 'optimizer': {'type': 'AdamW',\n",
- " 'args': {'lr': 0.0003,\n",
- " 'betas': [0.9, 0.999],\n",
- " 'eps': 1e-08,\n",
- " 'weight_decay': 3e-06,\n",
- " 'amsgrad': False}},\n",
- " 'lr_scheduler': {'type': 'OneCycleLR',\n",
- " 'args': {'max_lr': 0.0007,\n",
- " 'epochs': 128,\n",
- " 'anneal_strategy': 'cos',\n",
- " 'pct_start': 0.475,\n",
- " 'cycle_momentum': True,\n",
- " 'base_momentum': 0.85,\n",
- " 'max_momentum': 0.9,\n",
- " 'div_factor': 10,\n",
- " 'final_div_factor': 10000,\n",
- " 'interval': 'step'}},\n",
- " 'callbacks': ['Checkpoint',\n",
- " 'ProgressBar',\n",
- " 'WandbCallback',\n",
- " 'WandbImageLogger'],\n",
- " 'callback_args': {'Checkpoint': {'monitor': 'val_loss', 'mode': 'min'},\n",
- " 'ProgressBar': {'epochs': 128},\n",
- " 'WandbCallback': {'log_batch_frequency': 10},\n",
- " 'WandbImageLogger': {'num_examples': 6}},\n",
- " 'test_metric': 'test_accuracy'}]}"
- ]
- },
- "execution_count": 27,
+ "acc = (pred == target).sum().float() / target.shape[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 93,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(0.9667)"
+ ]
+ },
+ "execution_count": 93,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "f"
+ "acc"
]
},
{