summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-04 05:03:51 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-04 05:03:51 +0200
commitd3afa310f77f47553586eeee58e3d3345a754e2c (patch)
tree08b7de1daf2550852d0a1e4d4d75202f14bb03d4 /training
parent65d5f6c694e73792e40ed693a1381a792da8d277 (diff)
New VQVAE
Diffstat (limited to 'training')
-rw-r--r--training/callbacks/wandb_callbacks.py69
-rw-r--r--training/conf/callbacks/wandb_image_reconstructions.yaml3
-rw-r--r--training/conf/callbacks/wandb_vae.yaml6
-rw-r--r--training/conf/config.yaml2
-rw-r--r--training/conf/experiment/vqvae.yaml20
-rw-r--r--training/conf/experiment/vqvae_experiment.yaml13
-rw-r--r--training/conf/model/lit_vqvae.yaml4
-rw-r--r--training/conf/network/conv_transformer.yaml2
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml4
-rw-r--r--training/conf/network/vqvae.yaml21
10 files changed, 91 insertions, 53 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 906531f..c750e4b 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -5,6 +5,7 @@ import wandb
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
+from torch.utils.data import DataLoader
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
@@ -86,7 +87,11 @@ class LogTextPredictions(Callback):
self.ready = False
def _log_predictions(
- self, stage: str, trainer: Trainer, pl_module: LightningModule
+ self,
+ stage: str,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ dataloader: DataLoader,
) -> None:
"""Logs the predicted text contained in the images."""
if not self.ready:
@@ -96,22 +101,20 @@ class LogTextPredictions(Callback):
experiment = logger.experiment
# Get a validation batch from the validation dataloader.
- samples = next(iter(trainer.datamodule.val_dataloader()))
+ samples = next(iter(dataloader))
imgs, labels = samples
imgs = imgs.to(device=pl_module.device)
logits = pl_module(imgs)
mapping = pl_module.mapping
- columns = ["id", "image", "prediction", "truth"]
+ columns = ["image", "prediction", "truth"]
data = [
- [id, wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)]
- for id, (img, pred, label) in enumerate(
- zip(
- imgs[: self.num_samples],
- logits[: self.num_samples],
- labels[: self.num_samples],
- )
+ [wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)]
+ for img, pred, label in zip(
+ imgs[: self.num_samples],
+ logits[: self.num_samples],
+ labels[: self.num_samples],
)
]
@@ -133,11 +136,17 @@ class LogTextPredictions(Callback):
self, trainer: Trainer, pl_module: LightningModule
) -> None:
"""Logs predictions on validation epoch end."""
- self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.val_dataloader()
+ self._log_predictions(
+ stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
- self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.test_dataloader()
+ self._log_predictions(
+ stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
class LogReconstuctedImages(Callback):
@@ -148,7 +157,11 @@ class LogReconstuctedImages(Callback):
self.ready = False
def _log_reconstruction(
- self, stage: str, trainer: Trainer, pl_module: LightningModule
+ self,
+ stage: str,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ dataloader: DataLoader,
) -> None:
"""Logs the reconstructions."""
if not self.ready:
@@ -158,20 +171,24 @@ class LogReconstuctedImages(Callback):
experiment = logger.experiment
# Get a validation batch from the validation dataloader.
- samples = next(iter(trainer.datamodule.val_dataloader()))
+ samples = next(iter(dataloader))
imgs, _ = samples
+ colums = ["input", "reconstruction"]
imgs = imgs.to(device=pl_module.device)
- reconstructions = pl_module(imgs)
+ reconstructions = pl_module(imgs)[0]
+ data = [
+ [wandb.Image(img), wandb.Image(rec)]
+ for img, rec in zip(
+ imgs[: self.num_samples], reconstructions[: self.num_samples]
+ )
+ ]
experiment.log(
{
- f"Reconstructions/{experiment.name}/{stage}": [
- [wandb.Image(img), wandb.Image(rec),]
- for img, rec in zip(
- imgs[: self.num_samples], reconstructions[: self.num_samples],
- )
- ]
+ f"Reconstructions/{experiment.name}/{stage}": wandb.Table(
+ data=data, columns=colums
+ )
}
)
@@ -189,8 +206,14 @@ class LogReconstuctedImages(Callback):
self, trainer: Trainer, pl_module: LightningModule
) -> None:
"""Logs predictions on validation epoch end."""
- self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.val_dataloader()
+ self._log_reconstruction(
+ stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
- self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.test_dataloader()
+ self._log_reconstruction(
+ stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml
index e69de29..6cc4ada 100644
--- a/training/conf/callbacks/wandb_image_reconstructions.yaml
+++ b/training/conf/callbacks/wandb_image_reconstructions.yaml
@@ -0,0 +1,3 @@
+log_image_reconstruction:
+ _target_: callbacks.wandb_callbacks.LogReconstuctedImages
+ num_samples: 8
diff --git a/training/conf/callbacks/wandb_vae.yaml b/training/conf/callbacks/wandb_vae.yaml
new file mode 100644
index 0000000..609a8e8
--- /dev/null
+++ b/training/conf/callbacks/wandb_vae.yaml
@@ -0,0 +1,6 @@
+defaults:
+ - default
+ - wandb_watch
+ - wandb_code
+ - wandb_checkpoints
+ - wandb_image_reconstructions
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index 782bcbb..6b74502 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,3 +1,5 @@
+# @package _global_
+
defaults:
- callbacks: wandb_ocr
- criterion: label_smoothing
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
new file mode 100644
index 0000000..13e5f34
--- /dev/null
+++ b/training/conf/experiment/vqvae.yaml
@@ -0,0 +1,20 @@
+# @package _global_
+
+defaults:
+ - override /network: vqvae
+ - override /criterion: mse
+ - override /model: lit_vqvae
+ - override /callbacks: wandb_vae
+
+trainer:
+ max_epochs: 64
+
+datamodule:
+ batch_size: 32
+
+lr_scheduler:
+ epochs: 64
+ steps_per_epoch: 624
+
+optimizer:
+ lr: 1.0e-2
diff --git a/training/conf/experiment/vqvae_experiment.yaml b/training/conf/experiment/vqvae_experiment.yaml
deleted file mode 100644
index 0858c3d..0000000
--- a/training/conf/experiment/vqvae_experiment.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-defaults:
- - override /network: vqvae
- - override /criterion: mse
- - override /optimizer: madgrad
- - override /lr_scheduler: one_cycle
- - override /model: lit_vqvae
- - override /dataset: iam_extended_paragraphs
- - override /trainer: default
- - override /callbacks:
- - wandb
-
-load_checkpoint: null
-logging: INFO
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
index b337fe6..8837573 100644
--- a/training/conf/model/lit_vqvae.yaml
+++ b/training/conf/model/lit_vqvae.yaml
@@ -1,2 +1,4 @@
_target_: text_recognizer.models.vqvae.VQVAELitModel
-mapping: sentence_piece
+interval: step
+monitor: val/loss
+latent_loss_weight: 0.25
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index f76e892..d3a3b0f 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -4,7 +4,7 @@ defaults:
_target_: text_recognizer.networks.conv_transformer.ConvTransformer
input_dims: [1, 576, 640]
-hidden_dim: 96
+hidden_dim: 128
dropout_rate: 0.2
num_classes: 1006
pad_index: 1002
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index eb80f64..c326c04 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -2,12 +2,12 @@ defaults:
- rotary_emb: null
_target_: text_recognizer.networks.transformer.Decoder
-dim: 96
+dim: 128
depth: 2
num_heads: 8
attn_fn: text_recognizer.networks.transformer.attention.Attention
attn_kwargs:
- dim_head: 16
+ dim_head: 64
dropout_rate: 0.2
norm_fn: torch.nn.LayerNorm
ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
index 22eebf8..5a5c066 100644
--- a/training/conf/network/vqvae.yaml
+++ b/training/conf/network/vqvae.yaml
@@ -1,13 +1,8 @@
-type: VQVAE
-args:
- in_channels: 1
- channels: [64, 96]
- kernel_sizes: [4, 4]
- strides: [2, 2]
- num_residual_layers: 2
- embedding_dim: 64
- num_embeddings: 256
- upsampling: null
- beta: 0.25
- activation: leaky_relu
- dropout_rate: 0.2
+_target_: text_recognizer.networks.vqvae.VQVAE
+in_channels: 1
+res_channels: 32
+num_residual_layers: 2
+embedding_dim: 64
+num_embeddings: 512
+decay: 0.99
+activation: mish