diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-17 22:44:36 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-17 22:44:36 +0100 |
commit | 29be8113855c80d7c3fb806f030d4e00e62fb3b7 (patch) | |
tree | 2db55f8dc67774d7372dcdebbeea96215440b34e /training/conf/experiment/vqgan.yaml | |
parent | 347a39201af4144e880a4b6f24e1e6bb761ee948 (diff) |
Update configs
Diffstat (limited to 'training/conf/experiment/vqgan.yaml')
-rw-r--r-- | training/conf/experiment/vqgan.yaml | 89 |
1 files changed, 51 insertions, 38 deletions
diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml index 98f3346..726757f 100644 --- a/training/conf/experiment/vqgan.yaml +++ b/training/conf/experiment/vqgan.yaml @@ -1,18 +1,26 @@ +# @package _global_ + defaults: - override /network: vqvae - override /criterion: null - override /model: lit_vqgan - - override /callbacks: wandb_vae + - override /callbacks: vae - override /optimizers: null - override /lr_schedulers: null +epochs: &epochs 100 +ignore_index: &ignore_index 3 +num_classes: &num_classes 58 +max_output_len: &max_output_len 682 +summary: [[1, 1, 576, 640]] + criterion: - _target_: text_recognizer.criterions.vqgan_loss.VQGANLoss + _target_: text_recognizer.criterion.vqgan_loss.VQGANLoss reconstruction_loss: _target_: torch.nn.BCEWithLogitsLoss reduction: mean discriminator: - _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator + _target_: text_recognizer.criterion.n_layer_discriminator.NLayerDiscriminator in_channels: 1 num_channels: 64 num_layers: 3 @@ -21,39 +29,35 @@ criterion: discriminator_factor: 1.0 discriminator_iter_start: 8.0e4 +mapping: &mapping + mapping: + _target_: text_recognizer.data.mappings.emnist.EmnistMapping + extra_symbols: [ "\n" ] + datamodule: - batch_size: 12 - # resize: [288, 320] - augment: false + _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs + batch_size: 4 + num_workers: 12 + train_fraction: 0.9 + pin_memory: true + << : *mapping lr_schedulers: - generator: - _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 3.0e-4 - total_steps: null - epochs: 64 - steps_per_epoch: 1685 - pct_start: 0.3 - anneal_strategy: cos - cycle_momentum: true - base_momentum: 0.85 - max_momentum: 0.95 - div_factor: 25.0 - final_div_factor: 10000.0 - three_phase: true + network: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: *epochs + eta_min: 1.0e-5 last_epoch: -1 - verbose: false - interval: step + interval: epoch monitor: val/loss -# discriminator: -# _target_: torch.optim.lr_scheduler.CosineAnnealingLR -# T_max: 64 -# eta_min: 0.0 -# last_epoch: -1 -# -# interval: epoch -# monitor: val/loss + discriminator: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: *epochs + eta_min: 1.0e-5 + last_epoch: -1 + interval: epoch + monitor: val/loss optimizers: generator: @@ -75,11 +79,20 @@ optimizers: parameters: loss_fn.discriminator trainer: - max_epochs: 64 - # limit_train_batches: 0.1 - # limit_val_batches: 0.1 - # gradient_clip_val: 100 - -# tune: false -# train: true -# test: false + _target_: pytorch_lightning.Trainer + stochastic_weight_avg: false + auto_scale_batch_size: binsearch + auto_lr_find: false + gradient_clip_val: 0 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: *epochs + terminate_on_nan: true + weights_summary: null + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + resume_from_checkpoint: null + accumulate_grad_batches: 2 + overfit_batches: 0 |