summaryrefslogtreecommitdiff
path: root/training/conf/experiment/conv_transformer_lines.yaml
blob: 2631e81bff58c177d3817fbd6df7415c08e4cf25 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# @package _global_

defaults:
  - override /criterion: cross_entropy
  - override /callbacks: htr
  - override /datamodule: iam_lines
  - override /network: null
  # - override /network: conv_transformer
  - override /model: lit_transformer
  - override /lr_scheduler: null
  - override /optimizer: null

tags: [lines]
epochs: &epochs 260
ignore_index: &ignore_index 3
num_classes: &num_classes 57
max_output_len: &max_output_len 89
# summary: [[1, 1, 56, 1024], [1, 89]]

logger:
  wandb:
    tags: ${tags}

criterion:
  ignore_index: *ignore_index
  # label_smoothing: 0.1

callbacks:
  stochastic_weight_averaging:
    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
    swa_epoch_start: 0.75
    swa_lrs: 1.0e-5
    annealing_epochs: 10
    annealing_strategy: cos
    device: null

optimizer:
  _target_: adan_pytorch.Adan
  lr: 3.0e-4
  betas: [0.02, 0.08, 0.01]
  weight_decay: 0.02

lr_scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  mode: min
  factor: 0.8
  patience: 10
  threshold: 1.0e-4
  threshold_mode: rel
  cooldown: 0
  min_lr: 1.0e-5
  eps: 1.0e-8
  verbose: false
  interval: epoch
  monitor: val/cer

datamodule:
  batch_size: 8
  train_fraction: 0.95

network:
  _target_: text_recognizer.networks.ConvTransformer
  input_dims: [1, 1, 56, 1024]
  hidden_dim: &hidden_dim 128
  num_classes: 58
  pad_index: 3
  encoder:
    _target_: text_recognizer.networks.convnext.ConvNext
    dim: 16
    dim_mults: [2, 4, 8]
    depths: [3, 3, 6]
    downsampling_factors: [[2, 2], [2, 2], [2, 2]]
    attn:
      _target_: text_recognizer.networks.convnext.TransformerBlock
      attn:
        _target_: text_recognizer.networks.convnext.Attention
        dim: 128
        heads: 4
        dim_head: 64
        scale: 8
      ff:
        _target_: text_recognizer.networks.convnext.FeedForward
        dim: 128
        mult: 4
  decoder:
    _target_: text_recognizer.networks.transformer.Decoder
    depth: 6
    block:
      _target_: text_recognizer.networks.transformer.DecoderBlock
      self_attn:
        _target_: text_recognizer.networks.transformer.Attention
        dim: *hidden_dim
        num_heads: 12
        dim_head: 64
        dropout_rate: &dropout_rate 0.2
        causal: true
        rotary_embedding:
          _target_: text_recognizer.networks.transformer.RotaryEmbedding
          dim: 64
      cross_attn:
        _target_: text_recognizer.networks.transformer.Attention
        dim: *hidden_dim
        num_heads: 12
        dim_head: 64
        dropout_rate: *dropout_rate
        causal: false
      norm:
        _target_: text_recognizer.networks.transformer.RMSNorm
        dim: *hidden_dim
      ff:
        _target_: text_recognizer.networks.transformer.FeedForward
        dim: *hidden_dim
        dim_out: null
        expansion_factor: 2
        glu: true
        dropout_rate: *dropout_rate
  pixel_embedding:
    _target_: "text_recognizer.networks.transformer.embeddings.axial.\
      AxialPositionalEmbeddingImage"
    dim: *hidden_dim
    axial_shape: [7, 128]
    axial_dims: [64, 64]
  token_pos_embedding:
    _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
      PositionalEncoding"
    dim: *hidden_dim
    dropout_rate: 0.1
    max_len: 89

model:
  max_output_len: *max_output_len

trainer:
  gradient_clip_val: 1.0
  max_epochs: *epochs
  accumulate_grad_batches: 1