summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--notebooks/00-scratch-pad.ipynb275
-rw-r--r--text_recognizer/callbacks/wandb_callbacks.py95
-rw-r--r--text_recognizer/data/base_data_module.py29
-rw-r--r--text_recognizer/data/emnist.py22
-rw-r--r--text_recognizer/data/emnist_lines.py35
-rw-r--r--text_recognizer/data/iam.py6
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py33
-rw-r--r--text_recognizer/data/iam_lines.py22
-rw-r--r--text_recognizer/data/iam_paragraphs.py32
-rw-r--r--text_recognizer/models/base.py5
-rw-r--r--text_recognizer/models/metrics.py15
-rw-r--r--text_recognizer/models/transformer.py26
-rw-r--r--text_recognizer/models/vqvae.py34
-rw-r--r--text_recognizer/networks/util.py2
-rw-r--r--training/conf/datamodule/iam_extended_paragraphs.yaml5
-rw-r--r--training/conf/dataset/iam_extended_paragraphs.yaml6
-rw-r--r--training/conf/lr_scheduler/one_cycle.yaml23
-rw-r--r--training/conf/model/lit_vqvae.yaml2
-rw-r--r--training/run.py1
-rw-r--r--training/utils.py2
20 files changed, 468 insertions, 202 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index 16c6533..1e30038 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -30,7 +30,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -39,7 +39,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@@ -48,41 +48,280 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"@attr.s\n",
- "class B:\n",
- " batch_size = attr.ib()\n",
- " num_workers = attr.ib()"
+ "class B(nn.Module):\n",
+ " input_dim = attr.ib()\n",
+ " hidden = attr.ib()\n",
+ " xx = attr.ib(init=False, default=\"hek\")\n",
+ " \n",
+ " def __attrs_post_init__(self):\n",
+ " super().__init__()\n",
+ " self.fc = nn.Linear(self.input_dim, self.hidden)\n",
+ " self.xx = \"da\"\n",
+ " \n",
+ " def forward(self, x):\n",
+ " return self.fc(x)"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
- "@attr.s\n",
- "class T(B):\n",
+ "def f(x):\n",
+ " return 2\n",
"\n",
- " def __attrs_post_init__(self) -> None:\n",
- " super().__init__(self.batch_size, self.num_workers)\n",
- " self.hej = None\n",
+ "@attr.s(auto_attribs=True)\n",
+ "class T(B):\n",
" \n",
- " batch_size = attr.ib()\n",
- " num_workers = attr.ib()\n",
- " h: Path = attr.ib(converter=Path)"
+ " h: Path = attr.ib(converter=Path)\n",
+ " p: int = attr.ib(init=False, default=f(3))"
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 53,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": "__init__() missing 1 required positional argument: 'hidden'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m<ipython-input-53-ef8b390156f4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m16\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"hej\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m: __init__() missing 1 required positional argument: 'hidden'"
+ ]
+ }
+ ],
+ "source": [
+ "t = T(input_dim=16, h=\"hej\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'da'"
+ ]
+ },
+ "execution_count": 51,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t.xx"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "2"
+ ]
+ },
+ "execution_count": 52,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t.p"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "16"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t.input_dim"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = torch.rand(16, 16)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([16, 16])"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "T(input_dim=16, hidden=24, h=PosixPath('hej'))"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t.cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
- "t = T(batch_size=16, num_workers=2, h=\"hej\")"
+ "x = x.cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 3.6047e-01, 1.0200e+00, 3.6786e-01, 1.6077e-01, 3.9281e-02,\n",
+ " 3.2830e-01, 1.3433e-01, -9.0334e-02, -3.8712e-01, 8.1547e-01,\n",
+ " -5.4483e-01, -9.7471e-01, 3.3706e-01, -9.5283e-01, -1.6271e-01,\n",
+ " 3.8504e-01, -5.0106e-01, -4.8638e-01, 3.7033e-01, -4.9557e-01,\n",
+ " 2.6555e-01, 5.1245e-01, 6.6751e-01, -2.6291e-01],\n",
+ " [ 1.3811e-01, 7.4522e-01, 4.9935e-01, 3.3878e-01, 1.8501e-01,\n",
+ " 2.2269e-02, -2.0328e-01, 1.4629e-01, -2.2957e-01, 4.1197e-01,\n",
+ " -1.9555e-01, -4.7609e-01, 9.0206e-02, -8.8568e-01, -2.1618e-01,\n",
+ " 2.8882e-01, -5.4335e-01, -6.6301e-01, 4.9990e-01, -4.0144e-01,\n",
+ " 3.6403e-01, 5.3901e-01, 8.6665e-01, -7.8312e-02],\n",
+ " [ 1.6493e-02, 4.6157e-01, 2.9500e-02, 2.4190e-01, 6.5753e-01,\n",
+ " 4.3770e-02, -5.3773e-02, 1.8183e-01, -2.5983e-02, 4.1634e-01,\n",
+ " -3.5218e-01, -5.6129e-01, 4.1452e-01, -1.2265e+00, -5.8544e-01,\n",
+ " 3.6382e-01, -6.4090e-01, -5.8679e-01, 4.3489e-02, -1.1233e-01,\n",
+ " 3.1175e-01, 4.2857e-01, 1.6501e-01, -2.4118e-01],\n",
+ " [ 9.2361e-02, 6.0196e-01, 1.3081e-02, -8.1091e-02, 4.2342e-01,\n",
+ " -8.8457e-02, -8.1851e-02, -1.1562e-01, -1.5049e-01, 4.9972e-01,\n",
+ " -3.0432e-01, -7.8619e-01, 2.1060e-01, -1.0598e+00, -4.6542e-01,\n",
+ " 4.2382e-01, -6.5671e-01, -4.8589e-01, 5.5977e-02, -2.9478e-02,\n",
+ " 8.5718e-02, 4.7685e-01, 4.8351e-01, -2.8142e-01],\n",
+ " [ 1.3377e-01, 5.4434e-01, 3.4505e-01, 1.1307e-01, 4.4057e-01,\n",
+ " -7.6075e-03, 1.3841e-01, -1.1497e-01, -1.3177e-01, 8.0254e-01,\n",
+ " -3.0627e-01, -6.8437e-01, 1.9035e-01, -1.0208e+00, -1.3259e-01,\n",
+ " 5.3231e-01, -4.7814e-01, -5.1266e-01, 2.4646e-02, -3.0552e-01,\n",
+ " 2.7398e-01, 5.8269e-01, 6.5481e-01, -4.2041e-01],\n",
+ " [ 1.9604e-01, 4.0597e-01, 1.9071e-01, -2.5535e-01, 1.1915e-01,\n",
+ " -6.7129e-02, 5.4386e-03, -8.2196e-02, -4.2803e-01, 7.0287e-01,\n",
+ " -3.0026e-01, -7.6001e-01, -5.1471e-03, -7.0283e-01, -9.2978e-02,\n",
+ " 1.2243e-01, -1.8398e-01, -4.7374e-01, 2.7978e-01, -3.6962e-01,\n",
+ " 5.6046e-02, 4.1773e-01, 4.9894e-01, -3.1945e-01],\n",
+ " [ 1.2657e-01, 3.3224e-01, 6.2830e-02, 1.5718e-01, 4.8844e-01,\n",
+ " -1.1476e-01, -1.5044e-01, 2.5265e-02, -2.0351e-01, 5.5770e-01,\n",
+ " -3.6036e-01, -7.4406e-01, 1.6962e-01, -9.6185e-01, -2.9334e-01,\n",
+ " 2.2584e-01, -4.1169e-01, -5.2146e-01, 2.3314e-01, -1.3668e-01,\n",
+ " -1.9598e-02, 3.8727e-01, 3.6892e-01, -3.3071e-01],\n",
+ " [ 5.2178e-01, 6.9704e-01, 5.0093e-01, 1.1157e-01, 8.0012e-02,\n",
+ " 3.6931e-01, -6.4927e-02, 1.1126e-01, -2.5117e-01, 5.3017e-01,\n",
+ " -2.6488e-01, -8.4056e-01, 2.2374e-01, -6.6831e-01, -1.9402e-01,\n",
+ " 7.4174e-02, -4.7763e-01, -2.6912e-01, 5.1009e-01, -5.4239e-01,\n",
+ " 3.0123e-01, 3.7529e-01, 4.1625e-01, -2.0141e-01],\n",
+ " [ 3.7968e-01, 4.9387e-01, 3.6786e-01, -1.3131e-01, 2.4445e-02,\n",
+ " 2.2155e-01, -4.0087e-02, -1.4872e-01, -5.5030e-01, 6.8958e-01,\n",
+ " -3.8156e-01, -7.5760e-01, 3.2085e-01, -6.4571e-01, 1.1268e-03,\n",
+ " 3.4251e-02, -2.6440e-01, -2.6374e-01, 5.9787e-01, -4.6502e-01,\n",
+ " 2.0074e-01, 4.5471e-01, 2.4238e-01, -4.3247e-01],\n",
+ " [ 2.9364e-01, 4.8659e-01, 9.0845e-02, 1.6348e-01, 5.7636e-01,\n",
+ " 4.5485e-01, -1.6781e-01, -1.4557e-01, -8.8814e-02, 6.6351e-01,\n",
+ " -5.3669e-01, -8.2818e-01, 6.0474e-01, -9.4558e-01, -3.0133e-01,\n",
+ " 3.0310e-01, -5.2493e-01, -2.5948e-01, 1.5857e-01, -4.2695e-01,\n",
+ " 2.1311e-01, 4.6502e-01, 8.7946e-02, -5.5815e-01],\n",
+ " [ 9.2208e-02, 2.9731e-01, 3.3849e-01, -5.1049e-02, 2.7834e-01,\n",
+ " -1.1120e-01, 1.1835e-01, 1.3665e-01, -2.1291e-01, 3.5107e-01,\n",
+ " -9.8108e-02, -5.0180e-01, 2.9894e-01, -7.7726e-01, -8.1317e-02,\n",
+ " 3.5704e-01, -3.6759e-01, -2.2148e-01, 1.1019e-01, -1.4452e-02,\n",
+ " 1.5092e-02, 3.3405e-01, 1.2765e-01, -4.0411e-01],\n",
+ " [ 2.8927e-02, 4.4180e-01, 1.0994e-01, 5.6124e-01, 4.7174e-01,\n",
+ " 1.9914e-01, -9.5047e-02, 3.1277e-02, -1.8656e-01, 5.0631e-01,\n",
+ " -3.4353e-01, -5.7425e-01, 4.3409e-01, -8.3343e-01, -1.1627e-01,\n",
+ " 3.1852e-02, -4.1274e-01, -2.6756e-01, 4.9652e-01, -2.6137e-01,\n",
+ " 2.8559e-02, 3.0587e-01, 3.6717e-01, -4.4303e-01],\n",
+ " [-1.0741e-01, 1.3539e-01, 1.5746e-01, 2.1208e-01, 6.3745e-01,\n",
+ " -2.1864e-01, -1.8820e-01, 2.1184e-01, -3.6832e-02, 3.0890e-01,\n",
+ " -2.4719e-03, -3.3573e-01, 1.8479e-01, -9.2119e-01, -2.3361e-01,\n",
+ " 8.9827e-02, -5.4372e-01, -4.4935e-01, 3.2967e-01, -9.2807e-02,\n",
+ " 9.9241e-02, 4.1705e-01, 2.4728e-01, -4.8119e-01],\n",
+ " [ 2.8125e-01, 5.3276e-01, 5.0110e-02, 2.0471e-01, 5.7750e-01,\n",
+ " 4.6670e-02, -2.1400e-01, 6.8794e-03, -6.8737e-02, 4.2138e-01,\n",
+ " -3.1261e-01, -7.3709e-01, 4.2001e-01, -9.9757e-01, -4.8091e-01,\n",
+ " 2.9960e-01, -6.2133e-01, -4.0566e-01, 3.2191e-01, -1.0219e-02,\n",
+ " 1.2901e-01, 3.9601e-01, 1.6291e-01, -3.3871e-01],\n",
+ " [ 2.9181e-01, 5.5400e-01, 3.0462e-01, 2.2431e-02, 2.8480e-01,\n",
+ " 4.4624e-01, -2.8859e-01, -1.4629e-01, -4.3573e-02, 2.9742e-01,\n",
+ " -1.0100e-01, -4.3070e-01, 4.6713e-01, -3.7132e-01, -8.6748e-02,\n",
+ " 2.5666e-01, -3.5361e-01, -2.3917e-02, 3.0071e-01, -3.2420e-01,\n",
+ " 1.3375e-01, 3.4475e-01, 3.0642e-01, -4.3496e-01],\n",
+ " [-7.7723e-04, 2.3828e-01, 2.3124e-01, 4.1347e-01, 6.8455e-01,\n",
+ " -9.8319e-03, 1.3403e-01, 1.8460e-02, -1.4025e-01, 5.9780e-01,\n",
+ " -3.7015e-01, -5.7865e-01, 4.9211e-01, -1.1262e+00, -2.1693e-01,\n",
+ " 3.2002e-01, -2.9313e-01, -3.1941e-01, 9.8446e-02, -6.2767e-02,\n",
+ " -9.8636e-03, 3.5712e-01, 2.8833e-01, -5.3506e-01]], device='cuda:0',\n",
+ " grad_fn=<AddmmBackward>)"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t(x)"
]
},
{
diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py
index 4186b4a..d9d81f6 100644
--- a/text_recognizer/callbacks/wandb_callbacks.py
+++ b/text_recognizer/callbacks/wandb_callbacks.py
@@ -93,6 +93,40 @@ class LogTextPredictions(Callback):
def __attrs_pre_init__(self) -> None:
super().__init__()
+ def _log_predictions(
+ stage: str, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Logs the predicted text contained in the images."""
+ if not self.ready:
+ return None
+
+ logger = get_wandb_logger(trainer)
+ experiment = logger.experiment
+
+ # Get a validation batch from the validation dataloader.
+ samples = next(iter(trainer.datamodule.val_dataloader()))
+ imgs, labels = samples
+
+ imgs = imgs.to(device=pl_module.device)
+ logits = pl_module(imgs)
+
+ mapping = pl_module.mapping
+ experiment.log(
+ {
+ f"OCR/{experiment.name}/{stage}": [
+ wandb.Image(
+ img,
+ caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}",
+ )
+ for img, pred, label in zip(
+ imgs[: self.num_samples],
+ logits[: self.num_samples],
+ labels[: self.num_samples],
+ )
+ ]
+ }
+ )
+
def on_sanity_check_start(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
@@ -107,6 +141,27 @@ class LogTextPredictions(Callback):
self, trainer: Trainer, pl_module: LightningModule
) -> None:
"""Logs predictions on validation epoch end."""
+ self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module)
+
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Logs predictions on train epoch end."""
+ self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module)
+
+
+@attr.s
+class LogReconstuctedImages(Callback):
+ """Log reconstructions of images."""
+
+ num_samples: int = attr.ib(default=8)
+ ready: bool = attr.ib(default=True)
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ def _log_reconstruction(
+ self, stage: str, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Logs the reconstructions."""
if not self.ready:
return None
@@ -115,24 +170,42 @@ class LogTextPredictions(Callback):
# Get a validation batch from the validation dataloader.
samples = next(iter(trainer.datamodule.val_dataloader()))
- imgs, labels = samples
+ imgs, _ = samples
imgs = imgs.to(device=pl_module.device)
- logits = pl_module(imgs)
+ reconstructions = pl_module(imgs)
- mapping = pl_module.mapping
experiment.log(
{
- f"Images/{experiment.name}": [
- wandb.Image(
- img,
- caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}",
- )
- for img, pred, label in zip(
+ f"Reconstructions/{experiment.name}/{stage}": [
+ [
+ wandb.Image(img),
+ wandb.Image(rec),
+ ]
+ for img, rec in zip(
imgs[: self.num_samples],
- logits[: self.num_samples],
- labels[: self.num_samples],
+ reconstructions[: self.num_samples],
)
]
}
)
+
+ def on_sanity_check_start(
+ self, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Sets ready attribute."""
+ self.ready = False
+
+ def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Start executing this callback only after all validation sanity checks end."""
+ self.ready = True
+
+ def on_validation_epoch_end(
+ self, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Logs predictions on validation epoch end."""
+ self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module)
+
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Logs predictions on train epoch end."""
+ self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module)
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index de5628f..18b1996 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,11 +1,13 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict
+from typing import Any, Dict, Tuple
import attr
-import pytorch_lightning as LightningDataModule
+from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
+from text_recognizer.data.base_dataset import BaseDataset
+
def load_and_print_info(data_module_class: type) -> None:
"""Load dataset and print dataset information."""
@@ -19,17 +21,20 @@ def load_and_print_info(data_module_class: type) -> None:
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- batch_size: int = attr.ib(default=16)
- num_workers: int = attr.ib(default=0)
-
def __attrs_pre_init__(self) -> None:
super().__init__()
- def __attrs_post_init__(self) -> None:
- # Placeholders for subclasses.
- self.dims = None
- self.output_dims = None
- self.mapping = None
+ batch_size: int = attr.ib(default=16)
+ num_workers: int = attr.ib(default=0)
+
+ # Placeholders
+ data_train: BaseDataset = attr.ib(init=False, default=None)
+ data_val: BaseDataset = attr.ib(init=False, default=None)
+ data_test: BaseDataset = attr.ib(init=False, default=None)
+ dims: Tuple[int, ...] = attr.ib(init=False, default=None)
+ output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
+ mapping: Any = attr.ib(init=False, default=None)
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
@classmethod
def data_dirname(cls) -> Path:
@@ -58,9 +63,7 @@ class BaseDataModule(LightningDataModule):
stage (Any): Variable to set splits.
"""
- self.data_train = None
- self.data_val = None
- self.data_test = None
+ pass
def train_dataloader(self) -> DataLoader:
"""Retun DataLoader for train data."""
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 824b947..d51a42a 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -3,9 +3,10 @@ import json
import os
from pathlib import Path
import shutil
-from typing import Dict, List, Optional, Sequence, Tuple
+from typing import Callable, Dict, List, Optional, Sequence, Tuple
import zipfile
+import attr
import h5py
from loguru import logger
import numpy as np
@@ -32,6 +33,7 @@ PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
+@attr.s(auto_attribs=True)
class EMNIST(BaseDataModule):
"""Lightning DataModule class for loading EMNIST dataset.
@@ -44,18 +46,12 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- def __init__(
- self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
- ) -> None:
- super().__init__(batch_size, num_workers)
- self.train_fraction = train_fraction
- self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping()
- self.data_train = None
- self.data_val = None
- self.data_test = None
- self.transform = T.Compose([T.ToTensor()])
- self.dims = (1, *self.input_shape)
- self.output_dims = (1,)
+ train_fraction: float = attr.ib()
+ transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
+
+ def __attrs_post_init__(self) -> None:
+ self.mapping, self.inverse_mapping, input_shape = emnist_mapping()
+ self.dims = (1, *input_shape)
def prepare_data(self) -> None:
"""Downloads dataset if not present."""
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 9650198..4747508 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -3,6 +3,7 @@ from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, Tuple
+import attr
import h5py
from loguru import logger
import numpy as np
@@ -31,31 +32,20 @@ IMAGE_X_PADDING = 28
MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
+@attr.s(auto_attribs=True)
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,"""
- def __init__(
- self,
- augment: bool = True,
- batch_size: int = 128,
- num_workers: int = 0,
- max_length: int = 32,
- min_overlap: float = 0.0,
- max_overlap: float = 0.33,
- num_train: int = 10_000,
- num_val: int = 2_000,
- num_test: int = 2_000,
- ) -> None:
- super().__init__(batch_size, num_workers)
-
- self.augment = augment
- self.max_length = max_length
- self.min_overlap = min_overlap
- self.max_overlap = max_overlap
- self.num_train = num_train
- self.num_val = num_val
- self.num_test = num_test
+ augment: bool = attr.ib(default=True)
+ max_length: int = attr.ib(default=128)
+ min_overlap: float = attr.ib(default=0.0)
+ max_overlap: float = attr.ib(default=0.33)
+ num_train: int = attr.ib(default=10_000)
+ num_val: int = attr.ib(default=2_000)
+ num_test: int = attr.ib(default=2_000)
+ emnist: EMNIST = attr.ib(init=False, default=None)
+ def __attrs_post_init__(self) -> None:
self.emnist = EMNIST()
self.mapping = self.emnist.mapping
@@ -75,9 +65,6 @@ class EMNISTLines(BaseDataModule):
raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
self.output_dims = (MAX_OUTPUT_LENGTH, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
@property
def data_filename(self) -> Path:
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 261c8d3..3982c4f 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -5,6 +5,7 @@ from typing import Any, Dict, List
import xml.etree.ElementTree as ElementTree
import zipfile
+import attr
from boltons.cacheutils import cachedproperty
from loguru import logger
import toml
@@ -22,6 +23,7 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates.
+@attr.s(auto_attribs=True)
class IAM(BaseDataModule):
"""
"The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
@@ -35,9 +37,7 @@ class IAM(BaseDataModule):
The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
"""
- def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
- super().__init__(batch_size, num_workers)
- self.metadata = toml.load(METADATA_FILENAME)
+ metadata: Dict = attr.ib(init=False, default=toml.load(METADATA_FILENAME))
def prepare_data(self) -> None:
if self.xml_filenames:
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 0a30a42..886e37e 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,4 +1,7 @@
"""IAM original and sythetic dataset class."""
+from typing import Dict, List
+
+import attr
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_dataset import BaseDataset
@@ -7,22 +10,26 @@ from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
+@attr.s(auto_attribs=True)
class IAMExtendedParagraphs(BaseDataModule):
- def __init__(
- self,
- batch_size: int = 16,
- num_workers: int = 0,
- train_fraction: float = 0.8,
- augment: bool = True,
- word_pieces: bool = False,
- ) -> None:
- super().__init__(batch_size, num_workers)
+ train_fraction: float = attr.ib()
+ word_pieces: bool = attr.ib(default=False)
+
+ def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
- batch_size, num_workers, train_fraction, augment, word_pieces,
+ self.batch_size,
+ self.num_workers,
+ self.train_fraction,
+ self.augment,
+ self.word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- batch_size, num_workers, train_fraction, augment, word_pieces,
+ self.batch_size,
+ self.num_workers,
+ self.train_fraction,
+ self.augment,
+ self.word_pieces,
)
self.dims = self.iam_paragraphs.dims
@@ -30,10 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule):
self.mapping = self.iam_paragraphs.mapping
self.inverse_mapping = self.iam_paragraphs.inverse_mapping
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
-
def prepare_data(self) -> None:
"""Prepares the paragraphs data."""
self.iam_paragraphs.prepare_data()
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 9c78a22..e45e5c8 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -7,8 +7,9 @@ dataset.
import json
from pathlib import Path
import random
-from typing import List, Sequence, Tuple
+from typing import Dict, List, Sequence, Tuple
+import attr
from loguru import logger
from PIL import Image, ImageFile, ImageOps
import numpy as np
@@ -35,26 +36,17 @@ IMAGE_HEIGHT = 56
IMAGE_WIDTH = 1024
+@attr.s(auto_attribs=True)
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
- def __init__(
- self,
- augment: bool = True,
- fraction: float = 0.8,
- batch_size: int = 128,
- num_workers: int = 0,
- ) -> None:
- # TODO: add transforms
- super().__init__(batch_size, num_workers)
- self.augment = augment
- self.fraction = fraction
+ augment: bool = attr.ib(default=True)
+ fraction: float = attr.ib(default=0.8)
+
+ def __attrs_post_init__(self) -> None:
self.mapping, self.inverse_mapping, _ = emnist_mapping()
self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
self.output_dims = (89, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index fe60e99..445b788 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -3,6 +3,7 @@ import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
+import attr
from loguru import logger
import numpy as np
from PIL import Image, ImageOps
@@ -33,33 +34,25 @@ IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR
MAX_LABEL_LENGTH = 682
+@attr.s(auto_attribs=True)
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- def __init__(
- self,
- batch_size: int = 16,
- num_workers: int = 0,
- train_fraction: float = 0.8,
- augment: bool = True,
- word_pieces: bool = False,
- ) -> None:
- super().__init__(batch_size, num_workers)
- self.augment = augment
- self.word_pieces = word_pieces
+ augment: bool = attr.ib(default=True)
+ train_fraction: float = attr.ib(default=0.8)
+ word_pieces: bool = attr.ib(default=False)
+
+ def __attrs_post_init__(self) -> None:
self.mapping, self.inverse_mapping, _ = emnist_mapping(
extra_symbols=[NEW_LINE_TOKEN]
)
- if word_pieces:
+ if self.word_pieces:
self.mapping = WordPieceMapping()
self.train_fraction = train_fraction
self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
self.output_dims = (MAX_LABEL_LENGTH, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
def prepare_data(self) -> None:
"""Create data for training/testing."""
@@ -166,7 +159,10 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
+ "crop_shape": {
+ "min": crop_shapes.min(axis=0),
+ "max": crop_shapes.max(axis=0),
+ },
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -287,7 +283,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
),
T.ColorJitter(brightness=(0.8, 1.6)),
T.RandomAffine(
- degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
+ degrees=1,
+ shear=(-10, 10),
+ interpolation=InterpolationMode.BILINEAR,
),
]
else:
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 8dc7a36..f95df0f 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -5,7 +5,7 @@ import attr
import hydra
import loguru.logger as log
from omegaconf import DictConfig
-import pytorch_lightning as pl
+import pytorch_lightning as LightningModule
import torch
from torch import nn
from torch import Tensor
@@ -13,7 +13,7 @@ import torchmetrics
@attr.s
-class BaseLitModel(pl.LightningModule):
+class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
network: Type[nn.Module] = attr.ib()
@@ -80,7 +80,6 @@ class BaseLitModel(pl.LightningModule):
"""Configures optimizer and lr scheduler."""
optimizer = self._configure_optimizer()
scheduler = self._configure_lr_scheduler(optimizer)
-
return [optimizer], [scheduler]
def forward(self, data: Tensor) -> Tensor:
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index 58d0537..4117ae2 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -1,18 +1,23 @@
"""Character Error Rate (CER)."""
-from typing import Sequence
+from typing import Set, Sequence
+import attr
import editdistance
import torch
from torch import Tensor
-import torchmetrics
+from torchmetrics import Metric
-class CharacterErrorRate(torchmetrics.Metric):
+@attr.s
+class CharacterErrorRate(Metric):
"""Character error rate metric, computed using Levenshtein distance."""
- def __init__(self, ignore_tokens: Sequence[int], *args) -> None:
+ ignore_tokens: Set = attr.ib(converter=set)
+ error: Tensor = attr.ib(init=False)
+ total: Tensor = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
super().__init__()
- self.ignore_tokens = set(ignore_tokens)
self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index ea54d83..8c9fe8a 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -2,35 +2,24 @@
from typing import Dict, List, Optional, Union, Tuple, Type
import attr
+import hydra
from omegaconf import DictConfig
from torch import nn, Tensor
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
-from text_recognizer.models.base import LitBaseModel
+from text_recognizer.models.base import BaseLitModel
-@attr.s
-class TransformerLitModel(LitBaseModel):
+@attr.s(auto_attribs=True)
+class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- network: Type[nn.Module] = attr.ib()
- criterion_config: DictConfig = attr.ib(converter=DictConfig)
- optimizer_config: DictConfig = attr.ib(converter=DictConfig)
- lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
- monitor: str = attr.ib()
- mapping: Type[AbstractMapping] = attr.ib()
+ mapping_config: DictConfig = attr.ib(converter=DictConfig)
def __attrs_post_init__(self) -> None:
- super().__init__(
- network=self.network,
- optimizer_config=self.optimizer_config,
- lr_scheduler_config=self.lr_scheduler_config,
- criterion_config=self.criterion_config,
- monitor=self.monitor,
- )
- self.mapping, ignore_tokens = self.configure_mapping(mapping)
+ self.mapping, ignore_tokens = self._configure_mapping()
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)
@@ -39,9 +28,10 @@ class TransformerLitModel(LitBaseModel):
return self.network.predict(data)
@staticmethod
- def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]:
+ def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]:
"""Configure mapping."""
# TODO: Fix me!!!
+ # Load config with hydra
mapping, inverse_mapping, _ = emnist_mapping(["\n"])
start_index = inverse_mapping["<s>"]
end_index = inverse_mapping["<e>"]
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index 7dc950f..0172163 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -1,49 +1,23 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Any, Dict, Union, Tuple, Type
+import attr
from omegaconf import DictConfig
from torch import nn
from torch import Tensor
import wandb
-from text_recognizer.models.base import LitBaseModel
+from text_recognizer.models.base import BaseLitModel
-class LitVQVAEModel(LitBaseModel):
+@attr.s(auto_attribs=True)
+class VQVAELitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- def __init__(
- self,
- network: Type[nn.Module],
- optimizer: Union[DictConfig, Dict],
- lr_scheduler: Union[DictConfig, Dict],
- criterion: Union[DictConfig, Dict],
- monitor: str = "val/loss",
- *args: Any,
- **kwargs: Dict,
- ) -> None:
- super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
-
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
return self.network.predict(data)
- def _log_prediction(
- self, data: Tensor, reconstructions: Tensor, title: str
- ) -> None:
- """Logs prediction on image with wandb."""
- try:
- self.logger.experiment.log(
- {
- title: [
- wandb.Image(data[0]),
- wandb.Image(reconstructions[0]),
- ]
- }
- )
- except AttributeError:
- pass
-
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, _ = batch
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index 109bf4d..85094f1 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -1,4 +1,4 @@
-"""Miscellaneous neural network functionality."""
+"""Miscellaneous neural network utility functionality."""
from typing import Type
from torch import nn
diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml
new file mode 100644
index 0000000..3070b56
--- /dev/null
+++ b/training/conf/datamodule/iam_extended_paragraphs.yaml
@@ -0,0 +1,5 @@
+_target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
+batch_size: 32
+num_workers: 12
+train_fraction: 0.8
+augment: true
diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml
deleted file mode 100644
index 6439a15..0000000
--- a/training/conf/dataset/iam_extended_paragraphs.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-type: IAMExtendedParagraphs
-args:
- batch_size: 32
- num_workers: 12
- train_fraction: 0.8
- augment: true
diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml
index 60a6f27..e8cb5c4 100644
--- a/training/conf/lr_scheduler/one_cycle.yaml
+++ b/training/conf/lr_scheduler/one_cycle.yaml
@@ -1,8 +1,15 @@
-type: OneCycleLR
-args:
- interval: step
- max_lr: 1.0e-3
- three_phase: true
- epochs: 64
- steps_per_epoch: 633 # num_samples / batch_size
-monitor: val_loss
+_target_: torch.optim.lr_scheduler.OneCycleLR
+max_lr: 1.0e-3
+total_steps: None
+epochs: None
+steps_per_epoch: None
+pct_start: 0.3
+anneal_strategy: 'cos'
+cycle_momentum: True
+base_momentum: 0.85
+max_momentum: 0.95
+div_factor: 25.0
+final_div_factor: 10000.0
+three_phase: true
+last_epoch: -1
+verbose: false
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
index 7136dbd..6be37e5 100644
--- a/training/conf/model/lit_vqvae.yaml
+++ b/training/conf/model/lit_vqvae.yaml
@@ -1,3 +1,3 @@
-type: LitVQVAEModel
+_target_: text_recognizer.models.vqvae.VQVAELitModel
args:
mapping: sentence_piece
diff --git a/training/run.py b/training/run.py
index 5f7c927..31da666 100644
--- a/training/run.py
+++ b/training/run.py
@@ -51,6 +51,7 @@ def run(config: DictConfig) -> Optional[float]:
)
# Log hyperparameters
+ log.info("Logging hyperparameters")
utils.log_hyperparameters(config=config, model=model, trainer=trainer)
if config.debug:
diff --git a/training/utils.py b/training/utils.py
index 4c31dc3..140d97e 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -1,4 +1,4 @@
-"""Util functions for training hydra configs and pytorch lightning."""
+"""Util functions for training with hydra and pytorch lightning."""
from typing import Any, List, Type
import warnings