From 263f2b7158d76bc0adad45309625910c0fa7b1fe Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Fri, 6 Aug 2021 14:19:37 +0200
Subject: Remove lr args from model, add Cosine lr, fix to vqvae stack

---
 notebooks/05c-test-model-end-to-end.ipynb        | 119 +++++++++++++++++++++--
 text_recognizer/models/base.py                   |  12 ++-
 text_recognizer/models/vqvae.py                  |   3 +
 training/conf/experiment/vqvae.yaml              |  11 ++-
 training/conf/lr_scheduler/cosine_annealing.yaml |   7 ++
 training/conf/lr_scheduler/one_cycle.yaml        |   4 +
 training/conf/model/lit_vqvae.yaml               |   4 +-
 training/conf/network/decoder/vae_decoder.yaml   |   2 +-
 training/conf/network/encoder/vae_encoder.yaml   |   2 +-
 training/conf/network/vqvae.yaml                 |   2 +-
 training/conf/optimizer/madgrad.yaml             |   2 +-
 11 files changed, 143 insertions(+), 25 deletions(-)
 create mode 100644 training/conf/lr_scheduler/cosine_annealing.yaml

diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb
index 7996257..b26a1fe 100644
--- a/notebooks/05c-test-model-end-to-end.ipynb
+++ b/notebooks/05c-test-model-end-to-end.ipynb
@@ -2,10 +2,19 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 4,
    "id": "1e40a88b",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "The autoreload extension is already loaded. To reload it, use:\n",
+      "  %reload_ext autoreload\n"
+     ]
+    }
+   ],
    "source": [
     "%load_ext autoreload\n",
     "%autoreload 2\n",
@@ -25,7 +34,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 5,
    "id": "38fb3d9d-a163-4b72-981f-f31b51be39f2",
    "metadata": {},
    "outputs": [],
@@ -37,10 +46,46 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 10,
    "id": "74780b21-3313-452b-b580-703cac878416",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "encoder:\n",
+      "  _target_: text_recognizer.networks.vqvae.encoder.Encoder\n",
+      "  in_channels: 1\n",
+      "  hidden_dim: 32\n",
+      "  channels_multipliers:\n",
+      "  - 1\n",
+      "  - 2\n",
+      "  - 4\n",
+      "  - 4\n",
+      "  - 4\n",
+      "  dropout_rate: 0.25\n",
+      "decoder:\n",
+      "  _target_: text_recognizer.networks.vqvae.decoder.Decoder\n",
+      "  out_channels: 1\n",
+      "  hidden_dim: 32\n",
+      "  channels_multipliers:\n",
+      "  - 4\n",
+      "  - 4\n",
+      "  - 4\n",
+      "  - 2\n",
+      "  - 1\n",
+      "  dropout_rate: 0.25\n",
+      "_target_: text_recognizer.networks.vqvae.vqvae.VQVAE\n",
+      "hidden_dim: 128\n",
+      "embedding_dim: 32\n",
+      "num_embeddings: 1024\n",
+      "decay: 0.99\n",
+      "\n",
+      "{'encoder': {'_target_': 'text_recognizer.networks.vqvae.encoder.Encoder', 'in_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [1, 2, 4, 4, 4], 'dropout_rate': 0.25}, 'decoder': {'_target_': 'text_recognizer.networks.vqvae.decoder.Decoder', 'out_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [4, 4, 4, 2, 1], 'dropout_rate': 0.25}, '_target_': 'text_recognizer.networks.vqvae.vqvae.VQVAE', 'hidden_dim': 128, 'embedding_dim': 32, 'num_embeddings': 1024, 'decay': 0.99}\n"
+     ]
+    }
+   ],
    "source": [
     "# context initialization\n",
     "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n",
@@ -51,7 +96,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 11,
    "id": "205a03e8-7aa1-407f-afa5-92693715b677",
    "metadata": {},
    "outputs": [],
@@ -61,7 +106,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 12,
    "id": "c74384f0-754e-4c29-8f06-339372d6e4c1",
    "metadata": {},
    "outputs": [],
@@ -71,10 +116,66 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 13,
    "id": "5ebab599-2497-42f8-b54b-1663ee66fde9",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "==========================================================================================\n",
+      "Layer (type:depth-idx)                   Output Shape              Param #\n",
+      "==========================================================================================\n",
+      "├─Encoder: 1-1                           [-1, 128, 18, 20]         --\n",
+      "|    └─Sequential: 2-1                   [-1, 128, 18, 20]         --\n",
+      "|    |    └─Conv2d: 3-1                  [-1, 32, 576, 640]        320\n",
+      "|    |    └─Conv2d: 3-2                  [-1, 32, 288, 320]        16,416\n",
+      "|    |    └─Mish: 3-3                    [-1, 32, 288, 320]        --\n",
+      "|    |    └─Conv2d: 3-4                  [-1, 64, 144, 160]        32,832\n",
+      "|    |    └─Mish: 3-5                    [-1, 64, 144, 160]        --\n",
+      "|    |    └─Conv2d: 3-6                  [-1, 128, 72, 80]         131,200\n",
+      "|    |    └─Mish: 3-7                    [-1, 128, 72, 80]         --\n",
+      "|    |    └─Conv2d: 3-8                  [-1, 128, 36, 40]         262,272\n",
+      "|    |    └─Mish: 3-9                    [-1, 128, 36, 40]         --\n",
+      "|    |    └─Conv2d: 3-10                 [-1, 128, 18, 20]         262,272\n",
+      "|    |    └─Mish: 3-11                   [-1, 128, 18, 20]         --\n",
+      "|    |    └─Residual: 3-12               [-1, 128, 18, 20]         164,352\n",
+      "|    |    └─Residual: 3-13               [-1, 128, 18, 20]         164,352\n",
+      "├─Conv2d: 1-2                            [-1, 32, 18, 20]          4,128\n",
+      "├─VectorQuantizer: 1-3                   [-1, 32, 18, 20]          --\n",
+      "├─Conv2d: 1-4                            [-1, 128, 18, 20]         4,224\n",
+      "├─Decoder: 1-5                           [-1, 1, 576, 640]         --\n",
+      "|    └─Sequential: 2-2                   [-1, 1, 576, 640]         --\n",
+      "|    |    └─Residual: 3-14               [-1, 128, 18, 20]         164,352\n",
+      "|    |    └─Residual: 3-15               [-1, 128, 18, 20]         164,352\n",
+      "|    |    └─ConvTranspose2d: 3-16        [-1, 128, 36, 40]         262,272\n",
+      "|    |    └─Mish: 3-17                   [-1, 128, 36, 40]         --\n",
+      "|    |    └─ConvTranspose2d: 3-18        [-1, 128, 72, 80]         262,272\n",
+      "|    |    └─Mish: 3-19                   [-1, 128, 72, 80]         --\n",
+      "|    |    └─ConvTranspose2d: 3-20        [-1, 64, 144, 160]        131,136\n",
+      "|    |    └─Mish: 3-21                   [-1, 64, 144, 160]        --\n",
+      "|    |    └─ConvTranspose2d: 3-22        [-1, 32, 288, 320]        32,800\n",
+      "|    |    └─Mish: 3-23                   [-1, 32, 288, 320]        --\n",
+      "|    |    └─ConvTranspose2d: 3-24        [-1, 32, 576, 640]        16,416\n",
+      "|    |    └─Mish: 3-25                   [-1, 32, 576, 640]        --\n",
+      "|    |    └─Normalize: 3-26              [-1, 32, 576, 640]        64\n",
+      "|    |    └─Mish: 3-27                   [-1, 32, 576, 640]        --\n",
+      "|    |    └─Conv2d: 3-28                 [-1, 1, 576, 640]         289\n",
+      "==========================================================================================\n",
+      "Total params: 2,076,321\n",
+      "Trainable params: 2,076,321\n",
+      "Non-trainable params: 0\n",
+      "Total mult-adds (G): 17.68\n",
+      "==========================================================================================\n",
+      "Input size (MB): 1.41\n",
+      "Forward/backward pass size (MB): 355.17\n",
+      "Params size (MB): 7.92\n",
+      "Estimated Total Size (MB): 364.49\n",
+      "==========================================================================================\n"
+     ]
+    }
+   ],
    "source": [
     "summary(net, (1, 576, 640), device=\"cpu\");"
    ]
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 57c5964..ab3fa35 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -26,8 +26,6 @@ class BaseLitModel(LightningModule):
     loss_fn: Type[nn.Module] = attr.ib()
     optimizer_config: DictConfig = attr.ib()
     lr_scheduler_config: DictConfig = attr.ib()
-    interval: str = attr.ib()
-    monitor: str = attr.ib(default="val/loss")
     train_acc: torchmetrics.Accuracy = attr.ib(
         init=False, default=torchmetrics.Accuracy()
     )
@@ -58,12 +56,18 @@ class BaseLitModel(LightningModule):
         self, optimizer: Type[torch.optim.Optimizer]
     ) -> Dict[str, Any]:
         """Configures the lr scheduler."""
+        # Extract non-class arguments.
+        monitor = self.lr_scheduler_config.monitor
+        interval = self.lr_scheduler_config.interval
+        del self.lr_scheduler_config.monitor
+        del self.lr_scheduler_config.interval
+
         log.info(
             f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
         )
         scheduler = {
-            "monitor": self.monitor,
-            "interval": self.interval,
+            "monitor": monitor,
+            "interval": interval,
             "scheduler": hydra.utils.instantiate(
                 self.lr_scheduler_config, optimizer=optimizer
             ),
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index 7f79b78..76b7ba6 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -23,6 +23,7 @@ class VQVAELitModel(BaseLitModel):
         reconstructions, vq_loss = self(data)
         loss = self.loss_fn(reconstructions, data)
         loss = loss + self.latent_loss_weight * vq_loss
+        self.log("train/vq_loss", vq_loss)
         self.log("train/loss", loss)
         return loss
 
@@ -32,6 +33,7 @@ class VQVAELitModel(BaseLitModel):
         reconstructions, vq_loss = self(data)
         loss = self.loss_fn(reconstructions, data)
         loss = loss + self.latent_loss_weight * vq_loss
+        self.log("val/vq_loss", vq_loss)
         self.log("val/loss", loss, prog_bar=True)
 
     def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -40,4 +42,5 @@ class VQVAELitModel(BaseLitModel):
         reconstructions, vq_loss = self(data)
         loss = self.loss_fn(reconstructions, data)
         loss = loss + self.latent_loss_weight * vq_loss
+        self.log("test/vq_loss", vq_loss)
         self.log("test/loss", loss)
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
index eb40f3b..7a9e643 100644
--- a/training/conf/experiment/vqvae.yaml
+++ b/training/conf/experiment/vqvae.yaml
@@ -5,6 +5,7 @@ defaults:
   - override /criterion: mse
   - override /model: lit_vqvae
   - override /callbacks: wandb_vae
+  - override /lr_scheduler: cosine_annealing
 
 trainer:
   max_epochs: 64
@@ -13,11 +14,11 @@ trainer:
 datamodule:
   batch_size: 16
 
-lr_scheduler:
-  epochs: 64
-  steps_per_epoch: 1245
+# lr_scheduler:
+  # epochs: 64
+  # steps_per_epoch: 1245
 
-optimizer:
-  lr: 1.0e-3
+# optimizer:
+  # lr: 1.0e-3
 
 summary: [1, 576, 640]
diff --git a/training/conf/lr_scheduler/cosine_annealing.yaml b/training/conf/lr_scheduler/cosine_annealing.yaml
new file mode 100644
index 0000000..62667bb
--- /dev/null
+++ b/training/conf/lr_scheduler/cosine_annealing.yaml
@@ -0,0 +1,7 @@
+_target_: torch.optim.lr_scheduler.CosineAnnealingLR
+T_max: 64
+eta_min: 0.0
+last_epoch: -1
+
+interval: epoch
+monitor: val/loss
diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml
index eecee8a..fb5987a 100644
--- a/training/conf/lr_scheduler/one_cycle.yaml
+++ b/training/conf/lr_scheduler/one_cycle.yaml
@@ -13,3 +13,7 @@ final_div_factor: 10000.0
 three_phase: true
 last_epoch: -1
 verbose: false
+
+# Non-class arguments
+interval: step
+monitor: val/loss
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
index 409fa0d..632668b 100644
--- a/training/conf/model/lit_vqvae.yaml
+++ b/training/conf/model/lit_vqvae.yaml
@@ -1,4 +1,2 @@
 _target_: text_recognizer.models.vqvae.VQVAELitModel
-interval: step
-monitor: val/loss
-latent_loss_weight: 1.0
+latent_loss_weight: 0.25
diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml
index b2090b3..0a36a54 100644
--- a/training/conf/network/decoder/vae_decoder.yaml
+++ b/training/conf/network/decoder/vae_decoder.yaml
@@ -1,5 +1,5 @@
 _target_: text_recognizer.networks.vqvae.decoder.Decoder
 out_channels: 1 
 hidden_dim: 32
-channels_multipliers: [8, 6, 2, 1]
+channels_multipliers: [4, 4, 2, 1]
 dropout_rate: 0.25
diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml
index 5dc6814..dacd389 100644
--- a/training/conf/network/encoder/vae_encoder.yaml
+++ b/training/conf/network/encoder/vae_encoder.yaml
@@ -1,5 +1,5 @@
 _target_: text_recognizer.networks.vqvae.encoder.Encoder
 in_channels: 1 
 hidden_dim: 32
-channels_multipliers: [1, 2, 6, 8]
+channels_multipliers: [1, 2, 4, 4]
 dropout_rate: 0.25
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
index 835d0b7..d97e9b6 100644
--- a/training/conf/network/vqvae.yaml
+++ b/training/conf/network/vqvae.yaml
@@ -3,7 +3,7 @@ defaults:
   - decoder: vae_decoder
 
 _target_: text_recognizer.networks.vqvae.vqvae.VQVAE
-hidden_dim: 256
+hidden_dim: 128
 embedding_dim: 32
 num_embeddings: 1024
 decay: 0.99
diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml
index 46b2fff..458b116 100644
--- a/training/conf/optimizer/madgrad.yaml
+++ b/training/conf/optimizer/madgrad.yaml
@@ -1,5 +1,5 @@
 _target_: madgrad.MADGRAD
-lr: 2.0e-4
+lr: 3.0e-4
 momentum: 0.9
 weight_decay: 0
 eps: 1.0e-6
-- 
cgit v1.2.3-70-g09d2