summaryrefslogtreecommitdiff
path: root/training/conf
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 21:43:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 21:43:39 +0200
commit82f4acabe24e5171c40afa2939a4777ba87bcc30 (patch)
tree4d327fa26e4662a0447a66375442a9adeb13ea3d /training/conf
parent240f5e9f20032e82515fa66ce784619527d1041e (diff)
Add training of VQGAN
Diffstat (limited to 'training/conf')
-rw-r--r--training/conf/experiment/vqgan.yaml37
-rw-r--r--training/conf/model/lit_vqgan.yaml1
-rw-r--r--training/conf/network/encoder/vae_encoder.yaml2
3 files changed, 21 insertions, 19 deletions
diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml
index 3d97892..570e7f9 100644
--- a/training/conf/experiment/vqgan.yaml
+++ b/training/conf/experiment/vqgan.yaml
@@ -5,13 +5,15 @@ defaults:
- override /criterion: vqgan_loss
- override /model: lit_vqgan
- override /callbacks: wandb_vae
+ - override /optimizers: null
- override /lr_schedulers: null
datamodule:
batch_size: 8
lr_schedulers:
- - generator:
+ generator:
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 256
eta_min: 0.0
last_epoch: -1
@@ -19,7 +21,8 @@ lr_schedulers:
interval: epoch
monitor: val/loss
- - discriminator:
+ discriminator:
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 256
eta_min: 0.0
last_epoch: -1
@@ -27,26 +30,24 @@ lr_schedulers:
interval: epoch
monitor: val/loss
-optimizer:
- - generator:
- _target_: torch.optim.lr_scheduler.CosineAnnealingLR
- T_max: 256
- eta_min: 0.0
- last_epoch: -1
+optimizers:
+ generator:
+ _target_: madgrad.MADGRAD
+ lr: 2.0e-5
+ momentum: 0.5
+ weight_decay: 0
+ eps: 1.0e-6
- interval: epoch
- monitor: val/loss
parameters: network
- - discriminator:
- _target_: torch.optim.lr_scheduler.CosineAnnealingLR
- T_max: 256
- eta_min: 0.0
- last_epoch: -1
+ discriminator:
+ _target_: madgrad.MADGRAD
+ lr: 2.0e-5
+ momentum: 0.5
+ weight_decay: 0
+ eps: 1.0e-6
- interval: epoch
- monitor: val/loss
- parameters: loss_fn
+ parameters: loss_fn.discriminator
trainer:
max_epochs: 256
diff --git a/training/conf/model/lit_vqgan.yaml b/training/conf/model/lit_vqgan.yaml
new file mode 100644
index 0000000..9ee1046
--- /dev/null
+++ b/training/conf/model/lit_vqgan.yaml
@@ -0,0 +1 @@
+_target_: text_recognizer.models.vqgan.VQGANLitModel
diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml
index 58e905d..099c36a 100644
--- a/training/conf/network/encoder/vae_encoder.yaml
+++ b/training/conf/network/encoder/vae_encoder.yaml
@@ -1,5 +1,5 @@
_target_: text_recognizer.networks.vqvae.encoder.Encoder
in_channels: 1
hidden_dim: 32
-channels_multipliers: [1, 2, 4, 8, 8]
+channels_multipliers: [1, 4, 8, 8]
dropout_rate: 0.25