diff options
Diffstat (limited to 'training/conf')
-rw-r--r-- | training/conf/callbacks/wandb_code.yaml | 2 | ||||
-rw-r--r-- | training/conf/callbacks/wandb_config.yaml | 2 | ||||
-rw-r--r-- | training/conf/callbacks/wandb_htr.yaml | 2 | ||||
-rw-r--r-- | training/conf/callbacks/wandb_vae.yaml | 2 | ||||
-rw-r--r-- | training/conf/criterion/vqgan_loss.yaml | 2 | ||||
-rw-r--r-- | training/conf/experiment/vqgan.yaml | 40 | ||||
-rw-r--r-- | training/conf/lr_schedulers/one_cycle.yaml | 2 | ||||
-rw-r--r-- | training/conf/network/decoder/vae_decoder.yaml | 4 | ||||
-rw-r--r-- | training/conf/network/encoder/vae_encoder.yaml | 4 | ||||
-rw-r--r-- | training/conf/network/vqvae.yaml | 6 |
10 files changed, 39 insertions, 27 deletions
diff --git a/training/conf/callbacks/wandb_code.yaml b/training/conf/callbacks/wandb_code.yaml deleted file mode 100644 index 012cdce..0000000 --- a/training/conf/callbacks/wandb_code.yaml +++ /dev/null @@ -1,2 +0,0 @@ -upload_code_as_artifact: - _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact diff --git a/training/conf/callbacks/wandb_config.yaml b/training/conf/callbacks/wandb_config.yaml new file mode 100644 index 0000000..747a7c6 --- /dev/null +++ b/training/conf/callbacks/wandb_config.yaml @@ -0,0 +1,2 @@ +upload_code_as_artifact: + _target_: callbacks.wandb_callbacks.UploadConfigAsArtifact diff --git a/training/conf/callbacks/wandb_htr.yaml b/training/conf/callbacks/wandb_htr.yaml index 44adb71..f8c1ef7 100644 --- a/training/conf/callbacks/wandb_htr.yaml +++ b/training/conf/callbacks/wandb_htr.yaml @@ -1,6 +1,6 @@ defaults: - default - wandb_watch - - wandb_code + - wandb_config - wandb_checkpoints - wandb_htr_predictions diff --git a/training/conf/callbacks/wandb_vae.yaml b/training/conf/callbacks/wandb_vae.yaml index c7b09b0..ffc467f 100644 --- a/training/conf/callbacks/wandb_vae.yaml +++ b/training/conf/callbacks/wandb_vae.yaml @@ -3,4 +3,4 @@ defaults: - wandb_watch - wandb_checkpoints - wandb_image_reconstructions - # - wandb_code + - wandb_config diff --git a/training/conf/criterion/vqgan_loss.yaml b/training/conf/criterion/vqgan_loss.yaml index a1c886e..f983f6f 100644 --- a/training/conf/criterion/vqgan_loss.yaml +++ b/training/conf/criterion/vqgan_loss.yaml @@ -1,6 +1,6 @@ _target_: text_recognizer.criterions.vqgan_loss.VQGANLoss reconstruction_loss: - _target_: torch.nn.L1Loss + _target_: torch.nn.MSELoss reduction: mean discriminator: _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml index 40af15a..34d8f84 100644 --- a/training/conf/experiment/vqgan.yaml +++ b/training/conf/experiment/vqgan.yaml @@ -16,29 +16,41 @@ criterion: discriminator: _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator in_channels: 1 - num_channels: 32 + num_channels: 64 num_layers: 3 - vq_loss_weight: 0.8 - discriminator_weight: 0.8 + vq_loss_weight: 0.25 + discriminator_weight: 1.0 discriminator_factor: 1.0 - discriminator_iter_start: 2e4 + discriminator_iter_start: 2.0e4 datamodule: - batch_size: 8 + batch_size: 12 lr_schedulers: generator: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: 256 - eta_min: 0.0 + _target_: torch.optim.lr_scheduler.OneCycleLR + max_lr: 3.0e-4 + total_steps: null + epochs: 64 + steps_per_epoch: 1685 + pct_start: 0.1 + anneal_strategy: cos + cycle_momentum: true + base_momentum: 0.85 + max_momentum: 0.95 + div_factor: 1.0e2 + final_div_factor: 1.0e4 + three_phase: true last_epoch: -1 + verbose: false - interval: epoch + # Non-class arguments + interval: step monitor: val/loss discriminator: _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: 256 + T_max: 64 eta_min: 0.0 last_epoch: -1 @@ -48,10 +60,10 @@ lr_schedulers: optimizers: generator: _target_: madgrad.MADGRAD - lr: 4.5e-6 + lr: 1.0e-4 momentum: 0.5 weight_decay: 0 - eps: 1.0e-6 + eps: 1.0e-7 parameters: network @@ -65,7 +77,7 @@ optimizers: parameters: loss_fn.discriminator trainer: - max_epochs: 256 - # gradient_clip_val: 0.25 + max_epochs: 64 + # gradient_clip_val: 1.0e1 summary: null diff --git a/training/conf/lr_schedulers/one_cycle.yaml b/training/conf/lr_schedulers/one_cycle.yaml index c60577a..801a01f 100644 --- a/training/conf/lr_schedulers/one_cycle.yaml +++ b/training/conf/lr_schedulers/one_cycle.yaml @@ -1,4 +1,4 @@ -onc_cycle: +one_cycle: _target_: torch.optim.lr_scheduler.OneCycleLR max_lr: 1.0e-3 total_steps: null diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml index a5e7286..7558ff0 100644 --- a/training/conf/network/decoder/vae_decoder.yaml +++ b/training/conf/network/decoder/vae_decoder.yaml @@ -1,5 +1,5 @@ _target_: text_recognizer.networks.vqvae.decoder.Decoder out_channels: 1 hidden_dim: 32 -channels_multipliers: [8, 8, 4, 1] -dropout_rate: 0.25 +channels_multipliers: [4, 4, 2, 1] +dropout_rate: 0.0 diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml index 099c36a..b32f425 100644 --- a/training/conf/network/encoder/vae_encoder.yaml +++ b/training/conf/network/encoder/vae_encoder.yaml @@ -1,5 +1,5 @@ _target_: text_recognizer.networks.vqvae.encoder.Encoder in_channels: 1 hidden_dim: 32 -channels_multipliers: [1, 4, 8, 8] -dropout_rate: 0.25 +channels_multipliers: [1, 2, 4, 4] +dropout_rate: 0.0 diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index 835d0b7..936e575 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -3,7 +3,7 @@ defaults: - decoder: vae_decoder _target_: text_recognizer.networks.vqvae.vqvae.VQVAE -hidden_dim: 256 -embedding_dim: 32 -num_embeddings: 1024 +hidden_dim: 128 +embedding_dim: 64 +num_embeddings: 2048 decay: 0.99 |