summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 21:43:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 21:43:39 +0200
commit82f4acabe24e5171c40afa2939a4777ba87bcc30 (patch)
tree4d327fa26e4662a0447a66375442a9adeb13ea3d
parent240f5e9f20032e82515fa66ce784619527d1041e (diff)
Add training of VQGAN
-rw-r--r--notebooks/05c-test-model-end-to-end.ipynb308
-rw-r--r--text_recognizer/criterions/vqgan_loss.py23
-rw-r--r--text_recognizer/models/base.py8
-rw-r--r--text_recognizer/models/vqgan.py24
-rw-r--r--training/conf/experiment/vqgan.yaml37
-rw-r--r--training/conf/model/lit_vqgan.yaml1
-rw-r--r--training/conf/network/encoder/vae_encoder.yaml2
7 files changed, 205 insertions, 198 deletions
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb
index 42621da..23361b6 100644
--- a/notebooks/05c-test-model-end-to-end.ipynb
+++ b/notebooks/05c-test-model-end-to-end.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 32,
"id": "1e40a88b",
"metadata": {},
"outputs": [],
@@ -25,53 +25,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
- "id": "f40fc669-829c-4de8-83ed-475fc6a0b8c1",
- "metadata": {},
- "outputs": [],
- "source": [
- "class T:\n",
- " def __init__(self):\n",
- " self.network = nn.Linear(1, 1)\n",
- " \n",
- " def get(self):\n",
- " return getattr(self, \"network\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "d2bedf96-5388-4c7a-a048-1b97041cbedc",
- "metadata": {},
- "outputs": [],
- "source": [
- "t = T()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "a6fbe3be-2a9f-4050-a397-7ad982d6cd05",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<generator object Module.parameters at 0x7f29ad6d6120>"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "t.get().parameters()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
+ "execution_count": 33,
"id": "38fb3d9d-a163-4b72-981f-f31b51be39f2",
"metadata": {},
"outputs": [],
@@ -83,7 +37,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 47,
"id": "74780b21-3313-452b-b580-703cac878416",
"metadata": {},
"outputs": [
@@ -91,49 +45,181 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "encoder:\n",
- " _target_: text_recognizer.networks.vqvae.encoder.Encoder\n",
- " in_channels: 1\n",
- " hidden_dim: 32\n",
- " channels_multipliers:\n",
- " - 1\n",
- " - 2\n",
- " - 4\n",
- " - 4\n",
- " - 4\n",
- " dropout_rate: 0.25\n",
- "decoder:\n",
- " _target_: text_recognizer.networks.vqvae.decoder.Decoder\n",
- " out_channels: 1\n",
- " hidden_dim: 32\n",
- " channels_multipliers:\n",
- " - 4\n",
- " - 4\n",
- " - 4\n",
- " - 2\n",
- " - 1\n",
- " dropout_rate: 0.25\n",
- "_target_: text_recognizer.networks.vqvae.vqvae.VQVAE\n",
- "hidden_dim: 128\n",
- "embedding_dim: 32\n",
- "num_embeddings: 1024\n",
- "decay: 0.99\n",
+ "callbacks:\n",
+ " model_checkpoint:\n",
+ " _target_: pytorch_lightning.callbacks.ModelCheckpoint\n",
+ " monitor: val/loss\n",
+ " save_top_k: 1\n",
+ " save_last: true\n",
+ " mode: min\n",
+ " verbose: false\n",
+ " dirpath: checkpoints/\n",
+ " filename: '{epoch:02d}'\n",
+ " learning_rate_monitor:\n",
+ " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n",
+ " logging_interval: step\n",
+ " log_momentum: false\n",
+ " watch_model:\n",
+ " _target_: callbacks.wandb_callbacks.WatchModel\n",
+ " log: all\n",
+ " log_freq: 100\n",
+ " upload_ckpts_as_artifact:\n",
+ " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n",
+ " ckpt_dir: checkpoints/\n",
+ " upload_best_only: true\n",
+ " log_image_reconstruction:\n",
+ " _target_: callbacks.wandb_callbacks.LogReconstuctedImages\n",
+ " num_samples: 8\n",
+ "criterion:\n",
+ " _target_: text_recognizer.criterions.vqgan_loss.VQGANLoss\n",
+ " reconstruction_loss:\n",
+ " _target_: torch.nn.L1Loss\n",
+ " reduction: mean\n",
+ " discriminator:\n",
+ " _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator\n",
+ " in_channels: 1\n",
+ " num_channels: 32\n",
+ " num_layers: 3\n",
+ " vq_loss_weight: 1.0\n",
+ " discriminator_weight: 1.0\n",
+ "datamodule:\n",
+ " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n",
+ " batch_size: 8\n",
+ " num_workers: 12\n",
+ " train_fraction: 0.8\n",
+ " augment: true\n",
+ " pin_memory: false\n",
+ " word_pieces: true\n",
+ "logger:\n",
+ " wandb:\n",
+ " _target_: pytorch_lightning.loggers.wandb.WandbLogger\n",
+ " project: text-recognizer\n",
+ " name: null\n",
+ " save_dir: .\n",
+ " offline: false\n",
+ " id: null\n",
+ " log_model: false\n",
+ " prefix: ''\n",
+ " job_type: train\n",
+ " group: ''\n",
+ " tags: []\n",
+ "mapping:\n",
+ " _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping\n",
+ " num_features: 1000\n",
+ " tokens: iamdb_1kwp_tokens_1000.txt\n",
+ " lexicon: iamdb_1kwp_lex_1000.txt\n",
+ " data_dir: null\n",
+ " use_words: false\n",
+ " prepend_wordsep: false\n",
+ " special_tokens:\n",
+ " - <s>\n",
+ " - <e>\n",
+ " - <p>\n",
+ " extra_symbols:\n",
+ " - '\n",
"\n",
- "{'encoder': {'_target_': 'text_recognizer.networks.vqvae.encoder.Encoder', 'in_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [1, 2, 4, 4, 4], 'dropout_rate': 0.25}, 'decoder': {'_target_': 'text_recognizer.networks.vqvae.decoder.Decoder', 'out_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [4, 4, 4, 2, 1], 'dropout_rate': 0.25}, '_target_': 'text_recognizer.networks.vqvae.vqvae.VQVAE', 'hidden_dim': 128, 'embedding_dim': 32, 'num_embeddings': 1024, 'decay': 0.99}\n"
+ " '\n",
+ "model:\n",
+ " _target_: text_recognizer.models.vqgan.VQGANLitModel\n",
+ "network:\n",
+ " encoder:\n",
+ " _target_: text_recognizer.networks.vqvae.encoder.Encoder\n",
+ " in_channels: 1\n",
+ " hidden_dim: 32\n",
+ " channels_multipliers:\n",
+ " - 1\n",
+ " - 2\n",
+ " - 4\n",
+ " - 8\n",
+ " - 8\n",
+ " dropout_rate: 0.25\n",
+ " decoder:\n",
+ " _target_: text_recognizer.networks.vqvae.decoder.Decoder\n",
+ " out_channels: 1\n",
+ " hidden_dim: 32\n",
+ " channels_multipliers:\n",
+ " - 8\n",
+ " - 8\n",
+ " - 4\n",
+ " - 1\n",
+ " dropout_rate: 0.25\n",
+ " _target_: text_recognizer.networks.vqvae.vqvae.VQVAE\n",
+ " hidden_dim: 256\n",
+ " embedding_dim: 32\n",
+ " num_embeddings: 1024\n",
+ " decay: 0.99\n",
+ "trainer:\n",
+ " _target_: pytorch_lightning.Trainer\n",
+ " stochastic_weight_avg: false\n",
+ " auto_scale_batch_size: binsearch\n",
+ " auto_lr_find: false\n",
+ " gradient_clip_val: 0\n",
+ " fast_dev_run: false\n",
+ " gpus: 1\n",
+ " precision: 16\n",
+ " max_epochs: 256\n",
+ " terminate_on_nan: true\n",
+ " weights_summary: top\n",
+ " limit_train_batches: 1.0\n",
+ " limit_val_batches: 1.0\n",
+ " limit_test_batches: 1.0\n",
+ " resume_from_checkpoint: null\n",
+ "seed: 4711\n",
+ "tune: false\n",
+ "train: true\n",
+ "test: true\n",
+ "logging: INFO\n",
+ "work_dir: ${hydra:runtime.cwd}\n",
+ "debug: false\n",
+ "print_config: false\n",
+ "ignore_warnings: true\n",
+ "summary: null\n",
+ "lr_schedulers:\n",
+ " generator:\n",
+ " _target_: torch.optim.lr_scheduler.CosineAnnealingLR\n",
+ " T_max: 256\n",
+ " eta_min: 0.0\n",
+ " last_epoch: -1\n",
+ " interval: epoch\n",
+ " monitor: val/loss\n",
+ " discriminator:\n",
+ " _target_: torch.optim.lr_scheduler.CosineAnnealingLR\n",
+ " T_max: 256\n",
+ " eta_min: 0.0\n",
+ " last_epoch: -1\n",
+ " interval: epoch\n",
+ " monitor: val/loss\n",
+ "optimizers:\n",
+ " generator:\n",
+ " _target_: madgrad.MADGRAD\n",
+ " lr: 0.001\n",
+ " momentum: 0.5\n",
+ " weight_decay: 0\n",
+ " eps: 1.0e-06\n",
+ " parameters: network\n",
+ " discriminator:\n",
+ " _target_: madgrad.MADGRAD\n",
+ " lr: 0.001\n",
+ " momentum: 0.5\n",
+ " weight_decay: 0\n",
+ " eps: 1.0e-06\n",
+ " parameters: loss_fn.discriminator\n",
+ "\n",
+ "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': '{epoch:02d}'}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_image_reconstruction': {'_target_': 'callbacks.wandb_callbacks.LogReconstuctedImages', 'num_samples': 8}}, 'criterion': {'_target_': 'text_recognizer.criterions.vqgan_loss.VQGANLoss', 'reconstruction_loss': {'_target_': 'torch.nn.L1Loss', 'reduction': 'mean'}, 'discriminator': {'_target_': 'text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator', 'in_channels': 1, 'num_channels': 32, 'num_layers': 3}, 'vq_loss_weight': 1.0, 'discriminator_weight': 1.0}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 8, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False, 'word_pieces': True}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.vqgan.VQGANLitModel'}, 'network': {'encoder': {'_target_': 'text_recognizer.networks.vqvae.encoder.Encoder', 'in_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [1, 2, 4, 8, 8], 'dropout_rate': 0.25}, 'decoder': {'_target_': 'text_recognizer.networks.vqvae.decoder.Decoder', 'out_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [8, 8, 4, 1], 'dropout_rate': 0.25}, '_target_': 'text_recognizer.networks.vqvae.vqvae.VQVAE', 'hidden_dim': 256, 'embedding_dim': 32, 'num_embeddings': 1024, 'decay': 0.99}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 256, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': False, 'ignore_warnings': True, 'summary': None, 'lr_schedulers': {'generator': {'_target_': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'T_max': 256, 'eta_min': 0.0, 'last_epoch': -1, 'interval': 'epoch', 'monitor': 'val/loss'}, 'discriminator': {'_target_': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'T_max': 256, 'eta_min': 0.0, 'last_epoch': -1, 'interval': 'epoch', 'monitor': 'val/loss'}}, 'optimizers': {'generator': {'_target_': 'madgrad.MADGRAD', 'lr': 0.001, 'momentum': 0.5, 'weight_decay': 0, 'eps': 1e-06, 'parameters': 'network'}, 'discriminator': {'_target_': 'madgrad.MADGRAD', 'lr': 0.001, 'momentum': 0.5, 'weight_decay': 0, 'eps': 1e-06, 'parameters': 'loss_fn.discriminator'}}}\n"
]
}
],
"source": [
"# context initialization\n",
- "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n",
- " cfg = compose(config_name=\"vqvae\")\n",
+ "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n",
+ " cfg = compose(config_name=\"config\", overrides=[\"+experiment=vqgan\"])\n",
" print(OmegaConf.to_yaml(cfg))\n",
" print(cfg)"
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"id": "205a03e8-7aa1-407f-afa5-92693715b677",
"metadata": {},
"outputs": [],
@@ -143,7 +229,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"id": "c74384f0-754e-4c29-8f06-339372d6e4c1",
"metadata": {},
"outputs": [],
@@ -153,66 +239,10 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"id": "5ebab599-2497-42f8-b54b-1663ee66fde9",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "==========================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "==========================================================================================\n",
- "├─Encoder: 1-1 [-1, 128, 18, 20] --\n",
- "| └─Sequential: 2-1 [-1, 128, 18, 20] --\n",
- "| | └─Conv2d: 3-1 [-1, 32, 576, 640] 320\n",
- "| | └─Conv2d: 3-2 [-1, 32, 288, 320] 16,416\n",
- "| | └─Mish: 3-3 [-1, 32, 288, 320] --\n",
- "| | └─Conv2d: 3-4 [-1, 64, 144, 160] 32,832\n",
- "| | └─Mish: 3-5 [-1, 64, 144, 160] --\n",
- "| | └─Conv2d: 3-6 [-1, 128, 72, 80] 131,200\n",
- "| | └─Mish: 3-7 [-1, 128, 72, 80] --\n",
- "| | └─Conv2d: 3-8 [-1, 128, 36, 40] 262,272\n",
- "| | └─Mish: 3-9 [-1, 128, 36, 40] --\n",
- "| | └─Conv2d: 3-10 [-1, 128, 18, 20] 262,272\n",
- "| | └─Mish: 3-11 [-1, 128, 18, 20] --\n",
- "| | └─Residual: 3-12 [-1, 128, 18, 20] 164,352\n",
- "| | └─Residual: 3-13 [-1, 128, 18, 20] 164,352\n",
- "├─Conv2d: 1-2 [-1, 32, 18, 20] 4,128\n",
- "├─VectorQuantizer: 1-3 [-1, 32, 18, 20] --\n",
- "├─Conv2d: 1-4 [-1, 128, 18, 20] 4,224\n",
- "├─Decoder: 1-5 [-1, 1, 576, 640] --\n",
- "| └─Sequential: 2-2 [-1, 1, 576, 640] --\n",
- "| | └─Residual: 3-14 [-1, 128, 18, 20] 164,352\n",
- "| | └─Residual: 3-15 [-1, 128, 18, 20] 164,352\n",
- "| | └─ConvTranspose2d: 3-16 [-1, 128, 36, 40] 262,272\n",
- "| | └─Mish: 3-17 [-1, 128, 36, 40] --\n",
- "| | └─ConvTranspose2d: 3-18 [-1, 128, 72, 80] 262,272\n",
- "| | └─Mish: 3-19 [-1, 128, 72, 80] --\n",
- "| | └─ConvTranspose2d: 3-20 [-1, 64, 144, 160] 131,136\n",
- "| | └─Mish: 3-21 [-1, 64, 144, 160] --\n",
- "| | └─ConvTranspose2d: 3-22 [-1, 32, 288, 320] 32,800\n",
- "| | └─Mish: 3-23 [-1, 32, 288, 320] --\n",
- "| | └─ConvTranspose2d: 3-24 [-1, 32, 576, 640] 16,416\n",
- "| | └─Mish: 3-25 [-1, 32, 576, 640] --\n",
- "| | └─Normalize: 3-26 [-1, 32, 576, 640] 64\n",
- "| | └─Mish: 3-27 [-1, 32, 576, 640] --\n",
- "| | └─Conv2d: 3-28 [-1, 1, 576, 640] 289\n",
- "==========================================================================================\n",
- "Total params: 2,076,321\n",
- "Trainable params: 2,076,321\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (G): 17.68\n",
- "==========================================================================================\n",
- "Input size (MB): 1.41\n",
- "Forward/backward pass size (MB): 355.17\n",
- "Params size (MB): 7.92\n",
- "Estimated Total Size (MB): 364.49\n",
- "==========================================================================================\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"summary(net, (1, 576, 640), device=\"cpu\");"
]
diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py
index 8bb568f..87f0f1c 100644
--- a/text_recognizer/criterions/vqgan_loss.py
+++ b/text_recognizer/criterions/vqgan_loss.py
@@ -1,5 +1,5 @@
"""VQGAN loss for PyTorch Lightning."""
-from typing import Dict
+from typing import Dict, Optional
from click.types import Tuple
import torch
@@ -40,9 +40,9 @@ class VQGANLoss(nn.Module):
vq_loss: Tensor,
optimizer_idx: int,
stage: str,
- ) -> Tuple[Tensor, Dict[str, Tensor]]:
+ ) -> Optional[Tuple]:
"""Calculates the VQGAN loss."""
- rec_loss = self.reconstruction_loss(
+ rec_loss: Tensor = self.reconstruction_loss(
data.contiguous(), reconstructions.contiguous()
)
@@ -51,13 +51,13 @@ class VQGANLoss(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
- loss = (
+ loss: Tensor = (
rec_loss
+ self.discriminator_weight * g_loss
+ self.vq_loss_weight * vq_loss
)
log = {
- f"{stage}/loss": loss,
+ f"{stage}/total_loss": loss,
f"{stage}/vq_loss": vq_loss,
f"{stage}/rec_loss": rec_loss,
f"{stage}/g_loss": g_loss,
@@ -68,18 +68,11 @@ class VQGANLoss(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_real = self.discriminator(data.contiguous().detach())
- d_loss = self.adversarial_loss(
+ d_loss = self.discriminator_weight * self.adversarial_loss(
logits_real=logits_real, logits_fake=logits_fake
)
- loss = (
- rec_loss
- + self.discriminator_weight * d_loss
- + self.vq_loss_weight * vq_loss
- )
+
log = {
- f"{stage}/loss": loss,
- f"{stage}/vq_loss": vq_loss,
- f"{stage}/rec_loss": rec_loss,
f"{stage}/d_loss": d_loss,
}
- return loss, log
+ return d_loss, log
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 8b68ed9..94dbde5 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -49,12 +49,14 @@ class BaseLitModel(LightningModule):
"""Configures the optimizer."""
optimizers = []
for optimizer_config in self.optimizer_configs.values():
- network = getattr(self, optimizer_config.parameters)
+ module = self
+ for m in str(optimizer_config.parameters).split("."):
+ module = getattr(module, m)
del optimizer_config.parameters
log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
optimizers.append(
hydra.utils.instantiate(
- self.optimizer_config, params=network.parameters()
+ optimizer_config, params=module.parameters()
)
)
return optimizers
@@ -92,7 +94,7 @@ class BaseLitModel(LightningModule):
) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:
"""Configures optimizer and lr scheduler."""
optimizers = self._configure_optimizer()
- schedulers = self._configure_lr_scheduler(optimizers)
+ schedulers = self._configure_lr_schedulers(optimizers)
return optimizers, schedulers
def forward(self, data: Tensor) -> Tensor:
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
index 8ff65cc..80653b6 100644
--- a/text_recognizer/models/vqgan.py
+++ b/text_recognizer/models/vqgan.py
@@ -9,7 +9,7 @@ from text_recognizer.criterions.vqgan_loss import VQGANLoss
@attr.s(auto_attribs=True, eq=False)
-class VQVAELitModel(BaseLitModel):
+class VQGANLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
loss_fn: VQGANLoss = attr.ib()
@@ -26,7 +26,6 @@ class VQVAELitModel(BaseLitModel):
data, _ = batch
reconstructions, vq_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
if optimizer_idx == 0:
loss, log = self.loss_fn(
@@ -81,14 +80,6 @@ class VQVAELitModel(BaseLitModel):
self.log(
"val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
)
- self.log(
- "val/rec_loss",
- log["val/rec_loss"],
- prog_bar=True,
- logger=True,
- on_step=True,
- on_epoch=True,
- )
self.log_dict(log)
_, log = self.loss_fn(
@@ -105,24 +96,13 @@ class VQVAELitModel(BaseLitModel):
data, _ = batch
reconstructions, vq_loss = self(data)
- loss, log = self.loss_fn(
+ _, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
optimizer_idx=0,
stage="test",
)
- self.log(
- "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
- )
- self.log(
- "test/rec_loss",
- log["test/rec_loss"],
- prog_bar=True,
- logger=True,
- on_step=True,
- on_epoch=True,
- )
self.log_dict(log)
_, log = self.loss_fn(
diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml
index 3d97892..570e7f9 100644
--- a/training/conf/experiment/vqgan.yaml
+++ b/training/conf/experiment/vqgan.yaml
@@ -5,13 +5,15 @@ defaults:
- override /criterion: vqgan_loss
- override /model: lit_vqgan
- override /callbacks: wandb_vae
+ - override /optimizers: null
- override /lr_schedulers: null
datamodule:
batch_size: 8
lr_schedulers:
- - generator:
+ generator:
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 256
eta_min: 0.0
last_epoch: -1
@@ -19,7 +21,8 @@ lr_schedulers:
interval: epoch
monitor: val/loss
- - discriminator:
+ discriminator:
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 256
eta_min: 0.0
last_epoch: -1
@@ -27,26 +30,24 @@ lr_schedulers:
interval: epoch
monitor: val/loss
-optimizer:
- - generator:
- _target_: torch.optim.lr_scheduler.CosineAnnealingLR
- T_max: 256
- eta_min: 0.0
- last_epoch: -1
+optimizers:
+ generator:
+ _target_: madgrad.MADGRAD
+ lr: 2.0e-5
+ momentum: 0.5
+ weight_decay: 0
+ eps: 1.0e-6
- interval: epoch
- monitor: val/loss
parameters: network
- - discriminator:
- _target_: torch.optim.lr_scheduler.CosineAnnealingLR
- T_max: 256
- eta_min: 0.0
- last_epoch: -1
+ discriminator:
+ _target_: madgrad.MADGRAD
+ lr: 2.0e-5
+ momentum: 0.5
+ weight_decay: 0
+ eps: 1.0e-6
- interval: epoch
- monitor: val/loss
- parameters: loss_fn
+ parameters: loss_fn.discriminator
trainer:
max_epochs: 256
diff --git a/training/conf/model/lit_vqgan.yaml b/training/conf/model/lit_vqgan.yaml
new file mode 100644
index 0000000..9ee1046
--- /dev/null
+++ b/training/conf/model/lit_vqgan.yaml
@@ -0,0 +1 @@
+_target_: text_recognizer.models.vqgan.VQGANLitModel
diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml
index 58e905d..099c36a 100644
--- a/training/conf/network/encoder/vae_encoder.yaml
+++ b/training/conf/network/encoder/vae_encoder.yaml
@@ -1,5 +1,5 @@
_target_: text_recognizer.networks.vqvae.encoder.Encoder
in_channels: 1
hidden_dim: 32
-channels_multipliers: [1, 2, 4, 8, 8]
+channels_multipliers: [1, 4, 8, 8]
dropout_rate: 0.25