diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-06 17:42:53 +0200 | 
| commit | eb5b206f7e1b08435378d2a02395307be55ee6f1 (patch) | |
| tree | 0cd30234afab698eb632b20a7da97e3bc7e98882 | |
| parent | 4d1f2cef39688871d2caafce42a09316381a27ae (diff) | |
Refactoring data with attrs and refactor conf for hydra
| -rw-r--r-- | notebooks/00-scratch-pad.ipynb | 275 | ||||
| -rw-r--r-- | text_recognizer/callbacks/wandb_callbacks.py | 95 | ||||
| -rw-r--r-- | text_recognizer/data/base_data_module.py | 29 | ||||
| -rw-r--r-- | text_recognizer/data/emnist.py | 22 | ||||
| -rw-r--r-- | text_recognizer/data/emnist_lines.py | 35 | ||||
| -rw-r--r-- | text_recognizer/data/iam.py | 6 | ||||
| -rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 33 | ||||
| -rw-r--r-- | text_recognizer/data/iam_lines.py | 22 | ||||
| -rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 32 | ||||
| -rw-r--r-- | text_recognizer/models/base.py | 5 | ||||
| -rw-r--r-- | text_recognizer/models/metrics.py | 15 | ||||
| -rw-r--r-- | text_recognizer/models/transformer.py | 26 | ||||
| -rw-r--r-- | text_recognizer/models/vqvae.py | 34 | ||||
| -rw-r--r-- | text_recognizer/networks/util.py | 2 | ||||
| -rw-r--r-- | training/conf/datamodule/iam_extended_paragraphs.yaml | 5 | ||||
| -rw-r--r-- | training/conf/dataset/iam_extended_paragraphs.yaml | 6 | ||||
| -rw-r--r-- | training/conf/lr_scheduler/one_cycle.yaml | 23 | ||||
| -rw-r--r-- | training/conf/model/lit_vqvae.yaml | 2 | ||||
| -rw-r--r-- | training/run.py | 1 | ||||
| -rw-r--r-- | training/utils.py | 2 | 
20 files changed, 468 insertions, 202 deletions
| diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index 16c6533..1e30038 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -2,7 +2,7 @@   "cells": [    {     "cell_type": "code", -   "execution_count": 1, +   "execution_count": 12,     "metadata": {},     "outputs": [],     "source": [ @@ -30,7 +30,7 @@    },    {     "cell_type": "code", -   "execution_count": 1, +   "execution_count": 13,     "metadata": {},     "outputs": [],     "source": [ @@ -39,7 +39,7 @@    },    {     "cell_type": "code", -   "execution_count": 2, +   "execution_count": 14,     "metadata": {},     "outputs": [],     "source": [ @@ -48,41 +48,280 @@    },    {     "cell_type": "code", -   "execution_count": 9, +   "execution_count": 44,     "metadata": {},     "outputs": [],     "source": [      "@attr.s\n", -    "class B:\n", -    "    batch_size = attr.ib()\n", -    "    num_workers = attr.ib()" +    "class B(nn.Module):\n", +    "    input_dim = attr.ib()\n", +    "    hidden = attr.ib()\n", +    "    xx = attr.ib(init=False, default=\"hek\")\n", +    "    \n", +    "    def __attrs_post_init__(self):\n", +    "        super().__init__()\n", +    "        self.fc = nn.Linear(self.input_dim, self.hidden)\n", +    "        self.xx = \"da\"\n", +    "    \n", +    "    def forward(self, x):\n", +    "        return self.fc(x)"     ]    },    {     "cell_type": "code", -   "execution_count": 10, +   "execution_count": 49,     "metadata": {},     "outputs": [],     "source": [ -    "@attr.s\n", -    "class T(B):\n", +    "def f(x):\n", +    "    return 2\n",      "\n", -    "    def __attrs_post_init__(self) -> None:\n", -    "        super().__init__(self.batch_size, self.num_workers)\n", -    "        self.hej = None\n", +    "@attr.s(auto_attribs=True)\n", +    "class T(B):\n",      "    \n", -    "    batch_size = attr.ib()\n", -    "    num_workers = attr.ib()\n", -    "    h: Path = attr.ib(converter=Path)" +    "    h: Path = attr.ib(converter=Path)\n", +    "    p: int = attr.ib(init=False, default=f(3))"     ]    },    {     "cell_type": "code", -   "execution_count": 11, +   "execution_count": 53, +   "metadata": {}, +   "outputs": [ +    { +     "ename": "TypeError", +     "evalue": "__init__() missing 1 required positional argument: 'hidden'", +     "output_type": "error", +     "traceback": [ +      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", +      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)", +      "\u001b[0;32m<ipython-input-53-ef8b390156f4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m16\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"hej\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", +      "\u001b[0;31mTypeError\u001b[0m: __init__() missing 1 required positional argument: 'hidden'" +     ] +    } +   ], +   "source": [ +    "t = T(input_dim=16, h=\"hej\")" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 51, +   "metadata": {}, +   "outputs": [ +    { +     "data": { +      "text/plain": [ +       "'da'" +      ] +     }, +     "execution_count": 51, +     "metadata": {}, +     "output_type": "execute_result" +    } +   ], +   "source": [ +    "t.xx" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 52, +   "metadata": {}, +   "outputs": [ +    { +     "data": { +      "text/plain": [ +       "2" +      ] +     }, +     "execution_count": 52, +     "metadata": {}, +     "output_type": "execute_result" +    } +   ], +   "source": [ +    "t.p" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 19, +   "metadata": {}, +   "outputs": [ +    { +     "data": { +      "text/plain": [ +       "16" +      ] +     }, +     "execution_count": 19, +     "metadata": {}, +     "output_type": "execute_result" +    } +   ], +   "source": [ +    "t.input_dim" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 20, +   "metadata": {}, +   "outputs": [], +   "source": [ +    "x = torch.rand(16, 16)" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 21, +   "metadata": {}, +   "outputs": [ +    { +     "data": { +      "text/plain": [ +       "torch.Size([16, 16])" +      ] +     }, +     "execution_count": 21, +     "metadata": {}, +     "output_type": "execute_result" +    } +   ], +   "source": [ +    "x.shape" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 23, +   "metadata": {}, +   "outputs": [ +    { +     "data": { +      "text/plain": [ +       "T(input_dim=16, hidden=24, h=PosixPath('hej'))" +      ] +     }, +     "execution_count": 23, +     "metadata": {}, +     "output_type": "execute_result" +    } +   ], +   "source": [ +    "t.cuda()" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 24,     "metadata": {},     "outputs": [],     "source": [ -    "t = T(batch_size=16, num_workers=2, h=\"hej\")" +    "x = x.cuda()" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 25, +   "metadata": { +    "scrolled": true +   }, +   "outputs": [ +    { +     "data": { +      "text/plain": [ +       "tensor([[ 3.6047e-01,  1.0200e+00,  3.6786e-01,  1.6077e-01,  3.9281e-02,\n", +       "          3.2830e-01,  1.3433e-01, -9.0334e-02, -3.8712e-01,  8.1547e-01,\n", +       "         -5.4483e-01, -9.7471e-01,  3.3706e-01, -9.5283e-01, -1.6271e-01,\n", +       "          3.8504e-01, -5.0106e-01, -4.8638e-01,  3.7033e-01, -4.9557e-01,\n", +       "          2.6555e-01,  5.1245e-01,  6.6751e-01, -2.6291e-01],\n", +       "        [ 1.3811e-01,  7.4522e-01,  4.9935e-01,  3.3878e-01,  1.8501e-01,\n", +       "          2.2269e-02, -2.0328e-01,  1.4629e-01, -2.2957e-01,  4.1197e-01,\n", +       "         -1.9555e-01, -4.7609e-01,  9.0206e-02, -8.8568e-01, -2.1618e-01,\n", +       "          2.8882e-01, -5.4335e-01, -6.6301e-01,  4.9990e-01, -4.0144e-01,\n", +       "          3.6403e-01,  5.3901e-01,  8.6665e-01, -7.8312e-02],\n", +       "        [ 1.6493e-02,  4.6157e-01,  2.9500e-02,  2.4190e-01,  6.5753e-01,\n", +       "          4.3770e-02, -5.3773e-02,  1.8183e-01, -2.5983e-02,  4.1634e-01,\n", +       "         -3.5218e-01, -5.6129e-01,  4.1452e-01, -1.2265e+00, -5.8544e-01,\n", +       "          3.6382e-01, -6.4090e-01, -5.8679e-01,  4.3489e-02, -1.1233e-01,\n", +       "          3.1175e-01,  4.2857e-01,  1.6501e-01, -2.4118e-01],\n", +       "        [ 9.2361e-02,  6.0196e-01,  1.3081e-02, -8.1091e-02,  4.2342e-01,\n", +       "         -8.8457e-02, -8.1851e-02, -1.1562e-01, -1.5049e-01,  4.9972e-01,\n", +       "         -3.0432e-01, -7.8619e-01,  2.1060e-01, -1.0598e+00, -4.6542e-01,\n", +       "          4.2382e-01, -6.5671e-01, -4.8589e-01,  5.5977e-02, -2.9478e-02,\n", +       "          8.5718e-02,  4.7685e-01,  4.8351e-01, -2.8142e-01],\n", +       "        [ 1.3377e-01,  5.4434e-01,  3.4505e-01,  1.1307e-01,  4.4057e-01,\n", +       "         -7.6075e-03,  1.3841e-01, -1.1497e-01, -1.3177e-01,  8.0254e-01,\n", +       "         -3.0627e-01, -6.8437e-01,  1.9035e-01, -1.0208e+00, -1.3259e-01,\n", +       "          5.3231e-01, -4.7814e-01, -5.1266e-01,  2.4646e-02, -3.0552e-01,\n", +       "          2.7398e-01,  5.8269e-01,  6.5481e-01, -4.2041e-01],\n", +       "        [ 1.9604e-01,  4.0597e-01,  1.9071e-01, -2.5535e-01,  1.1915e-01,\n", +       "         -6.7129e-02,  5.4386e-03, -8.2196e-02, -4.2803e-01,  7.0287e-01,\n", +       "         -3.0026e-01, -7.6001e-01, -5.1471e-03, -7.0283e-01, -9.2978e-02,\n", +       "          1.2243e-01, -1.8398e-01, -4.7374e-01,  2.7978e-01, -3.6962e-01,\n", +       "          5.6046e-02,  4.1773e-01,  4.9894e-01, -3.1945e-01],\n", +       "        [ 1.2657e-01,  3.3224e-01,  6.2830e-02,  1.5718e-01,  4.8844e-01,\n", +       "         -1.1476e-01, -1.5044e-01,  2.5265e-02, -2.0351e-01,  5.5770e-01,\n", +       "         -3.6036e-01, -7.4406e-01,  1.6962e-01, -9.6185e-01, -2.9334e-01,\n", +       "          2.2584e-01, -4.1169e-01, -5.2146e-01,  2.3314e-01, -1.3668e-01,\n", +       "         -1.9598e-02,  3.8727e-01,  3.6892e-01, -3.3071e-01],\n", +       "        [ 5.2178e-01,  6.9704e-01,  5.0093e-01,  1.1157e-01,  8.0012e-02,\n", +       "          3.6931e-01, -6.4927e-02,  1.1126e-01, -2.5117e-01,  5.3017e-01,\n", +       "         -2.6488e-01, -8.4056e-01,  2.2374e-01, -6.6831e-01, -1.9402e-01,\n", +       "          7.4174e-02, -4.7763e-01, -2.6912e-01,  5.1009e-01, -5.4239e-01,\n", +       "          3.0123e-01,  3.7529e-01,  4.1625e-01, -2.0141e-01],\n", +       "        [ 3.7968e-01,  4.9387e-01,  3.6786e-01, -1.3131e-01,  2.4445e-02,\n", +       "          2.2155e-01, -4.0087e-02, -1.4872e-01, -5.5030e-01,  6.8958e-01,\n", +       "         -3.8156e-01, -7.5760e-01,  3.2085e-01, -6.4571e-01,  1.1268e-03,\n", +       "          3.4251e-02, -2.6440e-01, -2.6374e-01,  5.9787e-01, -4.6502e-01,\n", +       "          2.0074e-01,  4.5471e-01,  2.4238e-01, -4.3247e-01],\n", +       "        [ 2.9364e-01,  4.8659e-01,  9.0845e-02,  1.6348e-01,  5.7636e-01,\n", +       "          4.5485e-01, -1.6781e-01, -1.4557e-01, -8.8814e-02,  6.6351e-01,\n", +       "         -5.3669e-01, -8.2818e-01,  6.0474e-01, -9.4558e-01, -3.0133e-01,\n", +       "          3.0310e-01, -5.2493e-01, -2.5948e-01,  1.5857e-01, -4.2695e-01,\n", +       "          2.1311e-01,  4.6502e-01,  8.7946e-02, -5.5815e-01],\n", +       "        [ 9.2208e-02,  2.9731e-01,  3.3849e-01, -5.1049e-02,  2.7834e-01,\n", +       "         -1.1120e-01,  1.1835e-01,  1.3665e-01, -2.1291e-01,  3.5107e-01,\n", +       "         -9.8108e-02, -5.0180e-01,  2.9894e-01, -7.7726e-01, -8.1317e-02,\n", +       "          3.5704e-01, -3.6759e-01, -2.2148e-01,  1.1019e-01, -1.4452e-02,\n", +       "          1.5092e-02,  3.3405e-01,  1.2765e-01, -4.0411e-01],\n", +       "        [ 2.8927e-02,  4.4180e-01,  1.0994e-01,  5.6124e-01,  4.7174e-01,\n", +       "          1.9914e-01, -9.5047e-02,  3.1277e-02, -1.8656e-01,  5.0631e-01,\n", +       "         -3.4353e-01, -5.7425e-01,  4.3409e-01, -8.3343e-01, -1.1627e-01,\n", +       "          3.1852e-02, -4.1274e-01, -2.6756e-01,  4.9652e-01, -2.6137e-01,\n", +       "          2.8559e-02,  3.0587e-01,  3.6717e-01, -4.4303e-01],\n", +       "        [-1.0741e-01,  1.3539e-01,  1.5746e-01,  2.1208e-01,  6.3745e-01,\n", +       "         -2.1864e-01, -1.8820e-01,  2.1184e-01, -3.6832e-02,  3.0890e-01,\n", +       "         -2.4719e-03, -3.3573e-01,  1.8479e-01, -9.2119e-01, -2.3361e-01,\n", +       "          8.9827e-02, -5.4372e-01, -4.4935e-01,  3.2967e-01, -9.2807e-02,\n", +       "          9.9241e-02,  4.1705e-01,  2.4728e-01, -4.8119e-01],\n", +       "        [ 2.8125e-01,  5.3276e-01,  5.0110e-02,  2.0471e-01,  5.7750e-01,\n", +       "          4.6670e-02, -2.1400e-01,  6.8794e-03, -6.8737e-02,  4.2138e-01,\n", +       "         -3.1261e-01, -7.3709e-01,  4.2001e-01, -9.9757e-01, -4.8091e-01,\n", +       "          2.9960e-01, -6.2133e-01, -4.0566e-01,  3.2191e-01, -1.0219e-02,\n", +       "          1.2901e-01,  3.9601e-01,  1.6291e-01, -3.3871e-01],\n", +       "        [ 2.9181e-01,  5.5400e-01,  3.0462e-01,  2.2431e-02,  2.8480e-01,\n", +       "          4.4624e-01, -2.8859e-01, -1.4629e-01, -4.3573e-02,  2.9742e-01,\n", +       "         -1.0100e-01, -4.3070e-01,  4.6713e-01, -3.7132e-01, -8.6748e-02,\n", +       "          2.5666e-01, -3.5361e-01, -2.3917e-02,  3.0071e-01, -3.2420e-01,\n", +       "          1.3375e-01,  3.4475e-01,  3.0642e-01, -4.3496e-01],\n", +       "        [-7.7723e-04,  2.3828e-01,  2.3124e-01,  4.1347e-01,  6.8455e-01,\n", +       "         -9.8319e-03,  1.3403e-01,  1.8460e-02, -1.4025e-01,  5.9780e-01,\n", +       "         -3.7015e-01, -5.7865e-01,  4.9211e-01, -1.1262e+00, -2.1693e-01,\n", +       "          3.2002e-01, -2.9313e-01, -3.1941e-01,  9.8446e-02, -6.2767e-02,\n", +       "         -9.8636e-03,  3.5712e-01,  2.8833e-01, -5.3506e-01]], device='cuda:0',\n", +       "       grad_fn=<AddmmBackward>)" +      ] +     }, +     "execution_count": 25, +     "metadata": {}, +     "output_type": "execute_result" +    } +   ], +   "source": [ +    "t(x)"     ]    },    { diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py index 4186b4a..d9d81f6 100644 --- a/text_recognizer/callbacks/wandb_callbacks.py +++ b/text_recognizer/callbacks/wandb_callbacks.py @@ -93,6 +93,40 @@ class LogTextPredictions(Callback):      def __attrs_pre_init__(self) -> None:          super().__init__() +    def _log_predictions( +        stage: str, trainer: Trainer, pl_module: LightningModule +    ) -> None: +        """Logs the predicted text contained in the images.""" +        if not self.ready: +            return None + +        logger = get_wandb_logger(trainer) +        experiment = logger.experiment + +        # Get a validation batch from the validation dataloader. +        samples = next(iter(trainer.datamodule.val_dataloader())) +        imgs, labels = samples + +        imgs = imgs.to(device=pl_module.device) +        logits = pl_module(imgs) + +        mapping = pl_module.mapping +        experiment.log( +            { +                f"OCR/{experiment.name}/{stage}": [ +                    wandb.Image( +                        img, +                        caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}", +                    ) +                    for img, pred, label in zip( +                        imgs[: self.num_samples], +                        logits[: self.num_samples], +                        labels[: self.num_samples], +                    ) +                ] +            } +        ) +      def on_sanity_check_start(          self, trainer: Trainer, pl_module: LightningModule      ) -> None: @@ -107,6 +141,27 @@ class LogTextPredictions(Callback):          self, trainer: Trainer, pl_module: LightningModule      ) -> None:          """Logs predictions on validation epoch end.""" +        self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module) + +    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: +        """Logs predictions on train epoch end.""" +        self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module) + + +@attr.s +class LogReconstuctedImages(Callback): +    """Log reconstructions of images.""" + +    num_samples: int = attr.ib(default=8) +    ready: bool = attr.ib(default=True) + +    def __attrs_pre_init__(self) -> None: +        super().__init__() + +    def _log_reconstruction( +        self, stage: str, trainer: Trainer, pl_module: LightningModule +    ) -> None: +        """Logs the reconstructions."""          if not self.ready:              return None @@ -115,24 +170,42 @@ class LogTextPredictions(Callback):          # Get a validation batch from the validation dataloader.          samples = next(iter(trainer.datamodule.val_dataloader())) -        imgs, labels = samples +        imgs, _ = samples          imgs = imgs.to(device=pl_module.device) -        logits = pl_module(imgs) +        reconstructions = pl_module(imgs) -        mapping = pl_module.mapping          experiment.log(              { -                f"Images/{experiment.name}": [ -                    wandb.Image( -                        img, -                        caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}", -                    ) -                    for img, pred, label in zip( +                f"Reconstructions/{experiment.name}/{stage}": [ +                    [ +                        wandb.Image(img), +                        wandb.Image(rec), +                    ] +                    for img, rec in zip(                          imgs[: self.num_samples], -                        logits[: self.num_samples], -                        labels[: self.num_samples], +                        reconstructions[: self.num_samples],                      )                  ]              }          ) + +    def on_sanity_check_start( +        self, trainer: Trainer, pl_module: LightningModule +    ) -> None: +        """Sets ready attribute.""" +        self.ready = False + +    def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: +        """Start executing this callback only after all validation sanity checks end.""" +        self.ready = True + +    def on_validation_epoch_end( +        self, trainer: Trainer, pl_module: LightningModule +    ) -> None: +        """Logs predictions on validation epoch end.""" +        self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module) + +    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: +        """Logs predictions on train epoch end.""" +        self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module) diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index de5628f..18b1996 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -1,11 +1,13 @@  """Base lightning DataModule class."""  from pathlib import Path -from typing import Dict +from typing import Any, Dict, Tuple  import attr -import pytorch_lightning as LightningDataModule +from pytorch_lightning import LightningDataModule  from torch.utils.data import DataLoader +from text_recognizer.data.base_dataset import BaseDataset +  def load_and_print_info(data_module_class: type) -> None:      """Load dataset and print dataset information.""" @@ -19,17 +21,20 @@ def load_and_print_info(data_module_class: type) -> None:  class BaseDataModule(LightningDataModule):      """Base PyTorch Lightning DataModule.""" -    batch_size: int = attr.ib(default=16) -    num_workers: int = attr.ib(default=0) -      def __attrs_pre_init__(self) -> None:          super().__init__() -    def __attrs_post_init__(self) -> None: -        # Placeholders for subclasses. -        self.dims = None -        self.output_dims = None -        self.mapping = None +    batch_size: int = attr.ib(default=16) +    num_workers: int = attr.ib(default=0) + +    # Placeholders +    data_train: BaseDataset = attr.ib(init=False, default=None) +    data_val: BaseDataset = attr.ib(init=False, default=None) +    data_test: BaseDataset = attr.ib(init=False, default=None) +    dims: Tuple[int, ...] = attr.ib(init=False, default=None) +    output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) +    mapping: Any = attr.ib(init=False, default=None) +    inverse_mapping: Dict[str, int] = attr.ib(init=False)      @classmethod      def data_dirname(cls) -> Path: @@ -58,9 +63,7 @@ class BaseDataModule(LightningDataModule):              stage (Any): Variable to set splits.          """ -        self.data_train = None -        self.data_val = None -        self.data_test = None +        pass      def train_dataloader(self) -> DataLoader:          """Retun DataLoader for train data.""" diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 824b947..d51a42a 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -3,9 +3,10 @@ import json  import os  from pathlib import Path  import shutil -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Tuple  import zipfile +import attr  import h5py  from loguru import logger  import numpy as np @@ -32,6 +33,7 @@ PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"  ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" +@attr.s(auto_attribs=True)  class EMNIST(BaseDataModule):      """Lightning DataModule class for loading EMNIST dataset. @@ -44,18 +46,12 @@ class EMNIST(BaseDataModule):      EMNIST ByClass: 814,255 characters. 62 unbalanced classes.      """ -    def __init__( -        self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 -    ) -> None: -        super().__init__(batch_size, num_workers) -        self.train_fraction = train_fraction -        self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping() -        self.data_train = None -        self.data_val = None -        self.data_test = None -        self.transform = T.Compose([T.ToTensor()]) -        self.dims = (1, *self.input_shape) -        self.output_dims = (1,) +    train_fraction: float = attr.ib() +    transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) + +    def __attrs_post_init__(self) -> None: +        self.mapping, self.inverse_mapping, input_shape = emnist_mapping() +        self.dims = (1, *input_shape)      def prepare_data(self) -> None:          """Downloads dataset if not present.""" diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 9650198..4747508 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -3,6 +3,7 @@ from collections import defaultdict  from pathlib import Path  from typing import Callable, Dict, Tuple +import attr  import h5py  from loguru import logger  import numpy as np @@ -31,31 +32,20 @@ IMAGE_X_PADDING = 28  MAX_OUTPUT_LENGTH = 89  # Same as IAMLines +@attr.s(auto_attribs=True)  class EMNISTLines(BaseDataModule):      """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" -    def __init__( -        self, -        augment: bool = True, -        batch_size: int = 128, -        num_workers: int = 0, -        max_length: int = 32, -        min_overlap: float = 0.0, -        max_overlap: float = 0.33, -        num_train: int = 10_000, -        num_val: int = 2_000, -        num_test: int = 2_000, -    ) -> None: -        super().__init__(batch_size, num_workers) - -        self.augment = augment -        self.max_length = max_length -        self.min_overlap = min_overlap -        self.max_overlap = max_overlap -        self.num_train = num_train -        self.num_val = num_val -        self.num_test = num_test +    augment: bool = attr.ib(default=True) +    max_length: int = attr.ib(default=128) +    min_overlap: float = attr.ib(default=0.0) +    max_overlap: float = attr.ib(default=0.33) +    num_train: int = attr.ib(default=10_000) +    num_val: int = attr.ib(default=2_000) +    num_test: int = attr.ib(default=2_000) +    emnist: EMNIST = attr.ib(init=False, default=None) +    def __attrs_post_init__(self) -> None:          self.emnist = EMNIST()          self.mapping = self.emnist.mapping @@ -75,9 +65,6 @@ class EMNISTLines(BaseDataModule):              raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")          self.output_dims = (MAX_OUTPUT_LENGTH, 1) -        self.data_train: BaseDataset = None -        self.data_val: BaseDataset = None -        self.data_test: BaseDataset = None      @property      def data_filename(self) -> Path: diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 261c8d3..3982c4f 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List  import xml.etree.ElementTree as ElementTree  import zipfile +import attr  from boltons.cacheutils import cachedproperty  from loguru import logger  import toml @@ -22,6 +23,7 @@ DOWNSAMPLE_FACTOR = 2  # If images were downsampled, the regions must also be.  LINE_REGION_PADDING = 16  # Add this many pixels around the exact coordinates. +@attr.s(auto_attribs=True)  class IAM(BaseDataModule):      """      "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, @@ -35,9 +37,7 @@ class IAM(BaseDataModule):          The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.      """ -    def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: -        super().__init__(batch_size, num_workers) -        self.metadata = toml.load(METADATA_FILENAME) +    metadata: Dict = attr.ib(init=False, default=toml.load(METADATA_FILENAME))      def prepare_data(self) -> None:          if self.xml_filenames: diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 0a30a42..886e37e 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,4 +1,7 @@  """IAM original and sythetic dataset class.""" +from typing import Dict, List + +import attr  from torch.utils.data import ConcatDataset  from text_recognizer.data.base_dataset import BaseDataset @@ -7,22 +10,26 @@ from text_recognizer.data.iam_paragraphs import IAMParagraphs  from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs +@attr.s(auto_attribs=True)  class IAMExtendedParagraphs(BaseDataModule): -    def __init__( -        self, -        batch_size: int = 16, -        num_workers: int = 0, -        train_fraction: float = 0.8, -        augment: bool = True, -        word_pieces: bool = False, -    ) -> None: -        super().__init__(batch_size, num_workers) +    train_fraction: float = attr.ib() +    word_pieces: bool = attr.ib(default=False) + +    def __attrs_post_init__(self) -> None:          self.iam_paragraphs = IAMParagraphs( -            batch_size, num_workers, train_fraction, augment, word_pieces, +            self.batch_size, +            self.num_workers, +            self.train_fraction, +            self.augment, +            self.word_pieces,          )          self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( -            batch_size, num_workers, train_fraction, augment, word_pieces, +            self.batch_size, +            self.num_workers, +            self.train_fraction, +            self.augment, +            self.word_pieces,          )          self.dims = self.iam_paragraphs.dims @@ -30,10 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule):          self.mapping = self.iam_paragraphs.mapping          self.inverse_mapping = self.iam_paragraphs.inverse_mapping -        self.data_train: BaseDataset = None -        self.data_val: BaseDataset = None -        self.data_test: BaseDataset = None -      def prepare_data(self) -> None:          """Prepares the paragraphs data."""          self.iam_paragraphs.prepare_data() diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 9c78a22..e45e5c8 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -7,8 +7,9 @@ dataset.  import json  from pathlib import Path  import random -from typing import List, Sequence, Tuple +from typing import Dict, List, Sequence, Tuple +import attr  from loguru import logger  from PIL import Image, ImageFile, ImageOps  import numpy as np @@ -35,26 +36,17 @@ IMAGE_HEIGHT = 56  IMAGE_WIDTH = 1024 +@attr.s(auto_attribs=True)  class IAMLines(BaseDataModule):      """IAM handwritten lines dataset.""" -    def __init__( -        self, -        augment: bool = True, -        fraction: float = 0.8, -        batch_size: int = 128, -        num_workers: int = 0, -    ) -> None: -        # TODO: add transforms -        super().__init__(batch_size, num_workers) -        self.augment = augment -        self.fraction = fraction +    augment: bool = attr.ib(default=True) +    fraction: float = attr.ib(default=0.8) + +    def __attrs_post_init__(self) -> None:          self.mapping, self.inverse_mapping, _ = emnist_mapping()          self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)          self.output_dims = (89, 1) -        self.data_train: BaseDataset = None -        self.data_val: BaseDataset = None -        self.data_test: BaseDataset = None      def prepare_data(self) -> None:          """Creates the IAM lines dataset if not existing.""" diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index fe60e99..445b788 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -3,6 +3,7 @@ import json  from pathlib import Path  from typing import Dict, List, Optional, Sequence, Tuple +import attr  from loguru import logger  import numpy as np  from PIL import Image, ImageOps @@ -33,33 +34,25 @@ IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR  MAX_LABEL_LENGTH = 682 +@attr.s(auto_attribs=True)  class IAMParagraphs(BaseDataModule):      """IAM handwriting database paragraphs.""" -    def __init__( -        self, -        batch_size: int = 16, -        num_workers: int = 0, -        train_fraction: float = 0.8, -        augment: bool = True, -        word_pieces: bool = False, -    ) -> None: -        super().__init__(batch_size, num_workers) -        self.augment = augment -        self.word_pieces = word_pieces +    augment: bool = attr.ib(default=True) +    train_fraction: float = attr.ib(default=0.8) +    word_pieces: bool = attr.ib(default=False) + +    def __attrs_post_init__(self) -> None:          self.mapping, self.inverse_mapping, _ = emnist_mapping(              extra_symbols=[NEW_LINE_TOKEN]          ) -        if word_pieces: +        if self.word_pieces:              self.mapping = WordPieceMapping()          self.train_fraction = train_fraction          self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)          self.output_dims = (MAX_LABEL_LENGTH, 1) -        self.data_train: BaseDataset = None -        self.data_val: BaseDataset = None -        self.data_test: BaseDataset = None      def prepare_data(self) -> None:          """Create data for training/testing.""" @@ -166,7 +159,10 @@ def get_dataset_properties() -> Dict:              "min": min(_get_property_values("num_lines")),              "max": max(_get_property_values("num_lines")),          }, -        "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, +        "crop_shape": { +            "min": crop_shapes.min(axis=0), +            "max": crop_shapes.max(axis=0), +        },          "aspect_ratio": {              "min": aspect_ratio.min(axis=0),              "max": aspect_ratio.max(axis=0), @@ -287,7 +283,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:              ),              T.ColorJitter(brightness=(0.8, 1.6)),              T.RandomAffine( -                degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, +                degrees=1, +                shear=(-10, 10), +                interpolation=InterpolationMode.BILINEAR,              ),          ]      else: diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 8dc7a36..f95df0f 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -5,7 +5,7 @@ import attr  import hydra  import loguru.logger as log  from omegaconf import DictConfig -import pytorch_lightning as pl +import pytorch_lightning as LightningModule  import torch  from torch import nn  from torch import Tensor @@ -13,7 +13,7 @@ import torchmetrics  @attr.s -class BaseLitModel(pl.LightningModule): +class BaseLitModel(LightningModule):      """Abstract PyTorch Lightning class."""      network: Type[nn.Module] = attr.ib() @@ -80,7 +80,6 @@ class BaseLitModel(pl.LightningModule):          """Configures optimizer and lr scheduler."""          optimizer = self._configure_optimizer()          scheduler = self._configure_lr_scheduler(optimizer) -          return [optimizer], [scheduler]      def forward(self, data: Tensor) -> Tensor: diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 58d0537..4117ae2 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,18 +1,23 @@  """Character Error Rate (CER).""" -from typing import Sequence +from typing import Set, Sequence +import attr  import editdistance  import torch  from torch import Tensor -import torchmetrics +from torchmetrics import Metric -class CharacterErrorRate(torchmetrics.Metric): +@attr.s +class CharacterErrorRate(Metric):      """Character error rate metric, computed using Levenshtein distance.""" -    def __init__(self, ignore_tokens: Sequence[int], *args) -> None: +    ignore_tokens: Set = attr.ib(converter=set) +    error: Tensor = attr.ib(init=False) +    total: Tensor = attr.ib(init=False) + +    def __attrs_post_init__(self) -> None:          super().__init__() -        self.ignore_tokens = set(ignore_tokens)          self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")          self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index ea54d83..8c9fe8a 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -2,35 +2,24 @@  from typing import Dict, List, Optional, Union, Tuple, Type  import attr +import hydra  from omegaconf import DictConfig  from torch import nn, Tensor  from text_recognizer.data.emnist import emnist_mapping  from text_recognizer.data.mappings import AbstractMapping  from text_recognizer.models.metrics import CharacterErrorRate -from text_recognizer.models.base import LitBaseModel +from text_recognizer.models.base import BaseLitModel -@attr.s -class TransformerLitModel(LitBaseModel): +@attr.s(auto_attribs=True) +class TransformerLitModel(BaseLitModel):      """A PyTorch Lightning model for transformer networks.""" -    network: Type[nn.Module] = attr.ib() -    criterion_config: DictConfig = attr.ib(converter=DictConfig) -    optimizer_config: DictConfig = attr.ib(converter=DictConfig) -    lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) -    monitor: str = attr.ib() -    mapping: Type[AbstractMapping] = attr.ib() +    mapping_config: DictConfig = attr.ib(converter=DictConfig)      def __attrs_post_init__(self) -> None: -        super().__init__( -            network=self.network, -            optimizer_config=self.optimizer_config, -            lr_scheduler_config=self.lr_scheduler_config, -            criterion_config=self.criterion_config, -            monitor=self.monitor, -        ) -        self.mapping, ignore_tokens = self.configure_mapping(mapping) +        self.mapping, ignore_tokens = self._configure_mapping()          self.val_cer = CharacterErrorRate(ignore_tokens)          self.test_cer = CharacterErrorRate(ignore_tokens) @@ -39,9 +28,10 @@ class TransformerLitModel(LitBaseModel):          return self.network.predict(data)      @staticmethod -    def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: +    def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]:          """Configure mapping."""          # TODO: Fix me!!! +        # Load config with hydra          mapping, inverse_mapping, _ = emnist_mapping(["\n"])          start_index = inverse_mapping["<s>"]          end_index = inverse_mapping["<e>"] diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 7dc950f..0172163 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -1,49 +1,23 @@  """PyTorch Lightning model for base Transformers."""  from typing import Any, Dict, Union, Tuple, Type +import attr  from omegaconf import DictConfig  from torch import nn  from torch import Tensor  import wandb -from text_recognizer.models.base import LitBaseModel +from text_recognizer.models.base import BaseLitModel -class LitVQVAEModel(LitBaseModel): +@attr.s(auto_attribs=True) +class VQVAELitModel(BaseLitModel):      """A PyTorch Lightning model for transformer networks.""" -    def __init__( -        self, -        network: Type[nn.Module], -        optimizer: Union[DictConfig, Dict], -        lr_scheduler: Union[DictConfig, Dict], -        criterion: Union[DictConfig, Dict], -        monitor: str = "val/loss", -        *args: Any, -        **kwargs: Dict, -    ) -> None: -        super().__init__(network, optimizer, lr_scheduler, criterion, monitor) -      def forward(self, data: Tensor) -> Tensor:          """Forward pass with the transformer network."""          return self.network.predict(data) -    def _log_prediction( -        self, data: Tensor, reconstructions: Tensor, title: str -    ) -> None: -        """Logs prediction on image with wandb.""" -        try: -            self.logger.experiment.log( -                { -                    title: [ -                        wandb.Image(data[0]), -                        wandb.Image(reconstructions[0]), -                    ] -                } -            ) -        except AttributeError: -            pass -      def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:          """Training step."""          data, _ = batch diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 109bf4d..85094f1 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -1,4 +1,4 @@ -"""Miscellaneous neural network functionality.""" +"""Miscellaneous neural network utility functionality."""  from typing import Type  from torch import nn diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml new file mode 100644 index 0000000..3070b56 --- /dev/null +++ b/training/conf/datamodule/iam_extended_paragraphs.yaml @@ -0,0 +1,5 @@ +_target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs +batch_size: 32 +num_workers: 12 +train_fraction: 0.8 +augment: true diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml deleted file mode 100644 index 6439a15..0000000 --- a/training/conf/dataset/iam_extended_paragraphs.yaml +++ /dev/null @@ -1,6 +0,0 @@ -type: IAMExtendedParagraphs -args: -  batch_size: 32 -  num_workers: 12 -  train_fraction: 0.8 -  augment: true diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml index 60a6f27..e8cb5c4 100644 --- a/training/conf/lr_scheduler/one_cycle.yaml +++ b/training/conf/lr_scheduler/one_cycle.yaml @@ -1,8 +1,15 @@ -type: OneCycleLR -args: -  interval: step -  max_lr: 1.0e-3 -  three_phase: true -  epochs: 64 -  steps_per_epoch: 633 # num_samples / batch_size -monitor: val_loss +_target_: torch.optim.lr_scheduler.OneCycleLR +max_lr: 1.0e-3 +total_steps: None +epochs: None +steps_per_epoch: None +pct_start: 0.3 +anneal_strategy: 'cos' +cycle_momentum: True +base_momentum: 0.85 +max_momentum: 0.95 +div_factor: 25.0 +final_div_factor: 10000.0 +three_phase: true +last_epoch: -1 +verbose: false diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml index 7136dbd..6be37e5 100644 --- a/training/conf/model/lit_vqvae.yaml +++ b/training/conf/model/lit_vqvae.yaml @@ -1,3 +1,3 @@ -type: LitVQVAEModel +_target_: text_recognizer.models.vqvae.VQVAELitModel  args:    mapping: sentence_piece diff --git a/training/run.py b/training/run.py index 5f7c927..31da666 100644 --- a/training/run.py +++ b/training/run.py @@ -51,6 +51,7 @@ def run(config: DictConfig) -> Optional[float]:      )      # Log hyperparameters +    log.info("Logging hyperparameters")      utils.log_hyperparameters(config=config, model=model, trainer=trainer)      if config.debug: diff --git a/training/utils.py b/training/utils.py index 4c31dc3..140d97e 100644 --- a/training/utils.py +++ b/training/utils.py @@ -1,4 +1,4 @@ -"""Util functions for training hydra configs and pytorch lightning.""" +"""Util functions for training with hydra and pytorch lightning."""  from typing import Any, List, Type  import warnings |