summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-08-20 22:18:35 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-08-20 22:18:35 +0200
commit1f459ba19422593de325983040e176f97cf4ffc0 (patch)
tree89fef442d5dbe0c83253e9566d1762f0704f64e2 /src
parent95cbdf5bc1cc9639febda23c28d8f464c998b214 (diff)
A lot of stuff working :D. ResNet implemented!
Diffstat (limited to 'src')
-rw-r--r--src/notebooks/00-testing-stuff-out.ipynb639
-rw-r--r--src/notebooks/01-look-at-emnist.ipynb25
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py43
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py9
-rw-r--r--src/text_recognizer/models/base.py45
-rw-r--r--src/text_recognizer/models/character_model.py8
-rw-r--r--src/text_recognizer/networks/__init__.py3
-rw-r--r--src/text_recognizer/networks/lenet.py17
-rw-r--r--src/text_recognizer/networks/misc.py20
-rw-r--r--src/text_recognizer/networks/mlp.py18
-rw-r--r--src/text_recognizer/networks/residual_network.py314
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.ptbin14485310 -> 14485362 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.ptbin1704174 -> 11625484 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.ptbin0 -> 28654593 bytes
-rw-r--r--src/training/experiments/sample_experiment.yml37
-rw-r--r--src/training/prepare_experiments.py6
-rw-r--r--src/training/run_experiment.py19
-rw-r--r--src/training/trainer/__init__.py2
-rw-r--r--src/training/trainer/callbacks/__init__.py (renamed from src/training/callbacks/__init__.py)2
-rw-r--r--src/training/trainer/callbacks/base.py (renamed from src/training/callbacks/base.py)50
-rw-r--r--src/training/trainer/callbacks/early_stopping.py (renamed from src/training/callbacks/early_stopping.py)5
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py (renamed from src/training/callbacks/lr_schedulers.py)12
-rw-r--r--src/training/trainer/callbacks/progress_bar.py61
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py (renamed from src/training/callbacks/wandb_callbacks.py)8
-rw-r--r--src/training/trainer/population_based_training/__init__.py (renamed from src/training/population_based_training/__init__.py)0
-rw-r--r--src/training/trainer/population_based_training/population_based_training.py (renamed from src/training/population_based_training/population_based_training.py)0
-rw-r--r--src/training/trainer/train.py (renamed from src/training/train.py)87
-rw-r--r--src/training/trainer/util.py (renamed from src/training/util.py)0
28 files changed, 1212 insertions, 218 deletions
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb
index 49ca4c4..3f008c3 100644
--- a/src/notebooks/00-testing-stuff-out.ipynb
+++ b/src/notebooks/00-testing-stuff-out.ipynb
@@ -2,11 +2,120 @@
"cells": [
{
"cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.cuda.is_available()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.nn.modules.activation.SELU"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.nn.SELU"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
- "import torch"
+ "a = \"nNone\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b = a or \"relu\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'nnone'"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "b.lower()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'nNone'"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "b"
]
},
{
@@ -986,28 +1095,16 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 51,
"metadata": {},
- "outputs": [
- {
- "ename": "ModuleNotFoundError",
- "evalue": "No module named 'tqdm.auto.tqdm'; 'tqdm.auto' is not a package",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-20-68e3c8bf3e1f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtqdm\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtqdm_auto\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'tqdm.auto.tqdm'; 'tqdm.auto' is not a package"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "import tqdm.auto.tqdm as tqdm_auto"
+ "import tqdm"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 52,
"metadata": {},
"outputs": [
{
@@ -1016,7 +1113,7 @@
"tqdm.notebook.tqdm_notebook"
]
},
- "execution_count": 19,
+ "execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
@@ -1027,25 +1124,50 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tqdm.auto.tqdm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"def test():\n",
- " for i in range(9):\n",
+ " for i in tqdm.auto.tqdm(range(9)):\n",
" pass\n",
- " print(i)"
+ " print(i)\n",
+ " "
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 55,
"metadata": {},
"outputs": [
{
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e1d3b25d4ee141e882e316ec54e79d60",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
"name": "stdout",
"output_type": "stream",
"text": [
+ "\n",
"8\n"
]
}
@@ -1056,6 +1178,479 @@
},
{
"cell_type": "code",
+ "execution_count": 58,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from time import sleep"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "41b743273ce14236bcb65782dbcd2e75",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "pbar = tqdm.auto.tqdm([\"a\", \"b\", \"c\", \"d\"], leave=True)\n",
+ "for char in pbar:\n",
+ " pbar.set_description(\"Processing %s\" % char)\n",
+ "# pbar.set_prefix()\n",
+ " sleep(0.25)\n",
+ "pbar.set_postfix({\"hej\": 0.32})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 83,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pbar.close()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 96,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "cb5ad8d6109f4b1495b8fc7422bafd01",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "with tqdm.auto.tqdm(total=10, bar_format=\"{postfix[0]} {postfix[1][value]:>8.2g}\",\n",
+ " postfix=[\"Batch\", dict(value=0)]) as t:\n",
+ " for i in range(10):\n",
+ " sleep(0.1)\n",
+ "# t.postfix[2][\"value\"] = 3 \n",
+ " t.postfix[1][\"value\"] = i / 2\n",
+ " t.update()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 99,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0b341d49ad074823881e84a538bcad0c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "with tqdm.auto.tqdm(total=100, leave=True) as pbar:\n",
+ " for i in range(2):\n",
+ " for i in range(10):\n",
+ " sleep(0.1)\n",
+ " pbar.update(10)\n",
+ " pbar.set_postfix({\"adaf\": 23})\n",
+ " pbar.set_postfix({\"hej\": 0.32})\n",
+ " pbar.reset()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, Encoder, ResidualNetwork"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "IdentityBlock(\n",
+ " (blocks): Identity()\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Identity()\n",
+ ")"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "IdentityBlock(32, 64)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ResidualBlock(\n",
+ " (blocks): Identity()\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Sequential(\n",
+ " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ResidualBlock(32, 64)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BasicBlock(\n",
+ " (blocks): Sequential(\n",
+ " (0): Sequential(\n",
+ " (0): Conv2dAuto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Sequential(\n",
+ " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Sequential(\n",
+ " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "dummy = torch.ones((1, 32, 224, 224))\n",
+ "\n",
+ "block = BasicBlock(32, 64)\n",
+ "block(dummy).shape\n",
+ "print(block)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BottleNeckBlock(\n",
+ " (blocks): Sequential(\n",
+ " (0): Sequential(\n",
+ " (0): Conv2dAuto(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Sequential(\n",
+ " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (3): ReLU(inplace=True)\n",
+ " (4): Sequential(\n",
+ " (0): Conv2dAuto(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (activation_fn): ReLU(inplace=True)\n",
+ " (shortcut): Sequential(\n",
+ " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "dummy = torch.ones((1, 32, 10, 10))\n",
+ "\n",
+ "block = BottleNeckBlock(32, 64)\n",
+ "block(dummy).shape\n",
+ "print(block)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 128, 24, 24])"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dummy = torch.ones((1, 64, 48, 48))\n",
+ "\n",
+ "layer = ResidualLayer(64, 128, block=BasicBlock, num_blocks=3)\n",
+ "layer(dummy).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(64, 128), (128, 256), (256, 512)]"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "blocks_sizes=[64, 128, 256, 512]\n",
+ "list(zip(blocks_sizes, blocks_sizes[1:]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "e = Encoder(depths=[1, 1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torchsummary import summary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "----------------------------------------------------------------\n",
+ " Layer (type) Output Shape Param #\n",
+ "================================================================\n",
+ " Conv2d-1 [-1, 32, 15, 15] 800\n",
+ " BatchNorm2d-2 [-1, 32, 15, 15] 64\n",
+ " ReLU-3 [-1, 32, 15, 15] 0\n",
+ " MaxPool2d-4 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-6 [-1, 32, 8, 8] 64\n",
+ " ReLU-7 [-1, 32, 8, 8] 0\n",
+ " ReLU-8 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-10 [-1, 32, 8, 8] 64\n",
+ " ReLU-11 [-1, 32, 8, 8] 0\n",
+ " ReLU-12 [-1, 32, 8, 8] 0\n",
+ " BasicBlock-13 [-1, 32, 8, 8] 0\n",
+ " ResidualLayer-14 [-1, 32, 8, 8] 0\n",
+ " Conv2d-15 [-1, 64, 4, 4] 2,048\n",
+ " BatchNorm2d-16 [-1, 64, 4, 4] 128\n",
+ " Conv2dAuto-17 [-1, 64, 4, 4] 18,432\n",
+ " BatchNorm2d-18 [-1, 64, 4, 4] 128\n",
+ " ReLU-19 [-1, 64, 4, 4] 0\n",
+ " ReLU-20 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-21 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-22 [-1, 64, 4, 4] 128\n",
+ " ReLU-23 [-1, 64, 4, 4] 0\n",
+ " ReLU-24 [-1, 64, 4, 4] 0\n",
+ " BasicBlock-25 [-1, 64, 4, 4] 0\n",
+ " ResidualLayer-26 [-1, 64, 4, 4] 0\n",
+ "================================================================\n",
+ "Total params: 77,152\n",
+ "Trainable params: 77,152\n",
+ "Non-trainable params: 0\n",
+ "----------------------------------------------------------------\n",
+ "Input size (MB): 0.00\n",
+ "Forward/backward pass size (MB): 0.43\n",
+ "Params size (MB): 0.29\n",
+ "Estimated Total Size (MB): 0.73\n",
+ "----------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "summary(e, (1, 28, 28), device=\"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "resnet = ResidualNetwork(1, 80, activation=\"selu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "----------------------------------------------------------------\n",
+ " Layer (type) Output Shape Param #\n",
+ "================================================================\n",
+ " Conv2d-1 [-1, 32, 15, 15] 800\n",
+ " BatchNorm2d-2 [-1, 32, 15, 15] 64\n",
+ " SELU-3 [-1, 32, 15, 15] 0\n",
+ " MaxPool2d-4 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-6 [-1, 32, 8, 8] 64\n",
+ " SELU-7 [-1, 32, 8, 8] 0\n",
+ " SELU-8 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-10 [-1, 32, 8, 8] 64\n",
+ " SELU-11 [-1, 32, 8, 8] 0\n",
+ " SELU-12 [-1, 32, 8, 8] 0\n",
+ " BasicBlock-13 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-14 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-15 [-1, 32, 8, 8] 64\n",
+ " SELU-16 [-1, 32, 8, 8] 0\n",
+ " SELU-17 [-1, 32, 8, 8] 0\n",
+ " Conv2dAuto-18 [-1, 32, 8, 8] 9,216\n",
+ " BatchNorm2d-19 [-1, 32, 8, 8] 64\n",
+ " SELU-20 [-1, 32, 8, 8] 0\n",
+ " SELU-21 [-1, 32, 8, 8] 0\n",
+ " BasicBlock-22 [-1, 32, 8, 8] 0\n",
+ " ResidualLayer-23 [-1, 32, 8, 8] 0\n",
+ " Conv2d-24 [-1, 64, 4, 4] 2,048\n",
+ " BatchNorm2d-25 [-1, 64, 4, 4] 128\n",
+ " Conv2dAuto-26 [-1, 64, 4, 4] 18,432\n",
+ " BatchNorm2d-27 [-1, 64, 4, 4] 128\n",
+ " SELU-28 [-1, 64, 4, 4] 0\n",
+ " SELU-29 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-30 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-31 [-1, 64, 4, 4] 128\n",
+ " SELU-32 [-1, 64, 4, 4] 0\n",
+ " SELU-33 [-1, 64, 4, 4] 0\n",
+ " BasicBlock-34 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-35 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-36 [-1, 64, 4, 4] 128\n",
+ " SELU-37 [-1, 64, 4, 4] 0\n",
+ " SELU-38 [-1, 64, 4, 4] 0\n",
+ " Conv2dAuto-39 [-1, 64, 4, 4] 36,864\n",
+ " BatchNorm2d-40 [-1, 64, 4, 4] 128\n",
+ " SELU-41 [-1, 64, 4, 4] 0\n",
+ " SELU-42 [-1, 64, 4, 4] 0\n",
+ " BasicBlock-43 [-1, 64, 4, 4] 0\n",
+ " ResidualLayer-44 [-1, 64, 4, 4] 0\n",
+ " Encoder-45 [-1, 64, 4, 4] 0\n",
+ " Reduce-46 [-1, 64] 0\n",
+ " Linear-47 [-1, 80] 5,200\n",
+ " Decoder-48 [-1, 80] 0\n",
+ "================================================================\n",
+ "Total params: 174,896\n",
+ "Trainable params: 174,896\n",
+ "Non-trainable params: 0\n",
+ "----------------------------------------------------------------\n",
+ "Input size (MB): 0.00\n",
+ "Forward/backward pass size (MB): 0.65\n",
+ "Params size (MB): 0.67\n",
+ "Estimated Total Size (MB): 1.32\n",
+ "----------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "summary(resnet, (1, 28, 28), device=\"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
diff --git a/src/notebooks/01-look-at-emnist.ipynb b/src/notebooks/01-look-at-emnist.ipynb
index a68b418..8648afb 100644
--- a/src/notebooks/01-look-at-emnist.ipynb
+++ b/src/notebooks/01-look-at-emnist.ipynb
@@ -31,12 +31,31 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
- "dataset = EmnistDataset()\n",
- "dataset.load_emnist_dataset()"
+ "dataset = EmnistDataset()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Tensor"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "type(dataset.data)"
]
},
{
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 96f84e5..49ebad3 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -8,6 +8,7 @@ from loguru import logger
import numpy as np
from PIL import Image
import torch
+from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import EMNIST
from torchvision.transforms import Compose, Normalize, ToTensor
@@ -183,12 +184,8 @@ class EmnistDataset(Dataset):
self.input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
- # Placeholders
- self.data = None
- self.targets = None
-
# Load dataset.
- self.load_emnist_dataset()
+ self.data, self.targets = self.load_emnist_dataset()
@property
def mapper(self) -> EmnistMapper:
@@ -199,9 +196,7 @@ class EmnistDataset(Dataset):
"""Returns the length of the dataset."""
return len(self.data)
- def __getitem__(
- self, index: Union[int, torch.Tensor]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches samples from the dataset.
Args:
@@ -239,11 +234,13 @@ class EmnistDataset(Dataset):
f"Mapping: {self.mapper.mapping}\n"
)
- def _sample_to_balance(self) -> None:
+ def _sample_to_balance(
+ self, data: Tensor, targets: Tensor
+ ) -> Tuple[np.ndarray, np.ndarray]:
"""Because the dataset is not balanced, we take at most the mean number of instances per class."""
np.random.seed(self.seed)
- x = self.data
- y = self.targets
+ x = data
+ y = targets
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_indices = []
for label in np.unique(y.flatten()):
@@ -253,20 +250,22 @@ class EmnistDataset(Dataset):
indices = np.concatenate(all_sampled_indices)
x_sampled = x[indices]
y_sampled = y[indices]
- self.data = x_sampled
- self.targets = y_sampled
+ data = x_sampled
+ targets = y_sampled
+ return data, targets
- def _subsample(self) -> None:
+ def _subsample(self, data: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""Subsamples the dataset to the specified fraction."""
- x = self.data
- y = self.targets
+ x = data
+ y = targets
num_samples = int(x.shape[0] * self.subsample_fraction)
x_sampled = x[:num_samples]
y_sampled = y[:num_samples]
self.data = x_sampled
self.targets = y_sampled
+ return data, targets
- def load_emnist_dataset(self) -> None:
+ def load_emnist_dataset(self) -> Tuple[Tensor, Tensor]:
"""Fetch the EMNIST dataset."""
dataset = EMNIST(
root=DATA_DIRNAME,
@@ -277,11 +276,13 @@ class EmnistDataset(Dataset):
target_transform=None,
)
- self.data = dataset.data
- self.targets = dataset.targets
+ data = dataset.data
+ targets = dataset.targets
if self.sample_to_balance:
- self._sample_to_balance()
+ data, targets = self._sample_to_balance(data, targets)
if self.subsample_fraction is not None:
- self._subsample()
+ data, targets = self._subsample(data, targets)
+
+ return data, targets
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index d64a991..b0617f5 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -8,6 +8,7 @@ import h5py
from loguru import logger
import numpy as np
import torch
+from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Normalize, ToTensor
@@ -87,16 +88,14 @@ class EmnistLinesDataset(Dataset):
"""Returns the length of the dataset."""
return len(self.data)
- def __getitem__(
- self, index: Union[int, torch.Tensor]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches data, target pair of the dataset for a given and index or indices.
Args:
- index (Union[int, torch.Tensor]): Either a list or int of indices/index.
+ index (Union[int, Tensor]): Either a list or int of indices/index.
Returns:
- Tuple[torch.Tensor, torch.Tensor]: Data target pair.
+ Tuple[Tensor, Tensor]: Data target pair.
"""
if torch.is_tensor(index):
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 6d40b49..74fd223 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -53,8 +53,8 @@ class Model(ABC):
"""
- # Fetch data loaders and dataset info.
- dataset_name, self._data_loaders, self._mapper = self._load_data_loader(
+ # Configure data loaders and dataset info.
+ dataset_name, self._data_loaders, self._mapper = self._configure_data_loader(
data_loader_args
)
self._input_shape = self._mapper.input_shape
@@ -70,16 +70,19 @@ class Model(ABC):
else:
self._device = device
- # Load network.
- self._network, self._network_args = self._load_network(network_fn, network_args)
+ # Configure network.
+ self._network, self._network_args = self._configure_network(
+ network_fn, network_args
+ )
# To device.
self._network.to(self._device)
- # Set training objects.
- self._criterion = self._load_criterion(criterion, criterion_args)
- self._optimizer = self._load_optimizer(optimizer, optimizer_args)
- self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args)
+ # Configure training objects.
+ self._criterion = self._configure_criterion(criterion, criterion_args)
+ self._optimizer, self._lr_scheduler = self._configure_optimizers(
+ optimizer, optimizer_args, lr_scheduler, lr_scheduler_args
+ )
# Experiment directory.
self.model_dir = None
@@ -87,7 +90,7 @@ class Model(ABC):
# Flag for stopping training.
self.stop_training = False
- def _load_data_loader(
+ def _configure_data_loader(
self, data_loader_args: Optional[Dict]
) -> Tuple[str, Dict, EmnistMapper]:
"""Loads data loader, dataset name, and dataset mapper."""
@@ -102,7 +105,7 @@ class Model(ABC):
data_loaders = None
return dataset_name, data_loaders, mapper
- def _load_network(
+ def _configure_network(
self, network_fn: Type[nn.Module], network_args: Optional[Dict]
) -> Tuple[Type[nn.Module], Dict]:
"""Loads the network."""
@@ -113,7 +116,7 @@ class Model(ABC):
network = network_fn(**network_args)
return network, network_args
- def _load_criterion(
+ def _configure_criterion(
self, criterion: Optional[Callable], criterion_args: Optional[Dict]
) -> Optional[Callable]:
"""Loads the criterion."""
@@ -123,27 +126,27 @@ class Model(ABC):
_criterion = None
return _criterion
- def _load_optimizer(
- self, optimizer: Optional[Callable], optimizer_args: Optional[Dict]
- ) -> Optional[Callable]:
- """Loads the optimizer."""
+ def _configure_optimizers(
+ self,
+ optimizer: Optional[Callable],
+ optimizer_args: Optional[Dict],
+ lr_scheduler: Optional[Callable],
+ lr_scheduler_args: Optional[Dict],
+ ) -> Tuple[Optional[Callable], Optional[Callable]]:
+ """Loads the optimizers."""
if optimizer is not None:
_optimizer = optimizer(self._network.parameters(), **optimizer_args)
else:
_optimizer = None
- return _optimizer
- def _load_lr_scheduler(
- self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict]
- ) -> Optional[Callable]:
- """Loads learning rate scheduler."""
if self._optimizer and lr_scheduler is not None:
if "OneCycleLR" in str(lr_scheduler):
lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
_lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
else:
_lr_scheduler = None
- return _lr_scheduler
+
+ return _optimizer, _lr_scheduler
@property
def __name__(self) -> str:
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 0a0ab2d..0fd7afd 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -44,6 +44,7 @@ class CharacterModel(Model):
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
+ @torch.no_grad()
def predict_on_image(
self, image: Union[np.ndarray, torch.Tensor]
) -> Tuple[str, float]:
@@ -64,10 +65,9 @@ class CharacterModel(Model):
# If the image is an unscaled tensor.
image = image.type("torch.FloatTensor") / 255
- with torch.no_grad():
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- logits = self.network(image)
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ logits = self.network(image)
prediction = self.softmax(logits.data.squeeze())
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index e6b6946..a83ca35 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,5 +1,6 @@
"""Network modules."""
from .lenet import LeNet
from .mlp import MLP
+from .residual_network import ResidualNetwork
-__all__ = ["MLP", "LeNet"]
+__all__ = ["MLP", "LeNet", "ResidualNetwork"]
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index cbc58fc..91d3f2c 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
+from text_recognizer.networks.misc import activation_function
+
class LeNet(nn.Module):
"""LeNet network."""
@@ -16,8 +18,7 @@ class LeNet(nn.Module):
hidden_size: Tuple[int, ...] = (9216, 128),
dropout_rate: float = 0.2,
output_size: int = 10,
- activation_fn: Optional[Callable] = None,
- activation_fn_args: Optional[Dict] = None,
+ activation_fn: Optional[str] = "relu",
) -> None:
"""The LeNet network.
@@ -28,18 +29,12 @@ class LeNet(nn.Module):
Defaults to (9216, 128).
dropout_rate (float): The dropout rate. Defaults to 0.2.
output_size (int): Number of classes. Defaults to 10.
- activation_fn (Optional[Callable]): The non-linear activation function. Defaults to
- nn.ReLU(inplace).
- activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None.
+ activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu.
"""
super().__init__()
- if activation_fn is not None:
- activation_fn_args = activation_fn_args or {}
- activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
- else:
- activation_fn = nn.ReLU(inplace=True)
+ activation_fn = activation_function(activation_fn)
self.layers = [
nn.Conv2d(
@@ -66,7 +61,7 @@ class LeNet(nn.Module):
self.layers = nn.Sequential(*self.layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward."""
+ """The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
index 2fbab8f..6f61b5d 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/misc.py
@@ -1,9 +1,9 @@
"""Miscellaneous neural network functionality."""
-from typing import Tuple
+from typing import Tuple, Type
from einops import rearrange
import torch
-from torch.nn import Unfold
+from torch import nn
def sliding_window(
@@ -20,10 +20,24 @@ def sliding_window(
torch.Tensor: A tensor with the shape (batch, patches, height, width).
"""
- unfold = Unfold(kernel_size=patch_size, stride=stride)
+ unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
# Preform the slidning window, unsqueeze as the channel dimesion is lost.
patches = unfold(images).unsqueeze(1)
patches = rearrange(
patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1]
)
return patches
+
+
+def activation_function(activation: str) -> Type[nn.Module]:
+ """Returns the callable activation function."""
+ activation_fns = nn.ModuleDict(
+ [
+ ["gelu", nn.GELU()],
+ ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
+ ["none", nn.Identity()],
+ ["relu", nn.ReLU(inplace=True)],
+ ["selu", nn.SELU(inplace=True)],
+ ]
+ )
+ return activation_fns[activation.lower()]
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index ac2c825..acebdaa 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
+from text_recognizer.networks.misc import activation_function
+
class MLP(nn.Module):
"""Multi layered perceptron network."""
@@ -16,8 +18,7 @@ class MLP(nn.Module):
hidden_size: Union[int, List] = 128,
num_layers: int = 3,
dropout_rate: float = 0.2,
- activation_fn: Optional[Callable] = None,
- activation_fn_args: Optional[Dict] = None,
+ activation_fn: str = "relu",
) -> None:
"""Initialization of the MLP network.
@@ -27,18 +28,13 @@ class MLP(nn.Module):
hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
num_layers (int): The number of hidden layers. Defaults to 3.
dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
- activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to
- None.
- activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None.
+ activation_fn (str): Name of the activation function in the hidden layers. Defaults to
+ relu.
"""
super().__init__()
- if activation_fn is not None:
- activation_fn_args = activation_fn_args or {}
- activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
- else:
- activation_fn = nn.ReLU(inplace=True)
+ activation_fn = activation_function(activation_fn)
if isinstance(hidden_size, int):
hidden_size = [hidden_size] * num_layers
@@ -65,7 +61,7 @@ class MLP(nn.Module):
self.layers = nn.Sequential(*self.layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward."""
+ """The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 23394b0..47e351a 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -1 +1,315 @@
"""Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Rearrange, Reduce
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.misc import activation_function
+
+
+class Conv2dAuto(nn.Conv2d):
+ """Convolution with auto padding based on kernel size."""
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
+
+
+def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
+ """3x3 convolution with batch norm."""
+ conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
+ return nn.Sequential(
+ conv3x3(in_channels, out_channels, *args, **kwargs),
+ nn.BatchNorm2d(out_channels),
+ )
+
+
+class IdentityBlock(nn.Module):
+ """Residual with identity block."""
+
+ def __init__(
+ self, in_channels: int, out_channels: int, activation: str = "relu"
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.blocks = nn.Identity()
+ self.activation_fn = activation_function(activation)
+ self.shortcut = nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ residual = x
+ if self.apply_shortcut:
+ residual = self.shortcut(x)
+ x = self.blocks(x)
+ x += residual
+ x = self.activation_fn(x)
+ return x
+
+ @property
+ def apply_shortcut(self) -> bool:
+ """Check if shortcut should be applied."""
+ return self.in_channels != self.out_channels
+
+
+class ResidualBlock(IdentityBlock):
+ """Residual with nonlinear shortcut."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expansion: int = 1,
+ downsampling: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ """Short summary.
+
+ Args:
+ in_channels (int): Number of in channels.
+ out_channels (int): umber of out channels.
+ expansion (int): Expansion factor of the out channels. Defaults to 1.
+ downsampling (int): Downsampling factor used in stride. Defaults to 1.
+ *args (type): Extra arguments.
+ **kwargs (type): Extra key value arguments.
+
+ """
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.expansion = expansion
+ self.downsampling = downsampling
+
+ self.shortcut = (
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.expanded_channels,
+ kernel_size=1,
+ stride=self.downsampling,
+ bias=False,
+ ),
+ nn.BatchNorm2d(self.expanded_channels),
+ )
+ if self.apply_shortcut
+ else None
+ )
+
+ @property
+ def expanded_channels(self) -> int:
+ """Computes the expanded output channels."""
+ return self.out_channels * self.expansion
+
+ @property
+ def apply_shortcut(self) -> bool:
+ """Check if shortcut should be applied."""
+ return self.in_channels != self.expanded_channels
+
+
+class BasicBlock(ResidualBlock):
+ """Basic ResNet block."""
+
+ expansion = 1
+
+ def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.blocks = nn.Sequential(
+ conv_bn(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ bias=False,
+ stride=self.downsampling,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.expanded_channels,
+ bias=False,
+ ),
+ )
+
+
+class BottleNeckBlock(ResidualBlock):
+ """Bottleneck block to increase depth while minimizing parameter size."""
+
+ expansion = 4
+
+ def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.blocks = nn.Sequential(
+ conv_bn(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=1,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ stride=self.downsampling,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.expanded_channels,
+ kernel_size=1,
+ ),
+ )
+
+
+class ResidualLayer(nn.Module):
+ """ResNet layer."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ block: BasicBlock = BasicBlock,
+ num_blocks: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+ downsampling = 2 if in_channels != out_channels else 1
+ self.blocks = nn.Sequential(
+ block(
+ in_channels, out_channels, *args, **kwargs, downsampling=downsampling
+ ),
+ *[
+ block(
+ out_channels * block.expansion,
+ out_channels,
+ downsampling=1,
+ *args,
+ **kwargs
+ )
+ for _ in range(num_blocks - 1)
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x = self.blocks(x)
+ return x
+
+
+class Encoder(nn.Module):
+ """Encoder network."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ block_sizes: List[int] = (32, 64),
+ depths: List[int] = (2, 2),
+ activation: str = "relu",
+ block: Type[nn.Module] = BasicBlock,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+
+ self.block_sizes = block_sizes
+ self.depths = depths
+ self.activation = activation
+
+ self.gate = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=self.block_sizes[0],
+ kernel_size=3,
+ stride=2,
+ padding=3,
+ bias=False,
+ ),
+ nn.BatchNorm2d(self.block_sizes[0]),
+ activation_function(self.activation),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+
+ self.blocks = self._configure_blocks(block)
+
+ def _configure_blocks(
+ self, block: Type[nn.Module], *args, **kwargs
+ ) -> nn.Sequential:
+ channels = [self.block_sizes[0]] + list(
+ zip(self.block_sizes, self.block_sizes[1:])
+ )
+ blocks = [
+ ResidualLayer(
+ in_channels=channels[0],
+ out_channels=channels[0],
+ num_blocks=self.depths[0],
+ block=block,
+ activation=self.activation,
+ *args,
+ **kwargs
+ )
+ ]
+ blocks += [
+ ResidualLayer(
+ in_channels=in_channels * block.expansion,
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ block=block,
+ activation=self.activation,
+ *args,
+ **kwargs
+ )
+ for (in_channels, out_channels), num_blocks in zip(
+ channels[1:], self.depths[1:]
+ )
+ ]
+
+ return nn.Sequential(*blocks)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
+ x = self.gate(x)
+ return self.blocks(x)
+
+
+class Decoder(nn.Module):
+ """Classification head."""
+
+ def __init__(self, in_features: int, num_classes: int = 80) -> None:
+ super().__init__()
+ self.decoder = nn.Sequential(
+ Reduce("b c h w -> b c", "mean"),
+ nn.Linear(in_features=in_features, out_features=num_classes),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ return self.decoder(x)
+
+
+class ResidualNetwork(nn.Module):
+ """Full residual network."""
+
+ def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None:
+ super().__init__()
+ self.encoder = Encoder(in_channels, *args, **kwargs)
+ self.decoder = Decoder(
+ in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels,
+ num_classes=num_classes,
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x = self.encoder(x)
+ x = self.decoder(x)
+ return x
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt
index 81ef9be..676eb44 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
index 49bd166..86cf103 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
new file mode 100644
index 0000000..008beb2
--- /dev/null
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
Binary files differ
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml
index 355305c..bae02ac 100644
--- a/src/training/experiments/sample_experiment.yml
+++ b/src/training/experiments/sample_experiment.yml
@@ -9,25 +9,32 @@ experiments:
seed: 4711
data_loader_args:
splits: [train, val]
- batch_size: 256
shuffle: true
num_workers: 8
cuda: true
model: CharacterModel
metrics: [accuracy]
- network: MLP
+ # network: MLP
+ # network_args:
+ # input_size: 784
+ # hidden_size: 512
+ # output_size: 80
+ # num_layers: 3
+ # dropout_rate: 0
+ # activation_fn: SELU
+ network: ResidualNetwork
network_args:
- input_size: 784
- output_size: 62
- num_layers: 3
- activation_fn: GELU
+ in_channels: 1
+ num_classes: 80
+ depths: [1, 1]
+ block_sizes: [128, 256]
# network: LeNet
# network_args:
# output_size: 62
# activation_fn: GELU
train_args:
batch_size: 256
- epochs: 16
+ epochs: 32
criterion: CrossEntropyLoss
criterion_args:
weight: null
@@ -43,20 +50,24 @@ experiments:
# centered: false
optimizer: AdamW
optimizer_args:
- lr: 1.e-2
+ lr: 1.e-03
betas: [0.9, 0.999]
eps: 1.e-08
- weight_decay: 0
+ # weight_decay: 5.e-4
amsgrad: false
# lr_scheduler: null
lr_scheduler: OneCycleLR
lr_scheduler_args:
- max_lr: 1.e-3
- epochs: 16
- callbacks: [Checkpoint, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR]
+ max_lr: 1.e-03
+ epochs: 32
+ anneal_strategy: linear
+ callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR]
callback_args:
Checkpoint:
monitor: val_accuracy
+ ProgressBar:
+ epochs: 32
+ log_batch_frequency: 100
EarlyStopping:
monitor: val_loss
min_delta: 0.0
@@ -68,5 +79,5 @@ experiments:
num_examples: 4
OneCycleLR:
null
- verbosity: 2 # 0, 1, 2
+ verbosity: 1 # 0, 1, 2
resume_experiment: null
diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py
index 97c0304..4c3f9ba 100644
--- a/src/training/prepare_experiments.py
+++ b/src/training/prepare_experiments.py
@@ -7,11 +7,11 @@ from loguru import logger
import yaml
-# flake8: noqa: S404,S607,S603
def run_experiments(experiments_filename: str) -> None:
"""Run experiment from file."""
with open(experiments_filename) as f:
experiments_config = yaml.safe_load(f)
+
num_experiments = len(experiments_config["experiments"])
for index in range(num_experiments):
experiment_config = experiments_config["experiments"][index]
@@ -27,10 +27,10 @@ def run_experiments(experiments_filename: str) -> None:
type=str,
help="Filename of Yaml file of experiments to run.",
)
-def main(experiments_filename: str) -> None:
+def run_cli(experiments_filename: str) -> None:
"""Parse command-line arguments and run experiments from provided file."""
run_experiments(experiments_filename)
if __name__ == "__main__":
- main()
+ run_cli()
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index d278dc2..8c063ff 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -6,18 +6,20 @@ import json
import os
from pathlib import Path
import re
-from typing import Callable, Dict, Tuple
+from typing import Callable, Dict, Tuple, Type
import click
from loguru import logger
import torch
from tqdm import tqdm
-from training.callbacks import CallbackList
from training.gpu_manager import GPUManager
-from training.train import Trainer
+from training.trainer.callbacks import CallbackList
+from training.trainer.train import Trainer
import wandb
import yaml
+from text_recognizer.models import Model
+
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
@@ -35,7 +37,7 @@ def get_level(experiment_config: Dict) -> int:
return 10
-def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path:
+def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path:
"""Create new experiment."""
EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
experiment_dir = EXPERIMENTS_DIRNAME / model.__name__
@@ -67,6 +69,8 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
"""Loads all modules and arguments."""
# Import the data loader arguments.
data_loader_args = experiment_config.get("data_loader_args", {})
+ train_args = experiment_config.get("train_args", {})
+ data_loader_args["batch_size"] = train_args["batch_size"]
data_loader_args["dataset"] = experiment_config["dataset"]
data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {})
@@ -94,7 +98,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
optimizer_args = experiment_config.get("optimizer_args", {})
# Callbacks
- callback_modules = importlib.import_module("training.callbacks")
+ callback_modules = importlib.import_module("training.trainer.callbacks")
callbacks = [
getattr(callback_modules, callback)(
**check_args(experiment_config["callback_args"][callback])
@@ -208,6 +212,7 @@ def run_experiment(
with open(str(config_path), "w") as f:
yaml.dump(experiment_config, f)
+ # Train the model.
trainer = Trainer(
model=model,
model_dir=model_dir,
@@ -247,7 +252,7 @@ def run_experiment(
@click.option(
"--nowandb", is_flag=False, help="If true, do not use wandb for this run."
)
-def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
+def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
"""Run experiment."""
if gpu < 0:
gpu_manager = GPUManager(True)
@@ -260,4 +265,4 @@ def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
if __name__ == "__main__":
- main()
+ run_cli()
diff --git a/src/training/trainer/__init__.py b/src/training/trainer/__init__.py
new file mode 100644
index 0000000..de41bfb
--- /dev/null
+++ b/src/training/trainer/__init__.py
@@ -0,0 +1,2 @@
+"""Trainer modules."""
+from .train import Trainer
diff --git a/src/training/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index fbcc285..5942276 100644
--- a/src/training/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -2,6 +2,7 @@
from .base import Callback, CallbackList, Checkpoint
from .early_stopping import EarlyStopping
from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR
+from .progress_bar import ProgressBar
from .wandb_callbacks import WandbCallback, WandbImageLogger
__all__ = [
@@ -14,6 +15,7 @@ __all__ = [
"CyclicLR",
"MultiStepLR",
"OneCycleLR",
+ "ProgressBar",
"ReduceLROnPlateau",
"StepLR",
]
diff --git a/src/training/callbacks/base.py b/src/training/trainer/callbacks/base.py
index e0d91e6..8df94f3 100644
--- a/src/training/callbacks/base.py
+++ b/src/training/trainer/callbacks/base.py
@@ -1,7 +1,7 @@
"""Metaclass for callback functions."""
from enum import Enum
-from typing import Callable, Dict, List, Type, Union
+from typing import Callable, Dict, List, Optional, Type, Union
from loguru import logger
import numpy as np
@@ -36,27 +36,29 @@ class Callback:
"""Called when fit ends."""
pass
- def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch. Only used in training mode."""
pass
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch. Only used in training mode."""
pass
- def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
pass
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
pass
- def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_begin(
+ self, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Called at the beginning of an epoch."""
pass
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
pass
@@ -102,18 +104,18 @@ class CallbackList:
for callback in self._callbacks:
callback.on_fit_end()
- def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
for callback in self._callbacks:
callback.on_epoch_begin(epoch, logs)
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
for callback in self._callbacks:
callback.on_epoch_end(epoch, logs)
def _call_batch_hook(
- self, mode: str, hook: str, batch: int, logs: Dict = {}
+ self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
) -> None:
"""Helper function for all batch_{begin | end} methods."""
if hook == "begin":
@@ -123,39 +125,45 @@ class CallbackList:
else:
raise ValueError(f"Unrecognized hook {hook}.")
- def _call_batch_begin_hook(self, mode: str, batch: int, logs: Dict = {}) -> None:
+ def _call_batch_begin_hook(
+ self, mode: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Helper function for all `on_*_batch_begin` methods."""
hook_name = f"on_{mode}_batch_begin"
self._call_batch_hook_helper(hook_name, batch, logs)
- def _call_batch_end_hook(self, mode: str, batch: int, logs: Dict = {}) -> None:
+ def _call_batch_end_hook(
+ self, mode: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Helper function for all `on_*_batch_end` methods."""
hook_name = f"on_{mode}_batch_end"
self._call_batch_hook_helper(hook_name, batch, logs)
def _call_batch_hook_helper(
- self, hook_name: str, batch: int, logs: Dict = {}
+ self, hook_name: str, batch: int, logs: Optional[Dict] = None
) -> None:
"""Helper function for `on_*_batch_begin` methods."""
for callback in self._callbacks:
hook = getattr(callback, hook_name)
hook(batch, logs)
- def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch)
+ self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "end", batch)
+ self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)
- def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_begin(
+ self, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch)
+ self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch)
+ self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)
def __iter__(self) -> iter:
"""Iter function for callback list."""
diff --git a/src/training/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py
index c9b7907..02b431f 100644
--- a/src/training/callbacks/early_stopping.py
+++ b/src/training/trainer/callbacks/early_stopping.py
@@ -4,7 +4,8 @@ from typing import Dict, Union
from loguru import logger
import numpy as np
import torch
-from training.callbacks import Callback
+from torch import Tensor
+from training.trainer.callbacks import Callback
class EarlyStopping(Callback):
@@ -95,7 +96,7 @@ class EarlyStopping(Callback):
f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
)
- def get_monitor_value(self, logs: Dict) -> Union[torch.Tensor, None]:
+ def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]:
"""Extracts the monitor value."""
monitor_value = logs.get(self.monitor)
if monitor_value is None:
diff --git a/src/training/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index 00c7e9b..ba2226a 100644
--- a/src/training/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -1,7 +1,7 @@
"""Callbacks for learning rate schedulers."""
from typing import Callable, Dict, List, Optional, Type
-from training.callbacks import Callback
+from training.trainer.callbacks import Callback
from text_recognizer.models import Model
@@ -19,7 +19,7 @@ class StepLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
self.lr_scheduler.step()
@@ -37,7 +37,7 @@ class MultiStepLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
self.lr_scheduler.step()
@@ -55,7 +55,7 @@ class ReduceLROnPlateau(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
val_loss = logs["val_loss"]
self.lr_scheduler.step(val_loss)
@@ -74,7 +74,7 @@ class CyclicLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
self.lr_scheduler.step()
@@ -92,6 +92,6 @@ class OneCycleLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
self.lr_scheduler.step()
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py
new file mode 100644
index 0000000..1970747
--- /dev/null
+++ b/src/training/trainer/callbacks/progress_bar.py
@@ -0,0 +1,61 @@
+"""Progress bar callback for the training loop."""
+from typing import Dict, Optional
+
+from tqdm import tqdm
+from training.trainer.callbacks import Callback
+
+
+class ProgressBar(Callback):
+ """A TQDM progress bar for the training loop."""
+
+ def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
+ """Initializes the tqdm callback."""
+ self.epochs = epochs
+ self.log_batch_frequency = log_batch_frequency
+ self.progress_bar = None
+ self.val_metrics = {}
+
+ def _configure_progress_bar(self) -> None:
+ """Configures the tqdm progress bar with custom bar format."""
+ self.progress_bar = tqdm(
+ total=len(self.model.data_loaders["train"]),
+ leave=True,
+ unit="step",
+ mininterval=self.log_batch_frequency,
+ bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
+ )
+
+ def _key_abbreviations(self, logs: Dict) -> Dict:
+ """Changes the length of keys, so that the progress bar fits better."""
+
+ def rename(key: str) -> str:
+ """Renames accuracy to acc."""
+ return key.replace("accuracy", "acc")
+
+ return {rename(key): value for key, value in logs.items()}
+
+ def on_fit_begin(self) -> None:
+ """Creates a tqdm progress bar."""
+ self._configure_progress_bar()
+
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
+ """Updates the description with the current epoch."""
+ self.progress_bar.reset()
+ self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """At the end of each epoch, the validation metrics are updated to the progress bar."""
+ self.val_metrics = logs
+ self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+ self.progress_bar.update()
+
+ def on_train_batch_end(self, batch: int, logs: Dict) -> None:
+ """Updates the progress bar for each training step."""
+ if self.val_metrics:
+ logs.update(self.val_metrics)
+ self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+ self.progress_bar.update()
+
+ def on_fit_end(self) -> None:
+ """Closes the tqdm progress bar."""
+ self.progress_bar.close()
diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index 6ada6df..e44c745 100644
--- a/src/training/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -1,9 +1,9 @@
-"""Callbacks using wandb."""
+"""Callback for W&B."""
from typing import Callable, Dict, List, Optional, Type
import numpy as np
from torchvision.transforms import Compose, ToTensor
-from training.callbacks import Callback
+from training.trainer.callbacks import Callback
import wandb
from text_recognizer.datasets import Transpose
@@ -28,12 +28,12 @@ class WandbCallback(Callback):
if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
wandb.log(logs, commit=True)
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Logs training metrics."""
if logs is not None:
self._on_batch_end(batch, logs)
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Logs validation metrics."""
if logs is not None:
self._on_batch_end(batch, logs)
diff --git a/src/training/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py
index 868d739..868d739 100644
--- a/src/training/population_based_training/__init__.py
+++ b/src/training/trainer/population_based_training/__init__.py
diff --git a/src/training/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py
index 868d739..868d739 100644
--- a/src/training/population_based_training/population_based_training.py
+++ b/src/training/trainer/population_based_training/population_based_training.py
diff --git a/src/training/train.py b/src/training/trainer/train.py
index aaa0430..a75ae8f 100644
--- a/src/training/train.py
+++ b/src/training/trainer/train.py
@@ -7,9 +7,9 @@ from typing import Dict, List, Optional, Tuple, Type
from loguru import logger
import numpy as np
import torch
-from tqdm import tqdm, trange
-from training.callbacks import Callback, CallbackList
-from training.util import RunningAverage
+from torch import Tensor
+from training.trainer.callbacks import Callback, CallbackList
+from training.trainer.util import RunningAverage
import wandb
from text_recognizer.models import Model
@@ -46,11 +46,11 @@ class Trainer:
self.model_dir = model_dir
self.checkpoint_path = checkpoint_path
self.start_epoch = 1
- self.epochs = train_args["epochs"] + self.start_epoch
+ self.epochs = train_args["epochs"]
self.callbacks = callbacks
if self.checkpoint_path is not None:
- self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + 1
+ self.start_epoch = self.model.load_checkpoint(self.checkpoint_path)
# Parse the name of the experiment.
experiment_dir = str(self.model_dir.parents[1]).split("/")
@@ -59,7 +59,7 @@ class Trainer:
def training_step(
self,
batch: int,
- samples: Tuple[torch.Tensor, torch.Tensor],
+ samples: Tuple[Tensor, Tensor],
loss_avg: Type[RunningAverage],
) -> Dict:
"""Performs the training step."""
@@ -108,27 +108,16 @@ class Trainer:
data_loader = self.model.data_loaders["train"]
- with tqdm(
- total=len(data_loader),
- leave=False,
- unit="step",
- bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}",
- ) as t:
- for batch, samples in enumerate(data_loader):
- self.callbacks.on_train_batch_begin(batch)
-
- metrics = self.training_step(batch, samples, loss_avg)
-
- self.callbacks.on_train_batch_end(batch, logs=metrics)
-
- # Update Tqdm progress bar.
- t.set_postfix(**metrics)
- t.update()
+ for batch, samples in enumerate(data_loader):
+ self.callbacks.on_train_batch_begin(batch)
+ metrics = self.training_step(batch, samples, loss_avg)
+ self.callbacks.on_train_batch_end(batch, logs=metrics)
+ @torch.no_grad()
def validation_step(
self,
batch: int,
- samples: Tuple[torch.Tensor, torch.Tensor],
+ samples: Tuple[Tensor, Tensor],
loss_avg: Type[RunningAverage],
) -> Dict:
"""Performs the validation step."""
@@ -158,6 +147,12 @@ class Trainer:
return metrics
+ def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None:
+ log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
+ logger.debug(
+ log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
+ )
+
def validate(self, epoch: Optional[int] = None) -> Dict:
"""Runs the validation loop for one epoch."""
# Set model to eval mode.
@@ -172,41 +167,18 @@ class Trainer:
# Summary for the current eval loop.
summary = []
- with tqdm(
- total=len(data_loader),
- leave=False,
- unit="step",
- bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}",
- ) as t:
- with torch.no_grad():
- for batch, samples in enumerate(data_loader):
- self.callbacks.on_validation_batch_begin(batch)
-
- metrics = self.validation_step(batch, samples, loss_avg)
-
- self.callbacks.on_validation_batch_end(batch, logs=metrics)
-
- summary.append(metrics)
-
- # Update Tqdm progress bar.
- t.set_postfix(**metrics)
- t.update()
+ for batch, samples in enumerate(data_loader):
+ self.callbacks.on_validation_batch_begin(batch)
+ metrics = self.validation_step(batch, samples, loss_avg)
+ self.callbacks.on_validation_batch_end(batch, logs=metrics)
+ summary.append(metrics)
# Compute mean of all metrics.
metrics_mean = {
"val_" + metric: np.mean([x[metric] for x in summary])
for metric in summary[0]
}
- if epoch:
- logger.debug(
- f"Validation metrics at epoch {epoch} - "
- + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
- )
- else:
- logger.debug(
- "Validation metrics - "
- + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
- )
+ self._log_val_metric(metrics_mean, epoch)
return metrics_mean
@@ -214,19 +186,14 @@ class Trainer:
"""Runs the training and evaluation loop."""
logger.debug(f"Running an experiment called {self.experiment_name}.")
+
+ # Set start time.
t_start = time.time()
self.callbacks.on_fit_begin()
- # TODO: fix progress bar as callback.
# Run the training loop.
- for epoch in trange(
- self.start_epoch,
- self.epochs,
- leave=False,
- bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:30}| {remaining}{postfix}",
- desc="Epoch",
- ):
+ for epoch in range(self.start_epoch, self.epochs + 1):
self.callbacks.on_epoch_begin(epoch)
# Perform one training pass over the training set.
diff --git a/src/training/util.py b/src/training/trainer/util.py
index 132b2dc..132b2dc 100644
--- a/src/training/util.py
+++ b/src/training/trainer/util.py