summaryrefslogtreecommitdiff
path: root/training/conf/experiment/cnn_transformer_paragraphs.yaml
blob: b415c29d59549f0ae06ac47ed308de50a95c081e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# @package _global_

defaults:
  - override /mapping: null
  - override /criterion: null
  - override /datamodule: null
  - override /network: null
  - override /model: null
  - override /lr_schedulers: null
  - override /optimizers: null


epochs: &epochs 512
ignore_index: &ignore_index 3
num_classes: &num_classes 58
max_output_len: &max_output_len 682
summary: [[1, 1, 576, 640], [1, 682]]

criterion:
  _target_: torch.nn.CrossEntropyLoss
  ignore_index: *ignore_index
    
mapping:
  _target_: text_recognizer.data.emnist_mapping.EmnistMapping
  extra_symbols: [ "\n" ]

callbacks:
  stochastic_weight_averaging:
    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
    swa_epoch_start: 0.75
    swa_lrs: 1.0e-5
    annealing_epochs: 10
    annealing_strategy: cos
    device: null

optimizers:
  madgrad:
    _target_: madgrad.MADGRAD
    lr: 3.0e-4
    momentum: 0.9
    weight_decay: 0
    eps: 1.0e-6

    parameters: network

lr_schedulers:
  network:
    _target_: torch.optim.lr_scheduler.OneCycleLR
    max_lr: 3.0e-4
    total_steps: null
    epochs: *epochs
    steps_per_epoch: 52
    pct_start: 0.1
    anneal_strategy: cos
    cycle_momentum: true
    base_momentum: 0.85
    max_momentum: 0.95
    div_factor: 25
    final_div_factor: 1.0e4
    three_phase: false
    last_epoch: -1
    verbose: false
    # Non-class arguments
    interval: step
    monitor: val/loss

datamodule:
  _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
  batch_size: 4
  num_workers: 12
  train_fraction: 0.8
  augment: true
  pin_memory: false
  word_pieces: false
  resize: null

network:
  _target_: text_recognizer.networks.conv_transformer.ConvTransformer
  input_dims: [1, 56, 1024]
  hidden_dim: &hidden_dim 128
  encoder_dim: 1280
  dropout_rate: 0.2
  num_classes: *num_classes
  pad_index: *ignore_index
  encoder:
    _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
    arch: b0
    out_channels: 1280
    stochastic_dropout_rate: 0.2
    bn_momentum: 0.99
    bn_eps: 1.0e-3
  decoder:
    _target_: text_recognizer.networks.transformer.Decoder
    dim: *hidden_dim
    depth: 3 
    num_heads: 4
    attn_fn: text_recognizer.networks.transformer.attention.Attention
    attn_kwargs:
      dim_head: 32
      dropout_rate: 0.2
    norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm
    ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
    ff_kwargs:
      dim_out: null
      expansion_factor: 4
      glu: true
      dropout_rate: 0.2
    cross_attend: true
    pre_norm: true
    rotary_emb:
      _target_: text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding
      dim: 32
  pixel_pos_embedding:
    _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D
    hidden_dim: *hidden_dim 
    max_h: 18
    max_w: 20
  token_pos_embedding:
    _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding
    hidden_dim: *hidden_dim 
    dropout_rate: 0.2
    max_len: *max_output_len

model:
  _target_: text_recognizer.models.transformer.TransformerLitModel
  max_output_len: *max_output_len
  start_token: <s>
  end_token: <e>
  pad_token: <p>

trainer:
  _target_: pytorch_lightning.Trainer
  stochastic_weight_avg: true
  auto_scale_batch_size: binsearch
  auto_lr_find: false
  gradient_clip_val: 0.5
  fast_dev_run: false
  gpus: 1
  precision: 16
  max_epochs: *epochs
  terminate_on_nan: true
  weights_summary: null
  limit_train_batches: 1.0 
  limit_val_batches: 1.0
  limit_test_batches: 1.0
  resume_from_checkpoint: null
  accumulate_grad_batches: 32
  overfit_batches: 0