summaryrefslogtreecommitdiff
path: root/training/conf/experiment/conv_transformer_paragraphs.yaml
blob: 859117fcfee1ffba9c74f3ea6bc4e5a6a86d73ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# @package _global_

defaults:
  - override /mapping: null
  - override /criterion: cross_entropy
  - override /callbacks: htr
  - override /datamodule: iam_extended_paragraphs
  - override /network: null
  - override /model: null
  - override /lr_schedulers: null
  - override /optimizers: null

epochs: &epochs 600
ignore_index: &ignore_index 3
num_classes: &num_classes 58
max_output_len: &max_output_len 682
summary: [[1, 1, 576, 640], [1, 682]]

criterion:
  ignore_index: *ignore_index
  label_smoothing: 0.1
    
mapping: &mapping
  mapping:
    _target_: text_recognizer.data.mappings.emnist.EmnistMapping
    extra_symbols: [ "\n" ]

callbacks:
  stochastic_weight_averaging:
    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
    swa_epoch_start: 0.75
    swa_lrs: 1.0e-5
    annealing_epochs: 10
    annealing_strategy: cos
    device: null

optimizers:
  madgrad:
    _target_: madgrad.MADGRAD
    lr: 1.5e-4
    momentum: 0.9
    weight_decay: 0.0
    eps: 1.0e-6
    parameters: network

lr_schedulers:
  network:
    _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
    mode: min
    factor: 0.1
    patience: 10
    threshold: 1.0e-4
    threshold_mode: rel
    cooldown: 0
    min_lr: 1.0e-5
    eps: 1.0e-8
    verbose: false
    interval: epoch
    monitor: val/loss

datamodule:
  _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
  batch_size: 6
  num_workers: 12
  train_fraction: 0.8
  pin_memory: true
  << : *mapping

rotary_embedding: &rotary_embedding
  rotary_embedding: 
    _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
    dim: 64

attn: &attn
  dim: &hidden_dim 256
  num_heads: 4
  dim_head: 64
  dropout_rate: &dropout_rate 0.25

network:
  _target_: text_recognizer.networks.conv_transformer.ConvTransformer
  input_dims: [1, 576, 640]
  hidden_dim: *hidden_dim
  num_classes: *num_classes
  pad_index: *ignore_index
  encoder:
    _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet
    arch: b0
    stochastic_dropout_rate: 0.2
    bn_momentum: 0.99
    bn_eps: 1.0e-3
    depth: 7
  decoder:
    depth: 6
    _target_: text_recognizer.networks.transformer.layers.Decoder
    self_attn:
      _target_: text_recognizer.networks.transformer.attention.Attention
      << : *attn
      causal: true
      << : *rotary_embedding
    cross_attn:
      _target_: text_recognizer.networks.transformer.attention.Attention
      << : *attn
      causal: false
    norm:
      _target_: text_recognizer.networks.transformer.norm.ScaleNorm
      normalized_shape: *hidden_dim
    ff: 
      _target_: text_recognizer.networks.transformer.mlp.FeedForward
      dim: *hidden_dim
      dim_out: null
      expansion_factor: 4
      glu: true
      dropout_rate: *dropout_rate
    pre_norm: true
  pixel_pos_embedding:
    _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
    dim: *hidden_dim
    shape: &shape [18, 20]
  axial_encoder:
    _target_: text_recognizer.networks.transformer.axial_attention.encoder.AxialEncoder
    dim: *hidden_dim
    heads: 4
    shape: *shape
    depth: 2
    dim_head: 64
    dim_index: 1

model:
  _target_: text_recognizer.models.transformer.TransformerLitModel
  << : *mapping
  max_output_len: *max_output_len
  start_token: <s>
  end_token: <e>
  pad_token: <p>

trainer:
  _target_: pytorch_lightning.Trainer
  stochastic_weight_avg: true
  auto_scale_batch_size: binsearch
  auto_lr_find: false
  gradient_clip_val: 0.5
  fast_dev_run: false
  gpus: 1
  precision: 16
  max_epochs: *epochs
  terminate_on_nan: true
  weights_summary: null
  limit_train_batches: 1.0 
  limit_val_batches: 1.0
  limit_test_batches: 1.0
  resume_from_checkpoint: null
  accumulate_grad_batches: 2
  overfit_batches: 0