diff options
Diffstat (limited to 'training/conf')
-rw-r--r-- | training/conf/experiment/vqgan.yaml | 37 | ||||
-rw-r--r-- | training/conf/model/lit_vqgan.yaml | 1 | ||||
-rw-r--r-- | training/conf/network/encoder/vae_encoder.yaml | 2 |
3 files changed, 21 insertions, 19 deletions
diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml index 3d97892..570e7f9 100644 --- a/training/conf/experiment/vqgan.yaml +++ b/training/conf/experiment/vqgan.yaml @@ -5,13 +5,15 @@ defaults: - override /criterion: vqgan_loss - override /model: lit_vqgan - override /callbacks: wandb_vae + - override /optimizers: null - override /lr_schedulers: null datamodule: batch_size: 8 lr_schedulers: - - generator: + generator: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR T_max: 256 eta_min: 0.0 last_epoch: -1 @@ -19,7 +21,8 @@ lr_schedulers: interval: epoch monitor: val/loss - - discriminator: + discriminator: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR T_max: 256 eta_min: 0.0 last_epoch: -1 @@ -27,26 +30,24 @@ lr_schedulers: interval: epoch monitor: val/loss -optimizer: - - generator: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: 256 - eta_min: 0.0 - last_epoch: -1 +optimizers: + generator: + _target_: madgrad.MADGRAD + lr: 2.0e-5 + momentum: 0.5 + weight_decay: 0 + eps: 1.0e-6 - interval: epoch - monitor: val/loss parameters: network - - discriminator: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: 256 - eta_min: 0.0 - last_epoch: -1 + discriminator: + _target_: madgrad.MADGRAD + lr: 2.0e-5 + momentum: 0.5 + weight_decay: 0 + eps: 1.0e-6 - interval: epoch - monitor: val/loss - parameters: loss_fn + parameters: loss_fn.discriminator trainer: max_epochs: 256 diff --git a/training/conf/model/lit_vqgan.yaml b/training/conf/model/lit_vqgan.yaml new file mode 100644 index 0000000..9ee1046 --- /dev/null +++ b/training/conf/model/lit_vqgan.yaml @@ -0,0 +1 @@ +_target_: text_recognizer.models.vqgan.VQGANLitModel diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml index 58e905d..099c36a 100644 --- a/training/conf/network/encoder/vae_encoder.yaml +++ b/training/conf/network/encoder/vae_encoder.yaml @@ -1,5 +1,5 @@ _target_: text_recognizer.networks.vqvae.encoder.Encoder in_channels: 1 hidden_dim: 32 -channels_multipliers: [1, 2, 4, 8, 8] +channels_multipliers: [1, 4, 8, 8] dropout_rate: 0.25 |