From 532286b516b17d279c321358bf03dddc8adc8029 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 5 Apr 2021 23:14:24 +0200
Subject: Completed first draft for training loop with PyTorch Lightning

---
 training/experiments/image_transformer.yaml | 70 +++++++++++++++++++++++++++++
 1 file changed, 70 insertions(+)
 create mode 100644 training/experiments/image_transformer.yaml

(limited to 'training/experiments')

diff --git a/training/experiments/image_transformer.yaml b/training/experiments/image_transformer.yaml
new file mode 100644
index 0000000..7f0bbb7
--- /dev/null
+++ b/training/experiments/image_transformer.yaml
@@ -0,0 +1,70 @@
+network:
+        type: ImageTransformer
+        args: 
+                input_shape: None
+                output_shape: None
+                encoder:
+                        type: None
+                        args: None
+                mapping: sentence_piece
+                num_decoder_layers: 4
+                hidden_dim: 256
+                num_heads: 4
+                expansion_dim: 1024
+                dropout_rate: 0.1
+                transformer_activation: glu
+
+model:
+        type: LitTransformerModel
+        args:
+                optimizer: 
+                        type: MADGRAD
+                        args:
+                                lr: 1.0e-2
+                                momentum: 0.9
+                                weight_decay: 0
+                                eps: 1.0e-6
+                lr_scheduler: 
+                        type: CosineAnnealingLR
+                        args: 
+                                T_max: 512
+                criterion:
+                        type: CrossEntropyLoss
+                        args: 
+                                weight: None
+                                ignore_index: -100
+                                reduction: mean
+
+                monitor: val_loss
+                mapping: sentence_piece
+
+data:
+        type: IAMExtendedParagraphs
+        args: 
+                batch_size: 16
+                num_workers: 12
+                train_fraction: 0.8
+                augment: true
+
+callbacks:
+        - type: ModelCheckpoint
+          args:
+                  monitor: val_loss
+                  mode: min
+        - type: EarlyStopping
+          args:
+                  monitor: val_loss
+                  mode: min
+                  patience: 10
+
+trainer:
+        args:
+                stochastic_weight_avg: true
+                auto_scale_batch_size: power
+                gradient_clip_val: 0
+                fast_dev_run: false
+                gpus: 1
+                precision: 16
+                max_epocs: 512
+                terminate_on_nan: true
+                weights_summary: true
-- 
cgit v1.2.3-70-g09d2