summaryrefslogtreecommitdiff
path: root/training/configs/image_transformer.yaml
blob: bedcbb59406005f303a1c873695b2ce7cf57e5fe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
seed: 4711

network:
        desc: null
        type: ImageTransformer
        args:
                encoder:
                        type: null
                        args: null
                num_decoder_layers: 4
                hidden_dim: 256
                num_heads: 4
                expansion_dim: 1024
                dropout_rate: 0.1
                transformer_activation: glu

model:
        desc: null
        type: LitTransformerModel
        args:
                optimizer:
                        type: MADGRAD
                        args:
                                lr: 1.0e-2
                                momentum: 0.9
                                weight_decay: 0
                                eps: 1.0e-6
                lr_scheduler:
                        type: CosineAnnealingLR
                        args:
                                T_max: 512
                criterion:
                        type: CrossEntropyLoss
                        args:
                                weight: None
                                ignore_index: -100
                                reduction: mean
                monitor: val_loss
                mapping: sentence_piece

data:
        desc: null
        type: IAMExtendedParagraphs
        args:
                batch_size: 16
                num_workers: 12
                train_fraction: 0.8
                augment: true

callbacks:
        - type: ModelCheckpoint
          args:
                  monitor: val_loss
                  mode: min
        - type: EarlyStopping
          args:
                  monitor: val_loss
                  mode: min
                  patience: 10

trainer:
        desc: null
        args:
                stochastic_weight_avg: true
                auto_scale_batch_size: binsearch
                gradient_clip_val: 0
                fast_dev_run: false
                gpus: 1
                precision: 16
                max_epochs: 512
                terminate_on_nan: true
                weights_summary: true