summaryrefslogtreecommitdiff
path: root/training/configs
diff options
context:
space:
mode:
Diffstat (limited to 'training/configs')
-rw-r--r--training/configs/vqvae.yaml44
1 files changed, 22 insertions, 22 deletions
diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml
index 90082f7..a7acb3a 100644
--- a/training/configs/vqvae.yaml
+++ b/training/configs/vqvae.yaml
@@ -5,16 +5,16 @@ network:
type: VQVAE
args:
in_channels: 1
- channels: [32, 64, 96]
+ channels: [32, 64, 64]
kernel_sizes: [4, 4, 4]
strides: [2, 2, 2]
num_residual_layers: 2
- embedding_dim: 64
- num_embeddings: 1024
+ embedding_dim: 128
+ num_embeddings: 512
upsampling: null
beta: 0.25
activation: leaky_relu
- dropout_rate: 0.1
+ dropout_rate: 0.2
model:
desc: Configuration of the PyTorch Lightning model.
@@ -33,8 +33,8 @@ model:
interval: &interval step
max_lr: 1.0e-3
three_phase: true
- epochs: 512
- steps_per_epoch: 317 # num_samples / batch_size
+ epochs: 64
+ steps_per_epoch: 633 # num_samples / batch_size
criterion:
type: MSELoss
args:
@@ -46,7 +46,7 @@ data:
desc: Configuration of the training/test data.
type: IAMExtendedParagraphs
args:
- batch_size: 64
+ batch_size: 32
num_workers: 12
train_fraction: 0.8
augment: true
@@ -57,33 +57,33 @@ callbacks:
monitor: val_loss
mode: min
save_last: true
- # - type: StochasticWeightAveraging
- # args:
- # swa_epoch_start: 0.8
- # swa_lrs: 0.05
- # annealing_epochs: 10
- # annealing_strategy: cos
- # device: null
+ - type: StochasticWeightAveraging
+ args:
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
- type: LearningRateMonitor
args:
logging_interval: *interval
- - type: EarlyStopping
- args:
- monitor: val_loss
- mode: min
- patience: 10
+ # - type: EarlyStopping
+ # args:
+ # monitor: val_loss
+ # mode: min
+ # patience: 10
trainer:
desc: Configuration of the PyTorch Lightning Trainer.
args:
- stochastic_weight_avg: false # true
+ stochastic_weight_avg: true
auto_scale_batch_size: binsearch
gradient_clip_val: 0
fast_dev_run: false
gpus: 1
precision: 16
- max_epochs: 512
+ max_epochs: 64
terminate_on_nan: true
- weights_summary: full
+ weights_summary: top
load_checkpoint: null