summaryrefslogtreecommitdiff
path: root/training/configs
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
commit1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch)
tree5e610ac459c9b254f8826e92372346f01f8e2412 /training/configs
parentffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff)
Fixed training script, able to train vqvae
Diffstat (limited to 'training/configs')
-rw-r--r--training/configs/image_transformer.yaml2
-rw-r--r--training/configs/vqvae.yaml89
2 files changed, 91 insertions, 0 deletions
diff --git a/training/configs/image_transformer.yaml b/training/configs/image_transformer.yaml
index 228e53f..e6637f2 100644
--- a/training/configs/image_transformer.yaml
+++ b/training/configs/image_transformer.yaml
@@ -85,3 +85,5 @@ trainer:
max_epochs: 512
terminate_on_nan: true
weights_summary: true
+
+load_checkpoint: null
diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml
new file mode 100644
index 0000000..90082f7
--- /dev/null
+++ b/training/configs/vqvae.yaml
@@ -0,0 +1,89 @@
+seed: 4711
+
+network:
+ desc: Configuration of the PyTorch neural network.
+ type: VQVAE
+ args:
+ in_channels: 1
+ channels: [32, 64, 96]
+ kernel_sizes: [4, 4, 4]
+ strides: [2, 2, 2]
+ num_residual_layers: 2
+ embedding_dim: 64
+ num_embeddings: 1024
+ upsampling: null
+ beta: 0.25
+ activation: leaky_relu
+ dropout_rate: 0.1
+
+model:
+ desc: Configuration of the PyTorch Lightning model.
+ type: LitVQVAEModel
+ args:
+ optimizer:
+ type: MADGRAD
+ args:
+ lr: 1.0e-3
+ momentum: 0.9
+ weight_decay: 0
+ eps: 1.0e-6
+ lr_scheduler:
+ type: OneCycleLR
+ args:
+ interval: &interval step
+ max_lr: 1.0e-3
+ three_phase: true
+ epochs: 512
+ steps_per_epoch: 317 # num_samples / batch_size
+ criterion:
+ type: MSELoss
+ args:
+ reduction: mean
+ monitor: val_loss
+ mapping: sentence_piece
+
+data:
+ desc: Configuration of the training/test data.
+ type: IAMExtendedParagraphs
+ args:
+ batch_size: 64
+ num_workers: 12
+ train_fraction: 0.8
+ augment: true
+
+callbacks:
+ - type: ModelCheckpoint
+ args:
+ monitor: val_loss
+ mode: min
+ save_last: true
+ # - type: StochasticWeightAveraging
+ # args:
+ # swa_epoch_start: 0.8
+ # swa_lrs: 0.05
+ # annealing_epochs: 10
+ # annealing_strategy: cos
+ # device: null
+ - type: LearningRateMonitor
+ args:
+ logging_interval: *interval
+ - type: EarlyStopping
+ args:
+ monitor: val_loss
+ mode: min
+ patience: 10
+
+trainer:
+ desc: Configuration of the PyTorch Lightning Trainer.
+ args:
+ stochastic_weight_avg: false # true
+ auto_scale_batch_size: binsearch
+ gradient_clip_val: 0
+ fast_dev_run: false
+ gpus: 1
+ precision: 16
+ max_epochs: 512
+ terminate_on_nan: true
+ weights_summary: full
+
+load_checkpoint: null