summaryrefslogtreecommitdiff
path: root/training/conf
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 19:59:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 19:59:55 +0200
commit240f5e9f20032e82515fa66ce784619527d1041e (patch)
treeb002d28bbfc9abe9b6af090f7db60bea0aeed6e8 /training/conf
parentd12f70402371dda586d457af2a3df7fb5b3130ad (diff)
Add VQGAN and loss function
Diffstat (limited to 'training/conf')
-rw-r--r--training/conf/config.yaml12
-rw-r--r--training/conf/criterion/mae.yaml2
-rw-r--r--training/conf/criterion/vqgan_loss.yaml12
-rw-r--r--training/conf/experiment/vqgan.yaml55
-rw-r--r--training/conf/experiment/vqvae.yaml11
-rw-r--r--training/conf/experiment/vqvae_pixelcnn.yaml24
-rw-r--r--training/conf/lr_scheduler/cosine_annealing.yaml13
-rw-r--r--training/conf/lr_scheduler/one_cycle.yaml37
-rw-r--r--training/conf/network/decoder/pixelcnn_decoder.yaml (renamed from training/conf/network/encoder/pixelcnn_decoder.yaml)2
-rw-r--r--training/conf/network/decoder/vae_decoder.yaml2
-rw-r--r--training/conf/network/encoder/pixelcnn_encoder.yaml (renamed from training/conf/network/decoder/pixelcnn_encoder.yaml)2
-rw-r--r--training/conf/network/encoder/vae_encoder.yaml2
-rw-r--r--training/conf/network/vqvae.yaml2
-rw-r--r--training/conf/network/vqvae_pixelcnn.yaml2
-rw-r--r--training/conf/optimizer/madgrad.yaml13
15 files changed, 147 insertions, 44 deletions
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index c606366..5897d87 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -6,11 +6,13 @@ defaults:
- datamodule: iam_extended_paragraphs
- hydra: default
- logger: wandb
- - lr_scheduler: one_cycle
+ - lr_schedulers:
+ - one_cycle
- mapping: word_piece
- model: lit_transformer
- network: conv_transformer
- - optimizer: madgrad
+ - optimizers:
+ - madgrad
- trainer: default
seed: 4711
@@ -32,7 +34,9 @@ work_dir: ${hydra:runtime.cwd}
debug: False
# pretty print config at the start of the run using Rich library
-print_config: True
+print_config: false
# disable python warnings if they annoy you
-ignore_warnings: True
+ignore_warnings: true
+
+summary: null # [1, 576, 640]
diff --git a/training/conf/criterion/mae.yaml b/training/conf/criterion/mae.yaml
new file mode 100644
index 0000000..cb07467
--- /dev/null
+++ b/training/conf/criterion/mae.yaml
@@ -0,0 +1,2 @@
+_target_: torch.nn.L1Loss
+reduction: mean
diff --git a/training/conf/criterion/vqgan_loss.yaml b/training/conf/criterion/vqgan_loss.yaml
new file mode 100644
index 0000000..a1c886e
--- /dev/null
+++ b/training/conf/criterion/vqgan_loss.yaml
@@ -0,0 +1,12 @@
+_target_: text_recognizer.criterions.vqgan_loss.VQGANLoss
+reconstruction_loss:
+ _target_: torch.nn.L1Loss
+ reduction: mean
+discriminator:
+ _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator
+ in_channels: 1
+ num_channels: 32
+ num_layers: 3
+vq_loss_weight: 1.0
+discriminator_weight: 1.0
+
diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml
new file mode 100644
index 0000000..3d97892
--- /dev/null
+++ b/training/conf/experiment/vqgan.yaml
@@ -0,0 +1,55 @@
+# @package _global_
+
+defaults:
+ - override /network: vqvae
+ - override /criterion: vqgan_loss
+ - override /model: lit_vqgan
+ - override /callbacks: wandb_vae
+ - override /lr_schedulers: null
+
+datamodule:
+ batch_size: 8
+
+lr_schedulers:
+ - generator:
+ T_max: 256
+ eta_min: 0.0
+ last_epoch: -1
+
+ interval: epoch
+ monitor: val/loss
+
+ - discriminator:
+ T_max: 256
+ eta_min: 0.0
+ last_epoch: -1
+
+ interval: epoch
+ monitor: val/loss
+
+optimizer:
+ - generator:
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
+ T_max: 256
+ eta_min: 0.0
+ last_epoch: -1
+
+ interval: epoch
+ monitor: val/loss
+ parameters: network
+
+ - discriminator:
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
+ T_max: 256
+ eta_min: 0.0
+ last_epoch: -1
+
+ interval: epoch
+ monitor: val/loss
+ parameters: loss_fn
+
+trainer:
+ max_epochs: 256
+ # gradient_clip_val: 0.25
+
+summary: null
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
index 7a9e643..397a039 100644
--- a/training/conf/experiment/vqvae.yaml
+++ b/training/conf/experiment/vqvae.yaml
@@ -2,17 +2,18 @@
defaults:
- override /network: vqvae
- - override /criterion: mse
+ - override /criterion: mae
- override /model: lit_vqvae
- override /callbacks: wandb_vae
- - override /lr_scheduler: cosine_annealing
+ - override /lr_schedulers:
+ - cosine_annealing
trainer:
- max_epochs: 64
+ max_epochs: 256
# gradient_clip_val: 0.25
datamodule:
- batch_size: 16
+ batch_size: 8
# lr_scheduler:
# epochs: 64
@@ -21,4 +22,4 @@ datamodule:
# optimizer:
# lr: 1.0e-3
-summary: [1, 576, 640]
+summary: null
diff --git a/training/conf/experiment/vqvae_pixelcnn.yaml b/training/conf/experiment/vqvae_pixelcnn.yaml
new file mode 100644
index 0000000..4fae782
--- /dev/null
+++ b/training/conf/experiment/vqvae_pixelcnn.yaml
@@ -0,0 +1,24 @@
+# @package _global_
+
+defaults:
+ - override /network: vqvae_pixelcnn
+ - override /criterion: mae
+ - override /model: lit_vqvae
+ - override /callbacks: wandb_vae
+ - override /lr_schedulers:
+ - cosine_annealing
+
+trainer:
+ max_epochs: 256
+ # gradient_clip_val: 0.25
+
+datamodule:
+ batch_size: 8
+
+# lr_scheduler:
+ # epochs: 64
+ # steps_per_epoch: 1245
+
+# optimizer:
+ # lr: 1.0e-3
+
diff --git a/training/conf/lr_scheduler/cosine_annealing.yaml b/training/conf/lr_scheduler/cosine_annealing.yaml
index 62667bb..c53ee3a 100644
--- a/training/conf/lr_scheduler/cosine_annealing.yaml
+++ b/training/conf/lr_scheduler/cosine_annealing.yaml
@@ -1,7 +1,8 @@
-_target_: torch.optim.lr_scheduler.CosineAnnealingLR
-T_max: 64
-eta_min: 0.0
-last_epoch: -1
+cosine_annealing:
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
+ T_max: 256
+ eta_min: 0.0
+ last_epoch: -1
-interval: epoch
-monitor: val/loss
+ interval: epoch
+ monitor: val/loss
diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml
index fb5987a..c60577a 100644
--- a/training/conf/lr_scheduler/one_cycle.yaml
+++ b/training/conf/lr_scheduler/one_cycle.yaml
@@ -1,19 +1,20 @@
-_target_: torch.optim.lr_scheduler.OneCycleLR
-max_lr: 1.0e-3
-total_steps: null
-epochs: 512
-steps_per_epoch: 4992
-pct_start: 0.3
-anneal_strategy: cos
-cycle_momentum: true
-base_momentum: 0.85
-max_momentum: 0.95
-div_factor: 25.0
-final_div_factor: 10000.0
-three_phase: true
-last_epoch: -1
-verbose: false
+onc_cycle:
+ _target_: torch.optim.lr_scheduler.OneCycleLR
+ max_lr: 1.0e-3
+ total_steps: null
+ epochs: 512
+ steps_per_epoch: 4992
+ pct_start: 0.3
+ anneal_strategy: cos
+ cycle_momentum: true
+ base_momentum: 0.85
+ max_momentum: 0.95
+ div_factor: 25.0
+ final_div_factor: 10000.0
+ three_phase: true
+ last_epoch: -1
+ verbose: false
-# Non-class arguments
-interval: step
-monitor: val/loss
+ # Non-class arguments
+ interval: step
+ monitor: val/loss
diff --git a/training/conf/network/encoder/pixelcnn_decoder.yaml b/training/conf/network/decoder/pixelcnn_decoder.yaml
index 3895164..cdddb7a 100644
--- a/training/conf/network/encoder/pixelcnn_decoder.yaml
+++ b/training/conf/network/decoder/pixelcnn_decoder.yaml
@@ -1,5 +1,5 @@
_target_: text_recognizer.networks.vqvae.pixelcnn.Decoder
out_channels: 1
hidden_dim: 8
-channels_multipliers: [8, 8, 2, 1]
+channels_multipliers: [8, 2, 1]
dropout_rate: 0.25
diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml
index 0a36a54..a5e7286 100644
--- a/training/conf/network/decoder/vae_decoder.yaml
+++ b/training/conf/network/decoder/vae_decoder.yaml
@@ -1,5 +1,5 @@
_target_: text_recognizer.networks.vqvae.decoder.Decoder
out_channels: 1
hidden_dim: 32
-channels_multipliers: [4, 4, 2, 1]
+channels_multipliers: [8, 8, 4, 1]
dropout_rate: 0.25
diff --git a/training/conf/network/decoder/pixelcnn_encoder.yaml b/training/conf/network/encoder/pixelcnn_encoder.yaml
index 47a130d..f12957b 100644
--- a/training/conf/network/decoder/pixelcnn_encoder.yaml
+++ b/training/conf/network/encoder/pixelcnn_encoder.yaml
@@ -1,5 +1,5 @@
_target_: text_recognizer.networks.vqvae.pixelcnn.Encoder
in_channels: 1
hidden_dim: 8
-channels_multipliers: [1, 2, 8, 8]
+channels_multipliers: [1, 2, 8]
dropout_rate: 0.25
diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml
index dacd389..58e905d 100644
--- a/training/conf/network/encoder/vae_encoder.yaml
+++ b/training/conf/network/encoder/vae_encoder.yaml
@@ -1,5 +1,5 @@
_target_: text_recognizer.networks.vqvae.encoder.Encoder
in_channels: 1
hidden_dim: 32
-channels_multipliers: [1, 2, 4, 4]
+channels_multipliers: [1, 2, 4, 8, 8]
dropout_rate: 0.25
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
index d97e9b6..835d0b7 100644
--- a/training/conf/network/vqvae.yaml
+++ b/training/conf/network/vqvae.yaml
@@ -3,7 +3,7 @@ defaults:
- decoder: vae_decoder
_target_: text_recognizer.networks.vqvae.vqvae.VQVAE
-hidden_dim: 128
+hidden_dim: 256
embedding_dim: 32
num_embeddings: 1024
decay: 0.99
diff --git a/training/conf/network/vqvae_pixelcnn.yaml b/training/conf/network/vqvae_pixelcnn.yaml
index 10200bc..cd850af 100644
--- a/training/conf/network/vqvae_pixelcnn.yaml
+++ b/training/conf/network/vqvae_pixelcnn.yaml
@@ -5,5 +5,5 @@ defaults:
_target_: text_recognizer.networks.vqvae.vqvae.VQVAE
hidden_dim: 64
embedding_dim: 32
-num_embeddings: 512
+num_embeddings: 1024
decay: 0.99
diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml
index 458b116..a6c059d 100644
--- a/training/conf/optimizer/madgrad.yaml
+++ b/training/conf/optimizer/madgrad.yaml
@@ -1,5 +1,8 @@
-_target_: madgrad.MADGRAD
-lr: 3.0e-4
-momentum: 0.9
-weight_decay: 0
-eps: 1.0e-6
+madgrad:
+ _target_: madgrad.MADGRAD
+ lr: 1.0e-3
+ momentum: 0.9
+ weight_decay: 0
+ eps: 1.0e-6
+
+ parameters: network