diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/conf/experiment/vqgan.yaml | 37 | ||||
| -rw-r--r-- | training/conf/network/decoder/vae_decoder.yaml | 4 | ||||
| -rw-r--r-- | training/conf/network/encoder/vae_encoder.yaml | 4 | ||||
| -rw-r--r-- | training/conf/network/vqvae.yaml | 2 | ||||
| -rw-r--r-- | training/conf/optimizers/madgrad.yaml | 2 | ||||
| -rw-r--r-- | training/conf/trainer/default.yaml | 1 | 
6 files changed, 35 insertions, 15 deletions
diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml index 34886ec..572c320 100644 --- a/training/conf/experiment/vqgan.yaml +++ b/training/conf/experiment/vqgan.yaml @@ -11,30 +11,41 @@ defaults:  criterion:    _target_: text_recognizer.criterions.vqgan_loss.VQGANLoss    reconstruction_loss: -    _target_: torch.nn.MSELoss +    _target_: torch.nn.BCEWithLogitsLoss      reduction: mean    discriminator:      _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator      in_channels: 1      num_channels: 64      num_layers: 3 -  vq_loss_weight: 1.0 +  commitment_weight: 0.25    discriminator_weight: 0.8    discriminator_factor: 1.0 -  discriminator_iter_start: 7e4 +  discriminator_iter_start: 8.0e4  datamodule: -  batch_size: 8 +  batch_size: 12    # resize: [288, 320] +  augment: false  lr_schedulers:    generator: -    _target_: torch.optim.lr_scheduler.CosineAnnealingLR -    T_max: 128 -    eta_min: 4.5e-6 +    _target_: torch.optim.lr_scheduler.OneCycleLR +    max_lr: 3.0e-4 +    total_steps: null +    epochs: 64 +    steps_per_epoch: 1685 +    pct_start: 0.3 +    anneal_strategy: cos +    cycle_momentum: true +    base_momentum: 0.85 +    max_momentum: 0.95 +    div_factor: 25.0 +    final_div_factor: 10000.0 +    three_phase: true      last_epoch: -1 - -    interval: epoch +    verbose: false +    interval: step      monitor: val/loss  #   discriminator: @@ -66,7 +77,11 @@ optimizers:      parameters: loss_fn.discriminator  trainer: -  max_epochs: 128 +  max_epochs: 64    # limit_train_batches: 0.1    # limit_val_batches: 0.1 -  gradient_clip_val: 100 +  # gradient_clip_val: 100 + +# tune: false +# train: true +# test: false diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml index 8b5502d..aed5733 100644 --- a/training/conf/network/decoder/vae_decoder.yaml +++ b/training/conf/network/decoder/vae_decoder.yaml @@ -1,7 +1,9 @@  _target_: text_recognizer.networks.vqvae.decoder.Decoder  out_channels: 1   hidden_dim: 32 -channels_multipliers: [4, 4, 2, 1] +channels_multipliers: [4, 2, 1]  dropout_rate: 0.0  activation: mish  use_norm: true +num_residuals: 4 +residual_channels: 32 diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml index 33ae0b9..5d39bf7 100644 --- a/training/conf/network/encoder/vae_encoder.yaml +++ b/training/conf/network/encoder/vae_encoder.yaml @@ -1,7 +1,9 @@  _target_: text_recognizer.networks.vqvae.encoder.Encoder  in_channels: 1   hidden_dim: 32 -channels_multipliers: [1, 2, 4, 4] +channels_multipliers: [1, 2, 4]  dropout_rate: 0.0  activation: mish  use_norm: true +num_residuals: 4 +residual_channels: 32 diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index d97e9b6..8210f04 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -5,5 +5,5 @@ defaults:  _target_: text_recognizer.networks.vqvae.vqvae.VQVAE  hidden_dim: 128  embedding_dim: 32 -num_embeddings: 1024 +num_embeddings: 8192  decay: 0.99 diff --git a/training/conf/optimizers/madgrad.yaml b/training/conf/optimizers/madgrad.yaml index d97bf7e..b6507b9 100644 --- a/training/conf/optimizers/madgrad.yaml +++ b/training/conf/optimizers/madgrad.yaml @@ -1,6 +1,6 @@  madgrad:    _target_: madgrad.MADGRAD -  lr: 3.0e-4 +  lr: 1.0e-4    momentum: 0.9    weight_decay: 0    eps: 1.0e-6 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index c665adc..ef5b018 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -13,3 +13,4 @@ limit_train_batches: 1.0  limit_val_batches: 1.0  limit_test_batches: 1.0  resume_from_checkpoint: null +accumulate_grad_batches: 1  |