summaryrefslogtreecommitdiff
path: root/training/conf/experiment/conv_transformer_paragraphs.yaml
blob: cdac387da39dfcad145040d9f1b7c4bdd318b480 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# @package _global_

defaults:
  - override /criterion: cross_entropy
  - override /callbacks: htr
  - override /datamodule: iam_extended_paragraphs
  - override /network: null
  - override /model: lit_transformer
  - override /lr_scheduler: null
  - override /optimizer: null

tags: [paragraphs]
epochs: &epochs 600
num_classes: &num_classes 58
ignore_index: &ignore_index 3
# max_output_len: &max_output_len 682
# summary: [[1, 1, 576, 640], [1, 682]]

logger:
  wandb:
    tags: ${tags}

criterion:
  ignore_index: *ignore_index
  # label_smoothing: 0.05

callbacks:
  stochastic_weight_averaging:
    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
    swa_epoch_start: 0.75
    swa_lrs: 1.0e-5
    annealing_epochs: 10
    annealing_strategy: cos
    device: null

optimizer:
  _target_: adan_pytorch.Adan
  lr: 3.0e-4
  betas: [0.02, 0.08, 0.01]
  weight_decay: 0.02

lr_scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  mode: min
  factor: 0.8
  patience: 10
  threshold: 1.0e-4
  threshold_mode: rel
  cooldown: 0
  min_lr: 1.0e-5
  eps: 1.0e-8
  verbose: false
  interval: epoch
  monitor: val/cer

datamodule:
  batch_size: 2
  train_fraction: 0.95

network:
  _target_: text_recognizer.networks.ConvTransformer
  input_dims: [1, 1, 576, 640]
  hidden_dim: &hidden_dim 128
  num_classes: *num_classes
  pad_index: 3
  encoder:
    _target_: text_recognizer.networks.convnext.ConvNext
    dim: 16
    dim_mults: [1, 2, 4, 8, 8]
    depths: [3, 3, 3, 3, 6]
    downsampling_factors: [[2, 2], [2, 2], [2, 1], [2, 1], [2, 1]]
    attn:
      _target_: text_recognizer.networks.convnext.TransformerBlock
      attn:
        _target_: text_recognizer.networks.convnext.Attention
        dim: 128
        heads: 4
        dim_head: 64
        scale: 8
      ff:
        _target_: text_recognizer.networks.convnext.FeedForward
        dim: 128
        mult: 4
  decoder:
    _target_: text_recognizer.networks.transformer.Decoder
    depth: 6
    dim: *hidden_dim
    block:
      _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock
      self_attn:
        _target_: text_recognizer.networks.transformer.Attention
        dim: *hidden_dim
        num_heads: 12
        dim_head: 64
        dropout_rate: &dropout_rate 0.2
        causal: true
        rotary_embedding:
          _target_: text_recognizer.networks.transformer.RotaryEmbedding
          dim: 64
      cross_attn:
        _target_: text_recognizer.networks.transformer.Attention
        dim: *hidden_dim
        num_heads: 12
        dim_head: 64
        dropout_rate: *dropout_rate
        causal: false
      norm:
        _target_: text_recognizer.networks.transformer.RMSNorm
        dim: *hidden_dim
      ff:
        _target_: text_recognizer.networks.transformer.FeedForward
        dim: *hidden_dim
        dim_out: null
        expansion_factor: 2
        glu: true
        dropout_rate: *dropout_rate
  pixel_embedding:
    _target_: "text_recognizer.networks.transformer.embeddings.axial.\
      AxialPositionalEmbeddingImage"
    dim: *hidden_dim
    axial_shape: [18, 160]
    axial_dims: [64, 64]
  token_pos_embedding:
    _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
      PositionalEncoding"
    dim: *hidden_dim
    dropout_rate: 0.1
    max_len: 89

trainer:
  gradient_clip_val: 1.0
  max_epochs: *epochs
  accumulate_grad_batches: 8
  limit_train_batches: 1.0
  limit_val_batches: 1.0
  limit_test_batches: 1.0