From 82f4acabe24e5171c40afa2939a4777ba87bcc30 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 8 Aug 2021 21:43:39 +0200 Subject: Add training of VQGAN --- training/conf/experiment/vqgan.yaml | 37 +++++++++++++------------- training/conf/model/lit_vqgan.yaml | 1 + training/conf/network/encoder/vae_encoder.yaml | 2 +- 3 files changed, 21 insertions(+), 19 deletions(-) create mode 100644 training/conf/model/lit_vqgan.yaml (limited to 'training') diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml index 3d97892..570e7f9 100644 --- a/training/conf/experiment/vqgan.yaml +++ b/training/conf/experiment/vqgan.yaml @@ -5,13 +5,15 @@ defaults: - override /criterion: vqgan_loss - override /model: lit_vqgan - override /callbacks: wandb_vae + - override /optimizers: null - override /lr_schedulers: null datamodule: batch_size: 8 lr_schedulers: - - generator: + generator: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR T_max: 256 eta_min: 0.0 last_epoch: -1 @@ -19,7 +21,8 @@ lr_schedulers: interval: epoch monitor: val/loss - - discriminator: + discriminator: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR T_max: 256 eta_min: 0.0 last_epoch: -1 @@ -27,26 +30,24 @@ lr_schedulers: interval: epoch monitor: val/loss -optimizer: - - generator: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: 256 - eta_min: 0.0 - last_epoch: -1 +optimizers: + generator: + _target_: madgrad.MADGRAD + lr: 2.0e-5 + momentum: 0.5 + weight_decay: 0 + eps: 1.0e-6 - interval: epoch - monitor: val/loss parameters: network - - discriminator: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: 256 - eta_min: 0.0 - last_epoch: -1 + discriminator: + _target_: madgrad.MADGRAD + lr: 2.0e-5 + momentum: 0.5 + weight_decay: 0 + eps: 1.0e-6 - interval: epoch - monitor: val/loss - parameters: loss_fn + parameters: loss_fn.discriminator trainer: max_epochs: 256 diff --git a/training/conf/model/lit_vqgan.yaml b/training/conf/model/lit_vqgan.yaml new file mode 100644 index 0000000..9ee1046 --- /dev/null +++ b/training/conf/model/lit_vqgan.yaml @@ -0,0 +1 @@ +_target_: text_recognizer.models.vqgan.VQGANLitModel diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml index 58e905d..099c36a 100644 --- a/training/conf/network/encoder/vae_encoder.yaml +++ b/training/conf/network/encoder/vae_encoder.yaml @@ -1,5 +1,5 @@ _target_: text_recognizer.networks.vqvae.encoder.Encoder in_channels: 1 hidden_dim: 32 -channels_multipliers: [1, 2, 4, 8, 8] +channels_multipliers: [1, 4, 8, 8] dropout_rate: 0.25 -- cgit v1.2.3-70-g09d2