From d3afa310f77f47553586eeee58e3d3345a754e2c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 4 Aug 2021 05:03:51 +0200 Subject: New VQVAE --- training/callbacks/wandb_callbacks.py | 69 ++++++++++++++-------- .../callbacks/wandb_image_reconstructions.yaml | 3 + training/conf/callbacks/wandb_vae.yaml | 6 ++ training/conf/config.yaml | 2 + training/conf/experiment/vqvae.yaml | 20 +++++++ training/conf/experiment/vqvae_experiment.yaml | 13 ---- training/conf/model/lit_vqvae.yaml | 4 +- training/conf/network/conv_transformer.yaml | 2 +- .../conf/network/decoder/transformer_decoder.yaml | 4 +- training/conf/network/vqvae.yaml | 21 +++---- 10 files changed, 91 insertions(+), 53 deletions(-) create mode 100644 training/conf/callbacks/wandb_vae.yaml create mode 100644 training/conf/experiment/vqvae.yaml delete mode 100644 training/conf/experiment/vqvae_experiment.yaml (limited to 'training') diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 906531f..c750e4b 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -5,6 +5,7 @@ import wandb from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only +from torch.utils.data import DataLoader def get_wandb_logger(trainer: Trainer) -> WandbLogger: @@ -86,7 +87,11 @@ class LogTextPredictions(Callback): self.ready = False def _log_predictions( - self, stage: str, trainer: Trainer, pl_module: LightningModule + self, + stage: str, + trainer: Trainer, + pl_module: LightningModule, + dataloader: DataLoader, ) -> None: """Logs the predicted text contained in the images.""" if not self.ready: @@ -96,22 +101,20 @@ class LogTextPredictions(Callback): experiment = logger.experiment # Get a validation batch from the validation dataloader. - samples = next(iter(trainer.datamodule.val_dataloader())) + samples = next(iter(dataloader)) imgs, labels = samples imgs = imgs.to(device=pl_module.device) logits = pl_module(imgs) mapping = pl_module.mapping - columns = ["id", "image", "prediction", "truth"] + columns = ["image", "prediction", "truth"] data = [ - [id, wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)] - for id, (img, pred, label) in enumerate( - zip( - imgs[: self.num_samples], - logits[: self.num_samples], - labels[: self.num_samples], - ) + [wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)] + for img, pred, label in zip( + imgs[: self.num_samples], + logits[: self.num_samples], + labels[: self.num_samples], ) ] @@ -133,11 +136,17 @@ class LogTextPredictions(Callback): self, trainer: Trainer, pl_module: LightningModule ) -> None: """Logs predictions on validation epoch end.""" - self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.val_dataloader() + self._log_predictions( + stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Logs predictions on train epoch end.""" - self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.test_dataloader() + self._log_predictions( + stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) class LogReconstuctedImages(Callback): @@ -148,7 +157,11 @@ class LogReconstuctedImages(Callback): self.ready = False def _log_reconstruction( - self, stage: str, trainer: Trainer, pl_module: LightningModule + self, + stage: str, + trainer: Trainer, + pl_module: LightningModule, + dataloader: DataLoader, ) -> None: """Logs the reconstructions.""" if not self.ready: @@ -158,20 +171,24 @@ class LogReconstuctedImages(Callback): experiment = logger.experiment # Get a validation batch from the validation dataloader. - samples = next(iter(trainer.datamodule.val_dataloader())) + samples = next(iter(dataloader)) imgs, _ = samples + colums = ["input", "reconstruction"] imgs = imgs.to(device=pl_module.device) - reconstructions = pl_module(imgs) + reconstructions = pl_module(imgs)[0] + data = [ + [wandb.Image(img), wandb.Image(rec)] + for img, rec in zip( + imgs[: self.num_samples], reconstructions[: self.num_samples] + ) + ] experiment.log( { - f"Reconstructions/{experiment.name}/{stage}": [ - [wandb.Image(img), wandb.Image(rec),] - for img, rec in zip( - imgs[: self.num_samples], reconstructions[: self.num_samples], - ) - ] + f"Reconstructions/{experiment.name}/{stage}": wandb.Table( + data=data, columns=colums + ) } ) @@ -189,8 +206,14 @@ class LogReconstuctedImages(Callback): self, trainer: Trainer, pl_module: LightningModule ) -> None: """Logs predictions on validation epoch end.""" - self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.val_dataloader() + self._log_reconstruction( + stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Logs predictions on train epoch end.""" - self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.test_dataloader() + self._log_reconstruction( + stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml index e69de29..6cc4ada 100644 --- a/training/conf/callbacks/wandb_image_reconstructions.yaml +++ b/training/conf/callbacks/wandb_image_reconstructions.yaml @@ -0,0 +1,3 @@ +log_image_reconstruction: + _target_: callbacks.wandb_callbacks.LogReconstuctedImages + num_samples: 8 diff --git a/training/conf/callbacks/wandb_vae.yaml b/training/conf/callbacks/wandb_vae.yaml new file mode 100644 index 0000000..609a8e8 --- /dev/null +++ b/training/conf/callbacks/wandb_vae.yaml @@ -0,0 +1,6 @@ +defaults: + - default + - wandb_watch + - wandb_code + - wandb_checkpoints + - wandb_image_reconstructions diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 782bcbb..6b74502 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,3 +1,5 @@ +# @package _global_ + defaults: - callbacks: wandb_ocr - criterion: label_smoothing diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml new file mode 100644 index 0000000..13e5f34 --- /dev/null +++ b/training/conf/experiment/vqvae.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +defaults: + - override /network: vqvae + - override /criterion: mse + - override /model: lit_vqvae + - override /callbacks: wandb_vae + +trainer: + max_epochs: 64 + +datamodule: + batch_size: 32 + +lr_scheduler: + epochs: 64 + steps_per_epoch: 624 + +optimizer: + lr: 1.0e-2 diff --git a/training/conf/experiment/vqvae_experiment.yaml b/training/conf/experiment/vqvae_experiment.yaml deleted file mode 100644 index 0858c3d..0000000 --- a/training/conf/experiment/vqvae_experiment.yaml +++ /dev/null @@ -1,13 +0,0 @@ -defaults: - - override /network: vqvae - - override /criterion: mse - - override /optimizer: madgrad - - override /lr_scheduler: one_cycle - - override /model: lit_vqvae - - override /dataset: iam_extended_paragraphs - - override /trainer: default - - override /callbacks: - - wandb - -load_checkpoint: null -logging: INFO diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml index b337fe6..8837573 100644 --- a/training/conf/model/lit_vqvae.yaml +++ b/training/conf/model/lit_vqvae.yaml @@ -1,2 +1,4 @@ _target_: text_recognizer.models.vqvae.VQVAELitModel -mapping: sentence_piece +interval: step +monitor: val/loss +latent_loss_weight: 0.25 diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index f76e892..d3a3b0f 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -4,7 +4,7 @@ defaults: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] -hidden_dim: 96 +hidden_dim: 128 dropout_rate: 0.2 num_classes: 1006 pad_index: 1002 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index eb80f64..c326c04 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -2,12 +2,12 @@ defaults: - rotary_emb: null _target_: text_recognizer.networks.transformer.Decoder -dim: 96 +dim: 128 depth: 2 num_heads: 8 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: - dim_head: 16 + dim_head: 64 dropout_rate: 0.2 norm_fn: torch.nn.LayerNorm ff_fn: text_recognizer.networks.transformer.mlp.FeedForward diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index 22eebf8..5a5c066 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -1,13 +1,8 @@ -type: VQVAE -args: - in_channels: 1 - channels: [64, 96] - kernel_sizes: [4, 4] - strides: [2, 2] - num_residual_layers: 2 - embedding_dim: 64 - num_embeddings: 256 - upsampling: null - beta: 0.25 - activation: leaky_relu - dropout_rate: 0.2 +_target_: text_recognizer.networks.vqvae.VQVAE +in_channels: 1 +res_channels: 32 +num_residual_layers: 2 +embedding_dim: 64 +num_embeddings: 512 +decay: 0.99 +activation: mish -- cgit v1.2.3-70-g09d2