diff options
-rw-r--r-- | notebooks/05c-test-model-end-to-end.ipynb | 119 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 12 | ||||
-rw-r--r-- | text_recognizer/models/vqvae.py | 3 | ||||
-rw-r--r-- | training/conf/experiment/vqvae.yaml | 11 | ||||
-rw-r--r-- | training/conf/lr_scheduler/cosine_annealing.yaml | 7 | ||||
-rw-r--r-- | training/conf/lr_scheduler/one_cycle.yaml | 4 | ||||
-rw-r--r-- | training/conf/model/lit_vqvae.yaml | 4 | ||||
-rw-r--r-- | training/conf/network/decoder/vae_decoder.yaml | 2 | ||||
-rw-r--r-- | training/conf/network/encoder/vae_encoder.yaml | 2 | ||||
-rw-r--r-- | training/conf/network/vqvae.yaml | 2 | ||||
-rw-r--r-- | training/conf/optimizer/madgrad.yaml | 2 |
11 files changed, 143 insertions, 25 deletions
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb index 7996257..b26a1fe 100644 --- a/notebooks/05c-test-model-end-to-end.ipynb +++ b/notebooks/05c-test-model-end-to-end.ipynb @@ -2,10 +2,19 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "1e40a88b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -25,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "38fb3d9d-a163-4b72-981f-f31b51be39f2", "metadata": {}, "outputs": [], @@ -37,10 +46,46 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "74780b21-3313-452b-b580-703cac878416", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "encoder:\n", + " _target_: text_recognizer.networks.vqvae.encoder.Encoder\n", + " in_channels: 1\n", + " hidden_dim: 32\n", + " channels_multipliers:\n", + " - 1\n", + " - 2\n", + " - 4\n", + " - 4\n", + " - 4\n", + " dropout_rate: 0.25\n", + "decoder:\n", + " _target_: text_recognizer.networks.vqvae.decoder.Decoder\n", + " out_channels: 1\n", + " hidden_dim: 32\n", + " channels_multipliers:\n", + " - 4\n", + " - 4\n", + " - 4\n", + " - 2\n", + " - 1\n", + " dropout_rate: 0.25\n", + "_target_: text_recognizer.networks.vqvae.vqvae.VQVAE\n", + "hidden_dim: 128\n", + "embedding_dim: 32\n", + "num_embeddings: 1024\n", + "decay: 0.99\n", + "\n", + "{'encoder': {'_target_': 'text_recognizer.networks.vqvae.encoder.Encoder', 'in_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [1, 2, 4, 4, 4], 'dropout_rate': 0.25}, 'decoder': {'_target_': 'text_recognizer.networks.vqvae.decoder.Decoder', 'out_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [4, 4, 4, 2, 1], 'dropout_rate': 0.25}, '_target_': 'text_recognizer.networks.vqvae.vqvae.VQVAE', 'hidden_dim': 128, 'embedding_dim': 32, 'num_embeddings': 1024, 'decay': 0.99}\n" + ] + } + ], "source": [ "# context initialization\n", "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n", @@ -51,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "205a03e8-7aa1-407f-afa5-92693715b677", "metadata": {}, "outputs": [], @@ -61,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "c74384f0-754e-4c29-8f06-339372d6e4c1", "metadata": {}, "outputs": [], @@ -71,10 +116,66 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "5ebab599-2497-42f8-b54b-1663ee66fde9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==========================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "==========================================================================================\n", + "├─Encoder: 1-1 [-1, 128, 18, 20] --\n", + "| └─Sequential: 2-1 [-1, 128, 18, 20] --\n", + "| | └─Conv2d: 3-1 [-1, 32, 576, 640] 320\n", + "| | └─Conv2d: 3-2 [-1, 32, 288, 320] 16,416\n", + "| | └─Mish: 3-3 [-1, 32, 288, 320] --\n", + "| | └─Conv2d: 3-4 [-1, 64, 144, 160] 32,832\n", + "| | └─Mish: 3-5 [-1, 64, 144, 160] --\n", + "| | └─Conv2d: 3-6 [-1, 128, 72, 80] 131,200\n", + "| | └─Mish: 3-7 [-1, 128, 72, 80] --\n", + "| | └─Conv2d: 3-8 [-1, 128, 36, 40] 262,272\n", + "| | └─Mish: 3-9 [-1, 128, 36, 40] --\n", + "| | └─Conv2d: 3-10 [-1, 128, 18, 20] 262,272\n", + "| | └─Mish: 3-11 [-1, 128, 18, 20] --\n", + "| | └─Residual: 3-12 [-1, 128, 18, 20] 164,352\n", + "| | └─Residual: 3-13 [-1, 128, 18, 20] 164,352\n", + "├─Conv2d: 1-2 [-1, 32, 18, 20] 4,128\n", + "├─VectorQuantizer: 1-3 [-1, 32, 18, 20] --\n", + "├─Conv2d: 1-4 [-1, 128, 18, 20] 4,224\n", + "├─Decoder: 1-5 [-1, 1, 576, 640] --\n", + "| └─Sequential: 2-2 [-1, 1, 576, 640] --\n", + "| | └─Residual: 3-14 [-1, 128, 18, 20] 164,352\n", + "| | └─Residual: 3-15 [-1, 128, 18, 20] 164,352\n", + "| | └─ConvTranspose2d: 3-16 [-1, 128, 36, 40] 262,272\n", + "| | └─Mish: 3-17 [-1, 128, 36, 40] --\n", + "| | └─ConvTranspose2d: 3-18 [-1, 128, 72, 80] 262,272\n", + "| | └─Mish: 3-19 [-1, 128, 72, 80] --\n", + "| | └─ConvTranspose2d: 3-20 [-1, 64, 144, 160] 131,136\n", + "| | └─Mish: 3-21 [-1, 64, 144, 160] --\n", + "| | └─ConvTranspose2d: 3-22 [-1, 32, 288, 320] 32,800\n", + "| | └─Mish: 3-23 [-1, 32, 288, 320] --\n", + "| | └─ConvTranspose2d: 3-24 [-1, 32, 576, 640] 16,416\n", + "| | └─Mish: 3-25 [-1, 32, 576, 640] --\n", + "| | └─Normalize: 3-26 [-1, 32, 576, 640] 64\n", + "| | └─Mish: 3-27 [-1, 32, 576, 640] --\n", + "| | └─Conv2d: 3-28 [-1, 1, 576, 640] 289\n", + "==========================================================================================\n", + "Total params: 2,076,321\n", + "Trainable params: 2,076,321\n", + "Non-trainable params: 0\n", + "Total mult-adds (G): 17.68\n", + "==========================================================================================\n", + "Input size (MB): 1.41\n", + "Forward/backward pass size (MB): 355.17\n", + "Params size (MB): 7.92\n", + "Estimated Total Size (MB): 364.49\n", + "==========================================================================================\n" + ] + } + ], "source": [ "summary(net, (1, 576, 640), device=\"cpu\");" ] diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 57c5964..ab3fa35 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -26,8 +26,6 @@ class BaseLitModel(LightningModule): loss_fn: Type[nn.Module] = attr.ib() optimizer_config: DictConfig = attr.ib() lr_scheduler_config: DictConfig = attr.ib() - interval: str = attr.ib() - monitor: str = attr.ib(default="val/loss") train_acc: torchmetrics.Accuracy = attr.ib( init=False, default=torchmetrics.Accuracy() ) @@ -58,12 +56,18 @@ class BaseLitModel(LightningModule): self, optimizer: Type[torch.optim.Optimizer] ) -> Dict[str, Any]: """Configures the lr scheduler.""" + # Extract non-class arguments. + monitor = self.lr_scheduler_config.monitor + interval = self.lr_scheduler_config.interval + del self.lr_scheduler_config.monitor + del self.lr_scheduler_config.interval + log.info( f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" ) scheduler = { - "monitor": self.monitor, - "interval": self.interval, + "monitor": monitor, + "interval": interval, "scheduler": hydra.utils.instantiate( self.lr_scheduler_config, optimizer=optimizer ), diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 7f79b78..76b7ba6 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -23,6 +23,7 @@ class VQVAELitModel(BaseLitModel): reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) loss = loss + self.latent_loss_weight * vq_loss + self.log("train/vq_loss", vq_loss) self.log("train/loss", loss) return loss @@ -32,6 +33,7 @@ class VQVAELitModel(BaseLitModel): reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) loss = loss + self.latent_loss_weight * vq_loss + self.log("val/vq_loss", vq_loss) self.log("val/loss", loss, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -40,4 +42,5 @@ class VQVAELitModel(BaseLitModel): reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) loss = loss + self.latent_loss_weight * vq_loss + self.log("test/vq_loss", vq_loss) self.log("test/loss", loss) diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml index eb40f3b..7a9e643 100644 --- a/training/conf/experiment/vqvae.yaml +++ b/training/conf/experiment/vqvae.yaml @@ -5,6 +5,7 @@ defaults: - override /criterion: mse - override /model: lit_vqvae - override /callbacks: wandb_vae + - override /lr_scheduler: cosine_annealing trainer: max_epochs: 64 @@ -13,11 +14,11 @@ trainer: datamodule: batch_size: 16 -lr_scheduler: - epochs: 64 - steps_per_epoch: 1245 +# lr_scheduler: + # epochs: 64 + # steps_per_epoch: 1245 -optimizer: - lr: 1.0e-3 +# optimizer: + # lr: 1.0e-3 summary: [1, 576, 640] diff --git a/training/conf/lr_scheduler/cosine_annealing.yaml b/training/conf/lr_scheduler/cosine_annealing.yaml new file mode 100644 index 0000000..62667bb --- /dev/null +++ b/training/conf/lr_scheduler/cosine_annealing.yaml @@ -0,0 +1,7 @@ +_target_: torch.optim.lr_scheduler.CosineAnnealingLR +T_max: 64 +eta_min: 0.0 +last_epoch: -1 + +interval: epoch +monitor: val/loss diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml index eecee8a..fb5987a 100644 --- a/training/conf/lr_scheduler/one_cycle.yaml +++ b/training/conf/lr_scheduler/one_cycle.yaml @@ -13,3 +13,7 @@ final_div_factor: 10000.0 three_phase: true last_epoch: -1 verbose: false + +# Non-class arguments +interval: step +monitor: val/loss diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml index 409fa0d..632668b 100644 --- a/training/conf/model/lit_vqvae.yaml +++ b/training/conf/model/lit_vqvae.yaml @@ -1,4 +1,2 @@ _target_: text_recognizer.models.vqvae.VQVAELitModel -interval: step -monitor: val/loss -latent_loss_weight: 1.0 +latent_loss_weight: 0.25 diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml index b2090b3..0a36a54 100644 --- a/training/conf/network/decoder/vae_decoder.yaml +++ b/training/conf/network/decoder/vae_decoder.yaml @@ -1,5 +1,5 @@ _target_: text_recognizer.networks.vqvae.decoder.Decoder out_channels: 1 hidden_dim: 32 -channels_multipliers: [8, 6, 2, 1] +channels_multipliers: [4, 4, 2, 1] dropout_rate: 0.25 diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml index 5dc6814..dacd389 100644 --- a/training/conf/network/encoder/vae_encoder.yaml +++ b/training/conf/network/encoder/vae_encoder.yaml @@ -1,5 +1,5 @@ _target_: text_recognizer.networks.vqvae.encoder.Encoder in_channels: 1 hidden_dim: 32 -channels_multipliers: [1, 2, 6, 8] +channels_multipliers: [1, 2, 4, 4] dropout_rate: 0.25 diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index 835d0b7..d97e9b6 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -3,7 +3,7 @@ defaults: - decoder: vae_decoder _target_: text_recognizer.networks.vqvae.vqvae.VQVAE -hidden_dim: 256 +hidden_dim: 128 embedding_dim: 32 num_embeddings: 1024 decay: 0.99 diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml index 46b2fff..458b116 100644 --- a/training/conf/optimizer/madgrad.yaml +++ b/training/conf/optimizer/madgrad.yaml @@ -1,5 +1,5 @@ _target_: madgrad.MADGRAD -lr: 2.0e-4 +lr: 3.0e-4 momentum: 0.9 weight_decay: 0 eps: 1.0e-6 |