diff options
Diffstat (limited to 'src')
28 files changed, 1212 insertions, 218 deletions
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb index 49ca4c4..3f008c3 100644 --- a/src/notebooks/00-testing-stuff-out.ipynb +++ b/src/notebooks/00-testing-stuff-out.ipynb @@ -2,11 +2,120 @@ "cells": [ { "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from PIL import Image\n", + "import torch\n", + "from importlib.util import find_spec\n", + "if find_spec(\"text_recognizer\") is None:\n", + " import sys\n", + " sys.path.append('..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.cuda.is_available()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.nn.modules.activation.SELU" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.nn.SELU" + ] + }, + { + "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "import torch" + "a = \"nNone\"" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "b = a or \"relu\"" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'nnone'" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b.lower()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'nNone'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b" ] }, { @@ -986,28 +1095,16 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 51, "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'tqdm.auto.tqdm'; 'tqdm.auto' is not a package", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m<ipython-input-20-68e3c8bf3e1f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtqdm\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtqdm_auto\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'tqdm.auto.tqdm'; 'tqdm.auto' is not a package" - ] - } - ], + "outputs": [], "source": [ - "import tqdm.auto.tqdm as tqdm_auto" + "import tqdm" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 52, "metadata": {}, "outputs": [ { @@ -1016,7 +1113,7 @@ "tqdm.notebook.tqdm_notebook" ] }, - "execution_count": 19, + "execution_count": 52, "metadata": {}, "output_type": "execute_result" } @@ -1027,25 +1124,50 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tqdm.auto.tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "def test():\n", - " for i in range(9):\n", + " for i in tqdm.auto.tqdm(range(9)):\n", " pass\n", - " print(i)" + " print(i)\n", + " " ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 55, "metadata": {}, "outputs": [ { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e1d3b25d4ee141e882e316ec54e79d60", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { "name": "stdout", "output_type": "stream", "text": [ + "\n", "8\n" ] } @@ -1056,6 +1178,479 @@ }, { "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [], + "source": [ + "from time import sleep" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "41b743273ce14236bcb65782dbcd2e75", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "pbar = tqdm.auto.tqdm([\"a\", \"b\", \"c\", \"d\"], leave=True)\n", + "for char in pbar:\n", + " pbar.set_description(\"Processing %s\" % char)\n", + "# pbar.set_prefix()\n", + " sleep(0.25)\n", + "pbar.set_postfix({\"hej\": 0.32})" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [], + "source": [ + "pbar.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cb5ad8d6109f4b1495b8fc7422bafd01", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "with tqdm.auto.tqdm(total=10, bar_format=\"{postfix[0]} {postfix[1][value]:>8.2g}\",\n", + " postfix=[\"Batch\", dict(value=0)]) as t:\n", + " for i in range(10):\n", + " sleep(0.1)\n", + "# t.postfix[2][\"value\"] = 3 \n", + " t.postfix[1][\"value\"] = i / 2\n", + " t.update()" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0b341d49ad074823881e84a538bcad0c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "with tqdm.auto.tqdm(total=100, leave=True) as pbar:\n", + " for i in range(2):\n", + " for i in range(10):\n", + " sleep(0.1)\n", + " pbar.update(10)\n", + " pbar.set_postfix({\"adaf\": 23})\n", + " pbar.set_postfix({\"hej\": 0.32})\n", + " pbar.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, Encoder, ResidualNetwork" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "IdentityBlock(\n", + " (blocks): Identity()\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Identity()\n", + ")" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "IdentityBlock(32, 64)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ResidualBlock(\n", + " (blocks): Identity()\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ResidualBlock(32, 64)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BasicBlock(\n", + " (blocks): Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2dAuto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): ReLU(inplace=True)\n", + " (2): Sequential(\n", + " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "dummy = torch.ones((1, 32, 224, 224))\n", + "\n", + "block = BasicBlock(32, 64)\n", + "block(dummy).shape\n", + "print(block)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BottleNeckBlock(\n", + " (blocks): Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2dAuto(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): ReLU(inplace=True)\n", + " (2): Sequential(\n", + " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (3): ReLU(inplace=True)\n", + " (4): Sequential(\n", + " (0): Conv2dAuto(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "dummy = torch.ones((1, 32, 10, 10))\n", + "\n", + "block = BottleNeckBlock(32, 64)\n", + "block(dummy).shape\n", + "print(block)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 128, 24, 24])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dummy = torch.ones((1, 64, 48, 48))\n", + "\n", + "layer = ResidualLayer(64, 128, block=BasicBlock, num_blocks=3)\n", + "layer(dummy).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(64, 128), (128, 256), (256, 512)]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "blocks_sizes=[64, 128, 256, 512]\n", + "list(zip(blocks_sizes, blocks_sizes[1:]))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "e = Encoder(depths=[1, 1])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from torchsummary import summary" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " Conv2d-1 [-1, 32, 15, 15] 800\n", + " BatchNorm2d-2 [-1, 32, 15, 15] 64\n", + " ReLU-3 [-1, 32, 15, 15] 0\n", + " MaxPool2d-4 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-6 [-1, 32, 8, 8] 64\n", + " ReLU-7 [-1, 32, 8, 8] 0\n", + " ReLU-8 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-10 [-1, 32, 8, 8] 64\n", + " ReLU-11 [-1, 32, 8, 8] 0\n", + " ReLU-12 [-1, 32, 8, 8] 0\n", + " BasicBlock-13 [-1, 32, 8, 8] 0\n", + " ResidualLayer-14 [-1, 32, 8, 8] 0\n", + " Conv2d-15 [-1, 64, 4, 4] 2,048\n", + " BatchNorm2d-16 [-1, 64, 4, 4] 128\n", + " Conv2dAuto-17 [-1, 64, 4, 4] 18,432\n", + " BatchNorm2d-18 [-1, 64, 4, 4] 128\n", + " ReLU-19 [-1, 64, 4, 4] 0\n", + " ReLU-20 [-1, 64, 4, 4] 0\n", + " Conv2dAuto-21 [-1, 64, 4, 4] 36,864\n", + " BatchNorm2d-22 [-1, 64, 4, 4] 128\n", + " ReLU-23 [-1, 64, 4, 4] 0\n", + " ReLU-24 [-1, 64, 4, 4] 0\n", + " BasicBlock-25 [-1, 64, 4, 4] 0\n", + " ResidualLayer-26 [-1, 64, 4, 4] 0\n", + "================================================================\n", + "Total params: 77,152\n", + "Trainable params: 77,152\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.00\n", + "Forward/backward pass size (MB): 0.43\n", + "Params size (MB): 0.29\n", + "Estimated Total Size (MB): 0.73\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(e, (1, 28, 28), device=\"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "resnet = ResidualNetwork(1, 80, activation=\"selu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " Conv2d-1 [-1, 32, 15, 15] 800\n", + " BatchNorm2d-2 [-1, 32, 15, 15] 64\n", + " SELU-3 [-1, 32, 15, 15] 0\n", + " MaxPool2d-4 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-5 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-6 [-1, 32, 8, 8] 64\n", + " SELU-7 [-1, 32, 8, 8] 0\n", + " SELU-8 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-9 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-10 [-1, 32, 8, 8] 64\n", + " SELU-11 [-1, 32, 8, 8] 0\n", + " SELU-12 [-1, 32, 8, 8] 0\n", + " BasicBlock-13 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-14 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-15 [-1, 32, 8, 8] 64\n", + " SELU-16 [-1, 32, 8, 8] 0\n", + " SELU-17 [-1, 32, 8, 8] 0\n", + " Conv2dAuto-18 [-1, 32, 8, 8] 9,216\n", + " BatchNorm2d-19 [-1, 32, 8, 8] 64\n", + " SELU-20 [-1, 32, 8, 8] 0\n", + " SELU-21 [-1, 32, 8, 8] 0\n", + " BasicBlock-22 [-1, 32, 8, 8] 0\n", + " ResidualLayer-23 [-1, 32, 8, 8] 0\n", + " Conv2d-24 [-1, 64, 4, 4] 2,048\n", + " BatchNorm2d-25 [-1, 64, 4, 4] 128\n", + " Conv2dAuto-26 [-1, 64, 4, 4] 18,432\n", + " BatchNorm2d-27 [-1, 64, 4, 4] 128\n", + " SELU-28 [-1, 64, 4, 4] 0\n", + " SELU-29 [-1, 64, 4, 4] 0\n", + " Conv2dAuto-30 [-1, 64, 4, 4] 36,864\n", + " BatchNorm2d-31 [-1, 64, 4, 4] 128\n", + " SELU-32 [-1, 64, 4, 4] 0\n", + " SELU-33 [-1, 64, 4, 4] 0\n", + " BasicBlock-34 [-1, 64, 4, 4] 0\n", + " Conv2dAuto-35 [-1, 64, 4, 4] 36,864\n", + " BatchNorm2d-36 [-1, 64, 4, 4] 128\n", + " SELU-37 [-1, 64, 4, 4] 0\n", + " SELU-38 [-1, 64, 4, 4] 0\n", + " Conv2dAuto-39 [-1, 64, 4, 4] 36,864\n", + " BatchNorm2d-40 [-1, 64, 4, 4] 128\n", + " SELU-41 [-1, 64, 4, 4] 0\n", + " SELU-42 [-1, 64, 4, 4] 0\n", + " BasicBlock-43 [-1, 64, 4, 4] 0\n", + " ResidualLayer-44 [-1, 64, 4, 4] 0\n", + " Encoder-45 [-1, 64, 4, 4] 0\n", + " Reduce-46 [-1, 64] 0\n", + " Linear-47 [-1, 80] 5,200\n", + " Decoder-48 [-1, 80] 0\n", + "================================================================\n", + "Total params: 174,896\n", + "Trainable params: 174,896\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.00\n", + "Forward/backward pass size (MB): 0.65\n", + "Params size (MB): 0.67\n", + "Estimated Total Size (MB): 1.32\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(resnet, (1, 28, 28), device=\"cpu\")" + ] + }, + { + "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], diff --git a/src/notebooks/01-look-at-emnist.ipynb b/src/notebooks/01-look-at-emnist.ipynb index a68b418..8648afb 100644 --- a/src/notebooks/01-look-at-emnist.ipynb +++ b/src/notebooks/01-look-at-emnist.ipynb @@ -31,12 +31,31 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "dataset = EmnistDataset()\n", - "dataset.load_emnist_dataset()" + "dataset = EmnistDataset()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Tensor" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(dataset.data)" ] }, { diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 96f84e5..49ebad3 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -8,6 +8,7 @@ from loguru import logger import numpy as np from PIL import Image import torch +from torch import Tensor from torch.utils.data import DataLoader, Dataset from torchvision.datasets import EMNIST from torchvision.transforms import Compose, Normalize, ToTensor @@ -183,12 +184,8 @@ class EmnistDataset(Dataset): self.input_shape = self._mapper.input_shape self.num_classes = self._mapper.num_classes - # Placeholders - self.data = None - self.targets = None - # Load dataset. - self.load_emnist_dataset() + self.data, self.targets = self.load_emnist_dataset() @property def mapper(self) -> EmnistMapper: @@ -199,9 +196,7 @@ class EmnistDataset(Dataset): """Returns the length of the dataset.""" return len(self.data) - def __getitem__( - self, index: Union[int, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: """Fetches samples from the dataset. Args: @@ -239,11 +234,13 @@ class EmnistDataset(Dataset): f"Mapping: {self.mapper.mapping}\n" ) - def _sample_to_balance(self) -> None: + def _sample_to_balance( + self, data: Tensor, targets: Tensor + ) -> Tuple[np.ndarray, np.ndarray]: """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(self.seed) - x = self.data - y = self.targets + x = data + y = targets num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_indices = [] for label in np.unique(y.flatten()): @@ -253,20 +250,22 @@ class EmnistDataset(Dataset): indices = np.concatenate(all_sampled_indices) x_sampled = x[indices] y_sampled = y[indices] - self.data = x_sampled - self.targets = y_sampled + data = x_sampled + targets = y_sampled + return data, targets - def _subsample(self) -> None: + def _subsample(self, data: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: """Subsamples the dataset to the specified fraction.""" - x = self.data - y = self.targets + x = data + y = targets num_samples = int(x.shape[0] * self.subsample_fraction) x_sampled = x[:num_samples] y_sampled = y[:num_samples] self.data = x_sampled self.targets = y_sampled + return data, targets - def load_emnist_dataset(self) -> None: + def load_emnist_dataset(self) -> Tuple[Tensor, Tensor]: """Fetch the EMNIST dataset.""" dataset = EMNIST( root=DATA_DIRNAME, @@ -277,11 +276,13 @@ class EmnistDataset(Dataset): target_transform=None, ) - self.data = dataset.data - self.targets = dataset.targets + data = dataset.data + targets = dataset.targets if self.sample_to_balance: - self._sample_to_balance() + data, targets = self._sample_to_balance(data, targets) if self.subsample_fraction is not None: - self._subsample() + data, targets = self._subsample(data, targets) + + return data, targets diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index d64a991..b0617f5 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -8,6 +8,7 @@ import h5py from loguru import logger import numpy as np import torch +from torch import Tensor from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, Normalize, ToTensor @@ -87,16 +88,14 @@ class EmnistLinesDataset(Dataset): """Returns the length of the dataset.""" return len(self.data) - def __getitem__( - self, index: Union[int, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: """Fetches data, target pair of the dataset for a given and index or indices. Args: - index (Union[int, torch.Tensor]): Either a list or int of indices/index. + index (Union[int, Tensor]): Either a list or int of indices/index. Returns: - Tuple[torch.Tensor, torch.Tensor]: Data target pair. + Tuple[Tensor, Tensor]: Data target pair. """ if torch.is_tensor(index): diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 6d40b49..74fd223 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -53,8 +53,8 @@ class Model(ABC): """ - # Fetch data loaders and dataset info. - dataset_name, self._data_loaders, self._mapper = self._load_data_loader( + # Configure data loaders and dataset info. + dataset_name, self._data_loaders, self._mapper = self._configure_data_loader( data_loader_args ) self._input_shape = self._mapper.input_shape @@ -70,16 +70,19 @@ class Model(ABC): else: self._device = device - # Load network. - self._network, self._network_args = self._load_network(network_fn, network_args) + # Configure network. + self._network, self._network_args = self._configure_network( + network_fn, network_args + ) # To device. self._network.to(self._device) - # Set training objects. - self._criterion = self._load_criterion(criterion, criterion_args) - self._optimizer = self._load_optimizer(optimizer, optimizer_args) - self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args) + # Configure training objects. + self._criterion = self._configure_criterion(criterion, criterion_args) + self._optimizer, self._lr_scheduler = self._configure_optimizers( + optimizer, optimizer_args, lr_scheduler, lr_scheduler_args + ) # Experiment directory. self.model_dir = None @@ -87,7 +90,7 @@ class Model(ABC): # Flag for stopping training. self.stop_training = False - def _load_data_loader( + def _configure_data_loader( self, data_loader_args: Optional[Dict] ) -> Tuple[str, Dict, EmnistMapper]: """Loads data loader, dataset name, and dataset mapper.""" @@ -102,7 +105,7 @@ class Model(ABC): data_loaders = None return dataset_name, data_loaders, mapper - def _load_network( + def _configure_network( self, network_fn: Type[nn.Module], network_args: Optional[Dict] ) -> Tuple[Type[nn.Module], Dict]: """Loads the network.""" @@ -113,7 +116,7 @@ class Model(ABC): network = network_fn(**network_args) return network, network_args - def _load_criterion( + def _configure_criterion( self, criterion: Optional[Callable], criterion_args: Optional[Dict] ) -> Optional[Callable]: """Loads the criterion.""" @@ -123,27 +126,27 @@ class Model(ABC): _criterion = None return _criterion - def _load_optimizer( - self, optimizer: Optional[Callable], optimizer_args: Optional[Dict] - ) -> Optional[Callable]: - """Loads the optimizer.""" + def _configure_optimizers( + self, + optimizer: Optional[Callable], + optimizer_args: Optional[Dict], + lr_scheduler: Optional[Callable], + lr_scheduler_args: Optional[Dict], + ) -> Tuple[Optional[Callable], Optional[Callable]]: + """Loads the optimizers.""" if optimizer is not None: _optimizer = optimizer(self._network.parameters(), **optimizer_args) else: _optimizer = None - return _optimizer - def _load_lr_scheduler( - self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict] - ) -> Optional[Callable]: - """Loads learning rate scheduler.""" if self._optimizer and lr_scheduler is not None: if "OneCycleLR" in str(lr_scheduler): lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"]) _lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) else: _lr_scheduler = None - return _lr_scheduler + + return _optimizer, _lr_scheduler @property def __name__(self) -> str: diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 0a0ab2d..0fd7afd 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -44,6 +44,7 @@ class CharacterModel(Model): self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) + @torch.no_grad() def predict_on_image( self, image: Union[np.ndarray, torch.Tensor] ) -> Tuple[str, float]: @@ -64,10 +65,9 @@ class CharacterModel(Model): # If the image is an unscaled tensor. image = image.type("torch.FloatTensor") / 255 - with torch.no_grad(): - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - logits = self.network(image) + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + logits = self.network(image) prediction = self.softmax(logits.data.squeeze()) diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index e6b6946..a83ca35 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,5 +1,6 @@ """Network modules.""" from .lenet import LeNet from .mlp import MLP +from .residual_network import ResidualNetwork -__all__ = ["MLP", "LeNet"] +__all__ = ["MLP", "LeNet", "ResidualNetwork"] diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index cbc58fc..91d3f2c 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange import torch from torch import nn +from text_recognizer.networks.misc import activation_function + class LeNet(nn.Module): """LeNet network.""" @@ -16,8 +18,7 @@ class LeNet(nn.Module): hidden_size: Tuple[int, ...] = (9216, 128), dropout_rate: float = 0.2, output_size: int = 10, - activation_fn: Optional[Callable] = None, - activation_fn_args: Optional[Dict] = None, + activation_fn: Optional[str] = "relu", ) -> None: """The LeNet network. @@ -28,18 +29,12 @@ class LeNet(nn.Module): Defaults to (9216, 128). dropout_rate (float): The dropout rate. Defaults to 0.2. output_size (int): Number of classes. Defaults to 10. - activation_fn (Optional[Callable]): The non-linear activation function. Defaults to - nn.ReLU(inplace). - activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. + activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. """ super().__init__() - if activation_fn is not None: - activation_fn_args = activation_fn_args or {} - activation_fn = getattr(nn, activation_fn)(**activation_fn_args) - else: - activation_fn = nn.ReLU(inplace=True) + activation_fn = activation_function(activation_fn) self.layers = [ nn.Conv2d( @@ -66,7 +61,7 @@ class LeNet(nn.Module): self.layers = nn.Sequential(*self.layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward.""" + """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py index 2fbab8f..6f61b5d 100644 --- a/src/text_recognizer/networks/misc.py +++ b/src/text_recognizer/networks/misc.py @@ -1,9 +1,9 @@ """Miscellaneous neural network functionality.""" -from typing import Tuple +from typing import Tuple, Type from einops import rearrange import torch -from torch.nn import Unfold +from torch import nn def sliding_window( @@ -20,10 +20,24 @@ def sliding_window( torch.Tensor: A tensor with the shape (batch, patches, height, width). """ - unfold = Unfold(kernel_size=patch_size, stride=stride) + unfold = nn.Unfold(kernel_size=patch_size, stride=stride) # Preform the slidning window, unsqueeze as the channel dimesion is lost. patches = unfold(images).unsqueeze(1) patches = rearrange( patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1] ) return patches + + +def activation_function(activation: str) -> Type[nn.Module]: + """Returns the callable activation function.""" + activation_fns = nn.ModuleDict( + [ + ["gelu", nn.GELU()], + ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], + ["none", nn.Identity()], + ["relu", nn.ReLU(inplace=True)], + ["selu", nn.SELU(inplace=True)], + ] + ) + return activation_fns[activation.lower()] diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index ac2c825..acebdaa 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange import torch from torch import nn +from text_recognizer.networks.misc import activation_function + class MLP(nn.Module): """Multi layered perceptron network.""" @@ -16,8 +18,7 @@ class MLP(nn.Module): hidden_size: Union[int, List] = 128, num_layers: int = 3, dropout_rate: float = 0.2, - activation_fn: Optional[Callable] = None, - activation_fn_args: Optional[Dict] = None, + activation_fn: str = "relu", ) -> None: """Initialization of the MLP network. @@ -27,18 +28,13 @@ class MLP(nn.Module): hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. num_layers (int): The number of hidden layers. Defaults to 3. dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. - activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to - None. - activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. + activation_fn (str): Name of the activation function in the hidden layers. Defaults to + relu. """ super().__init__() - if activation_fn is not None: - activation_fn_args = activation_fn_args or {} - activation_fn = getattr(nn, activation_fn)(**activation_fn_args) - else: - activation_fn = nn.ReLU(inplace=True) + activation_fn = activation_function(activation_fn) if isinstance(hidden_size, int): hidden_size = [hidden_size] * num_layers @@ -65,7 +61,7 @@ class MLP(nn.Module): self.layers = nn.Sequential(*self.layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward.""" + """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 23394b0..47e351a 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -1 +1,315 @@ """Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Rearrange, Reduce +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.misc import activation_function + + +class Conv2dAuto(nn.Conv2d): + """Convolution with auto padding based on kernel size.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) + + +def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: + """3x3 convolution with batch norm.""" + conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) + return nn.Sequential( + conv3x3(in_channels, out_channels, *args, **kwargs), + nn.BatchNorm2d(out_channels), + ) + + +class IdentityBlock(nn.Module): + """Residual with identity block.""" + + def __init__( + self, in_channels: int, out_channels: int, activation: str = "relu" + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.blocks = nn.Identity() + self.activation_fn = activation_function(activation) + self.shortcut = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self.apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + x = self.activation_fn(x) + return x + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.out_channels + + +class ResidualBlock(IdentityBlock): + """Residual with nonlinear shortcut.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion: int = 1, + downsampling: int = 1, + *args, + **kwargs + ) -> None: + """Short summary. + + Args: + in_channels (int): Number of in channels. + out_channels (int): umber of out channels. + expansion (int): Expansion factor of the out channels. Defaults to 1. + downsampling (int): Downsampling factor used in stride. Defaults to 1. + *args (type): Extra arguments. + **kwargs (type): Extra key value arguments. + + """ + super().__init__(in_channels, out_channels, *args, **kwargs) + self.expansion = expansion + self.downsampling = downsampling + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.expanded_channels, + kernel_size=1, + stride=self.downsampling, + bias=False, + ), + nn.BatchNorm2d(self.expanded_channels), + ) + if self.apply_shortcut + else None + ) + + @property + def expanded_channels(self) -> int: + """Computes the expanded output channels.""" + return self.out_channels * self.expansion + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.expanded_channels + + +class BasicBlock(ResidualBlock): + """Basic ResNet block.""" + + expansion = 1 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + bias=False, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + bias=False, + ), + ) + + +class BottleNeckBlock(ResidualBlock): + """Bottleneck block to increase depth while minimizing parameter size.""" + + expansion = 4 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + kernel_size=1, + ), + ) + + +class ResidualLayer(nn.Module): + """ResNet layer.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + block: BasicBlock = BasicBlock, + num_blocks: int = 1, + *args, + **kwargs + ) -> None: + super().__init__() + downsampling = 2 if in_channels != out_channels else 1 + self.blocks = nn.Sequential( + block( + in_channels, out_channels, *args, **kwargs, downsampling=downsampling + ), + *[ + block( + out_channels * block.expansion, + out_channels, + downsampling=1, + *args, + **kwargs + ) + for _ in range(num_blocks - 1) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.blocks(x) + return x + + +class Encoder(nn.Module): + """Encoder network.""" + + def __init__( + self, + in_channels: int = 1, + block_sizes: List[int] = (32, 64), + depths: List[int] = (2, 2), + activation: str = "relu", + block: Type[nn.Module] = BasicBlock, + *args, + **kwargs + ) -> None: + super().__init__() + + self.block_sizes = block_sizes + self.depths = depths + self.activation = activation + + self.gate = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=self.block_sizes[0], + kernel_size=3, + stride=2, + padding=3, + bias=False, + ), + nn.BatchNorm2d(self.block_sizes[0]), + activation_function(self.activation), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + self.blocks = self._configure_blocks(block) + + def _configure_blocks( + self, block: Type[nn.Module], *args, **kwargs + ) -> nn.Sequential: + channels = [self.block_sizes[0]] + list( + zip(self.block_sizes, self.block_sizes[1:]) + ) + blocks = [ + ResidualLayer( + in_channels=channels[0], + out_channels=channels[0], + num_blocks=self.depths[0], + block=block, + activation=self.activation, + *args, + **kwargs + ) + ] + blocks += [ + ResidualLayer( + in_channels=in_channels * block.expansion, + out_channels=out_channels, + num_blocks=num_blocks, + block=block, + activation=self.activation, + *args, + **kwargs + ) + for (in_channels, out_channels), num_blocks in zip( + channels[1:], self.depths[1:] + ) + ] + + return nn.Sequential(*blocks) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.gate(x) + return self.blocks(x) + + +class Decoder(nn.Module): + """Classification head.""" + + def __init__(self, in_features: int, num_classes: int = 80) -> None: + super().__init__() + self.decoder = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(in_features=in_features, out_features=num_classes), + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.decoder(x) + + +class ResidualNetwork(nn.Module): + """Full residual network.""" + + def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: + super().__init__() + self.encoder = Encoder(in_channels, *args, **kwargs) + self.decoder = Decoder( + in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, + num_classes=num_classes, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.encoder(x) + x = self.decoder(x) + return x diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt Binary files differindex 81ef9be..676eb44 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt Binary files differindex 49bd166..86cf103 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt Binary files differnew file mode 100644 index 0000000..008beb2 --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml index 355305c..bae02ac 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/src/training/experiments/sample_experiment.yml @@ -9,25 +9,32 @@ experiments: seed: 4711 data_loader_args: splits: [train, val] - batch_size: 256 shuffle: true num_workers: 8 cuda: true model: CharacterModel metrics: [accuracy] - network: MLP + # network: MLP + # network_args: + # input_size: 784 + # hidden_size: 512 + # output_size: 80 + # num_layers: 3 + # dropout_rate: 0 + # activation_fn: SELU + network: ResidualNetwork network_args: - input_size: 784 - output_size: 62 - num_layers: 3 - activation_fn: GELU + in_channels: 1 + num_classes: 80 + depths: [1, 1] + block_sizes: [128, 256] # network: LeNet # network_args: # output_size: 62 # activation_fn: GELU train_args: batch_size: 256 - epochs: 16 + epochs: 32 criterion: CrossEntropyLoss criterion_args: weight: null @@ -43,20 +50,24 @@ experiments: # centered: false optimizer: AdamW optimizer_args: - lr: 1.e-2 + lr: 1.e-03 betas: [0.9, 0.999] eps: 1.e-08 - weight_decay: 0 + # weight_decay: 5.e-4 amsgrad: false # lr_scheduler: null lr_scheduler: OneCycleLR lr_scheduler_args: - max_lr: 1.e-3 - epochs: 16 - callbacks: [Checkpoint, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] + max_lr: 1.e-03 + epochs: 32 + anneal_strategy: linear + callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] callback_args: Checkpoint: monitor: val_accuracy + ProgressBar: + epochs: 32 + log_batch_frequency: 100 EarlyStopping: monitor: val_loss min_delta: 0.0 @@ -68,5 +79,5 @@ experiments: num_examples: 4 OneCycleLR: null - verbosity: 2 # 0, 1, 2 + verbosity: 1 # 0, 1, 2 resume_experiment: null diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py index 97c0304..4c3f9ba 100644 --- a/src/training/prepare_experiments.py +++ b/src/training/prepare_experiments.py @@ -7,11 +7,11 @@ from loguru import logger import yaml -# flake8: noqa: S404,S607,S603 def run_experiments(experiments_filename: str) -> None: """Run experiment from file.""" with open(experiments_filename) as f: experiments_config = yaml.safe_load(f) + num_experiments = len(experiments_config["experiments"]) for index in range(num_experiments): experiment_config = experiments_config["experiments"][index] @@ -27,10 +27,10 @@ def run_experiments(experiments_filename: str) -> None: type=str, help="Filename of Yaml file of experiments to run.", ) -def main(experiments_filename: str) -> None: +def run_cli(experiments_filename: str) -> None: """Parse command-line arguments and run experiments from provided file.""" run_experiments(experiments_filename) if __name__ == "__main__": - main() + run_cli() diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index d278dc2..8c063ff 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -6,18 +6,20 @@ import json import os from pathlib import Path import re -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, Tuple, Type import click from loguru import logger import torch from tqdm import tqdm -from training.callbacks import CallbackList from training.gpu_manager import GPUManager -from training.train import Trainer +from training.trainer.callbacks import CallbackList +from training.trainer.train import Trainer import wandb import yaml +from text_recognizer.models import Model + EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" @@ -35,7 +37,7 @@ def get_level(experiment_config: Dict) -> int: return 10 -def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path: +def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path: """Create new experiment.""" EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) experiment_dir = EXPERIMENTS_DIRNAME / model.__name__ @@ -67,6 +69,8 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] """Loads all modules and arguments.""" # Import the data loader arguments. data_loader_args = experiment_config.get("data_loader_args", {}) + train_args = experiment_config.get("train_args", {}) + data_loader_args["batch_size"] = train_args["batch_size"] data_loader_args["dataset"] = experiment_config["dataset"] data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {}) @@ -94,7 +98,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] optimizer_args = experiment_config.get("optimizer_args", {}) # Callbacks - callback_modules = importlib.import_module("training.callbacks") + callback_modules = importlib.import_module("training.trainer.callbacks") callbacks = [ getattr(callback_modules, callback)( **check_args(experiment_config["callback_args"][callback]) @@ -208,6 +212,7 @@ def run_experiment( with open(str(config_path), "w") as f: yaml.dump(experiment_config, f) + # Train the model. trainer = Trainer( model=model, model_dir=model_dir, @@ -247,7 +252,7 @@ def run_experiment( @click.option( "--nowandb", is_flag=False, help="If true, do not use wandb for this run." ) -def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: +def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: """Run experiment.""" if gpu < 0: gpu_manager = GPUManager(True) @@ -260,4 +265,4 @@ def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: if __name__ == "__main__": - main() + run_cli() diff --git a/src/training/trainer/__init__.py b/src/training/trainer/__init__.py new file mode 100644 index 0000000..de41bfb --- /dev/null +++ b/src/training/trainer/__init__.py @@ -0,0 +1,2 @@ +"""Trainer modules.""" +from .train import Trainer diff --git a/src/training/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index fbcc285..5942276 100644 --- a/src/training/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -2,6 +2,7 @@ from .base import Callback, CallbackList, Checkpoint from .early_stopping import EarlyStopping from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR +from .progress_bar import ProgressBar from .wandb_callbacks import WandbCallback, WandbImageLogger __all__ = [ @@ -14,6 +15,7 @@ __all__ = [ "CyclicLR", "MultiStepLR", "OneCycleLR", + "ProgressBar", "ReduceLROnPlateau", "StepLR", ] diff --git a/src/training/callbacks/base.py b/src/training/trainer/callbacks/base.py index e0d91e6..8df94f3 100644 --- a/src/training/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -1,7 +1,7 @@ """Metaclass for callback functions.""" from enum import Enum -from typing import Callable, Dict, List, Type, Union +from typing import Callable, Dict, List, Optional, Type, Union from loguru import logger import numpy as np @@ -36,27 +36,29 @@ class Callback: """Called when fit ends.""" pass - def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch. Only used in training mode.""" pass - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch. Only used in training mode.""" pass - def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch.""" pass - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch.""" pass - def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: """Called at the beginning of an epoch.""" pass - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch.""" pass @@ -102,18 +104,18 @@ class CallbackList: for callback in self._callbacks: callback.on_fit_end() - def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch.""" for callback in self._callbacks: callback.on_epoch_begin(epoch, logs) - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch.""" for callback in self._callbacks: callback.on_epoch_end(epoch, logs) def _call_batch_hook( - self, mode: str, hook: str, batch: int, logs: Dict = {} + self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None ) -> None: """Helper function for all batch_{begin | end} methods.""" if hook == "begin": @@ -123,39 +125,45 @@ class CallbackList: else: raise ValueError(f"Unrecognized hook {hook}.") - def _call_batch_begin_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: + def _call_batch_begin_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: """Helper function for all `on_*_batch_begin` methods.""" hook_name = f"on_{mode}_batch_begin" self._call_batch_hook_helper(hook_name, batch, logs) - def _call_batch_end_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: + def _call_batch_end_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: """Helper function for all `on_*_batch_end` methods.""" hook_name = f"on_{mode}_batch_end" self._call_batch_hook_helper(hook_name, batch, logs) def _call_batch_hook_helper( - self, hook_name: str, batch: int, logs: Dict = {} + self, hook_name: str, batch: int, logs: Optional[Dict] = None ) -> None: """Helper function for `on_*_batch_begin` methods.""" for callback in self._callbacks: hook = getattr(callback, hook_name) hook(batch, logs) - def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch) + self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "end", batch) + self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) - def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch) + self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch) + self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) def __iter__(self) -> iter: """Iter function for callback list.""" diff --git a/src/training/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py index c9b7907..02b431f 100644 --- a/src/training/callbacks/early_stopping.py +++ b/src/training/trainer/callbacks/early_stopping.py @@ -4,7 +4,8 @@ from typing import Dict, Union from loguru import logger import numpy as np import torch -from training.callbacks import Callback +from torch import Tensor +from training.trainer.callbacks import Callback class EarlyStopping(Callback): @@ -95,7 +96,7 @@ class EarlyStopping(Callback): f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." ) - def get_monitor_value(self, logs: Dict) -> Union[torch.Tensor, None]: + def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: """Extracts the monitor value.""" monitor_value = logs.get(self.monitor) if monitor_value is None: diff --git a/src/training/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index 00c7e9b..ba2226a 100644 --- a/src/training/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -1,7 +1,7 @@ """Callbacks for learning rate schedulers.""" from typing import Callable, Dict, List, Optional, Type -from training.callbacks import Callback +from training.trainer.callbacks import Callback from text_recognizer.models import Model @@ -19,7 +19,7 @@ class StepLR(Callback): self.model = model self.lr_scheduler = self.model.lr_scheduler - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" self.lr_scheduler.step() @@ -37,7 +37,7 @@ class MultiStepLR(Callback): self.model = model self.lr_scheduler = self.model.lr_scheduler - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" self.lr_scheduler.step() @@ -55,7 +55,7 @@ class ReduceLROnPlateau(Callback): self.model = model self.lr_scheduler = self.model.lr_scheduler - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" val_loss = logs["val_loss"] self.lr_scheduler.step(val_loss) @@ -74,7 +74,7 @@ class CyclicLR(Callback): self.model = model self.lr_scheduler = self.model.lr_scheduler - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" self.lr_scheduler.step() @@ -92,6 +92,6 @@ class OneCycleLR(Callback): self.model = model self.lr_scheduler = self.model.lr_scheduler - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" self.lr_scheduler.step() diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py new file mode 100644 index 0000000..1970747 --- /dev/null +++ b/src/training/trainer/callbacks/progress_bar.py @@ -0,0 +1,61 @@ +"""Progress bar callback for the training loop.""" +from typing import Dict, Optional + +from tqdm import tqdm +from training.trainer.callbacks import Callback + + +class ProgressBar(Callback): + """A TQDM progress bar for the training loop.""" + + def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: + """Initializes the tqdm callback.""" + self.epochs = epochs + self.log_batch_frequency = log_batch_frequency + self.progress_bar = None + self.val_metrics = {} + + def _configure_progress_bar(self) -> None: + """Configures the tqdm progress bar with custom bar format.""" + self.progress_bar = tqdm( + total=len(self.model.data_loaders["train"]), + leave=True, + unit="step", + mininterval=self.log_batch_frequency, + bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", + ) + + def _key_abbreviations(self, logs: Dict) -> Dict: + """Changes the length of keys, so that the progress bar fits better.""" + + def rename(key: str) -> str: + """Renames accuracy to acc.""" + return key.replace("accuracy", "acc") + + return {rename(key): value for key, value in logs.items()} + + def on_fit_begin(self) -> None: + """Creates a tqdm progress bar.""" + self._configure_progress_bar() + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: + """Updates the description with the current epoch.""" + self.progress_bar.reset() + self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """At the end of each epoch, the validation metrics are updated to the progress bar.""" + self.val_metrics = logs + self.progress_bar.set_postfix(**self._key_abbreviations(logs)) + self.progress_bar.update() + + def on_train_batch_end(self, batch: int, logs: Dict) -> None: + """Updates the progress bar for each training step.""" + if self.val_metrics: + logs.update(self.val_metrics) + self.progress_bar.set_postfix(**self._key_abbreviations(logs)) + self.progress_bar.update() + + def on_fit_end(self) -> None: + """Closes the tqdm progress bar.""" + self.progress_bar.close() diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index 6ada6df..e44c745 100644 --- a/src/training/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -1,9 +1,9 @@ -"""Callbacks using wandb.""" +"""Callback for W&B.""" from typing import Callable, Dict, List, Optional, Type import numpy as np from torchvision.transforms import Compose, ToTensor -from training.callbacks import Callback +from training.trainer.callbacks import Callback import wandb from text_recognizer.datasets import Transpose @@ -28,12 +28,12 @@ class WandbCallback(Callback): if self.log_batch_frequency and batch % self.log_batch_frequency == 0: wandb.log(logs, commit=True) - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Logs training metrics.""" if logs is not None: self._on_batch_end(batch, logs) - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Logs validation metrics.""" if logs is not None: self._on_batch_end(batch, logs) diff --git a/src/training/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py index 868d739..868d739 100644 --- a/src/training/population_based_training/__init__.py +++ b/src/training/trainer/population_based_training/__init__.py diff --git a/src/training/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py index 868d739..868d739 100644 --- a/src/training/population_based_training/population_based_training.py +++ b/src/training/trainer/population_based_training/population_based_training.py diff --git a/src/training/train.py b/src/training/trainer/train.py index aaa0430..a75ae8f 100644 --- a/src/training/train.py +++ b/src/training/trainer/train.py @@ -7,9 +7,9 @@ from typing import Dict, List, Optional, Tuple, Type from loguru import logger import numpy as np import torch -from tqdm import tqdm, trange -from training.callbacks import Callback, CallbackList -from training.util import RunningAverage +from torch import Tensor +from training.trainer.callbacks import Callback, CallbackList +from training.trainer.util import RunningAverage import wandb from text_recognizer.models import Model @@ -46,11 +46,11 @@ class Trainer: self.model_dir = model_dir self.checkpoint_path = checkpoint_path self.start_epoch = 1 - self.epochs = train_args["epochs"] + self.start_epoch + self.epochs = train_args["epochs"] self.callbacks = callbacks if self.checkpoint_path is not None: - self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + 1 + self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) # Parse the name of the experiment. experiment_dir = str(self.model_dir.parents[1]).split("/") @@ -59,7 +59,7 @@ class Trainer: def training_step( self, batch: int, - samples: Tuple[torch.Tensor, torch.Tensor], + samples: Tuple[Tensor, Tensor], loss_avg: Type[RunningAverage], ) -> Dict: """Performs the training step.""" @@ -108,27 +108,16 @@ class Trainer: data_loader = self.model.data_loaders["train"] - with tqdm( - total=len(data_loader), - leave=False, - unit="step", - bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}", - ) as t: - for batch, samples in enumerate(data_loader): - self.callbacks.on_train_batch_begin(batch) - - metrics = self.training_step(batch, samples, loss_avg) - - self.callbacks.on_train_batch_end(batch, logs=metrics) - - # Update Tqdm progress bar. - t.set_postfix(**metrics) - t.update() + for batch, samples in enumerate(data_loader): + self.callbacks.on_train_batch_begin(batch) + metrics = self.training_step(batch, samples, loss_avg) + self.callbacks.on_train_batch_end(batch, logs=metrics) + @torch.no_grad() def validation_step( self, batch: int, - samples: Tuple[torch.Tensor, torch.Tensor], + samples: Tuple[Tensor, Tensor], loss_avg: Type[RunningAverage], ) -> Dict: """Performs the validation step.""" @@ -158,6 +147,12 @@ class Trainer: return metrics + def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None: + log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") + logger.debug( + log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) + ) + def validate(self, epoch: Optional[int] = None) -> Dict: """Runs the validation loop for one epoch.""" # Set model to eval mode. @@ -172,41 +167,18 @@ class Trainer: # Summary for the current eval loop. summary = [] - with tqdm( - total=len(data_loader), - leave=False, - unit="step", - bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}", - ) as t: - with torch.no_grad(): - for batch, samples in enumerate(data_loader): - self.callbacks.on_validation_batch_begin(batch) - - metrics = self.validation_step(batch, samples, loss_avg) - - self.callbacks.on_validation_batch_end(batch, logs=metrics) - - summary.append(metrics) - - # Update Tqdm progress bar. - t.set_postfix(**metrics) - t.update() + for batch, samples in enumerate(data_loader): + self.callbacks.on_validation_batch_begin(batch) + metrics = self.validation_step(batch, samples, loss_avg) + self.callbacks.on_validation_batch_end(batch, logs=metrics) + summary.append(metrics) # Compute mean of all metrics. metrics_mean = { "val_" + metric: np.mean([x[metric] for x in summary]) for metric in summary[0] } - if epoch: - logger.debug( - f"Validation metrics at epoch {epoch} - " - + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) - ) - else: - logger.debug( - "Validation metrics - " - + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) - ) + self._log_val_metric(metrics_mean, epoch) return metrics_mean @@ -214,19 +186,14 @@ class Trainer: """Runs the training and evaluation loop.""" logger.debug(f"Running an experiment called {self.experiment_name}.") + + # Set start time. t_start = time.time() self.callbacks.on_fit_begin() - # TODO: fix progress bar as callback. # Run the training loop. - for epoch in trange( - self.start_epoch, - self.epochs, - leave=False, - bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:30}| {remaining}{postfix}", - desc="Epoch", - ): + for epoch in range(self.start_epoch, self.epochs + 1): self.callbacks.on_epoch_begin(epoch) # Perform one training pass over the training set. diff --git a/src/training/util.py b/src/training/trainer/util.py index 132b2dc..132b2dc 100644 --- a/src/training/util.py +++ b/src/training/trainer/util.py |