summaryrefslogtreecommitdiff
path: root/training/conf/experiment/conv_transformer_lines.yaml
blob: 3392cd6470bde08b9676810f18afb3810dbb40b3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# @package _global_

defaults:
  - override /criterion: cross_entropy
  - override /callbacks: htr
  - override /datamodule: iam_lines
  - override /network: null
  - override /model: lit_transformer
  - override /lr_scheduler: null
  - override /optimizer: null

tags: [lines]
epochs: &epochs 260
ignore_index: &ignore_index 3
num_classes: &num_classes 58
max_output_len: &max_output_len 89
dim: &dim 384
# summary: [[1, 1, 56, 1024], [1, 89]]

logger:
  wandb:
    tags: ${tags}

criterion:
  ignore_index: *ignore_index
  # label_smoothing: 0.05

callbacks:
  stochastic_weight_averaging:
    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
    swa_epoch_start: 0.75
    swa_lrs: 1.0e-5
    annealing_epochs: 10
    annealing_strategy: cos
    device: null

optimizer:
  _target_: adan_pytorch.Adan
  lr: 3.0e-4
  betas: [0.02, 0.08, 0.01]
  weight_decay: 0.02

lr_scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  mode: min
  factor: 0.8
  patience: 10
  threshold: 1.0e-4
  threshold_mode: rel
  cooldown: 0
  min_lr: 1.0e-5
  eps: 1.0e-8
  verbose: false
  interval: epoch
  monitor: val/cer

datamodule:
  batch_size: 16
  train_fraction: 0.95

network:
  _target_: text_recognizer.networks.ConvTransformer
  encoder:
    _target_: text_recognizer.networks.image_encoder.ImageEncoder
    encoder:
      _target_: text_recognizer.networks.convnext.ConvNext
      dim: 16
      dim_mults: [2, 4, 24]
      depths: [3, 3, 6]
      downsampling_factors: [[2, 2], [2, 2], [2, 2]]
      attn:
        _target_: text_recognizer.networks.convnext.TransformerBlock
        attn:
          _target_: text_recognizer.networks.convnext.Attention
          dim: *dim
          heads: 4
          dim_head: 64
          scale: 8
        ff:
          _target_: text_recognizer.networks.convnext.FeedForward
          dim: *dim
          mult: 2
    pixel_embedding:
      _target_: "text_recognizer.networks.transformer.embeddings.axial.\
        AxialPositionalEmbeddingImage"
      dim: *dim
      axial_shape: [7, 128]
      axial_dims: [192, 192]
  decoder:
    _target_: text_recognizer.networks.text_decoder.TextDecoder
    hidden_dim: *dim
    num_classes: *num_classes
    pad_index: *ignore_index
    decoder:
      _target_: text_recognizer.networks.transformer.Decoder
      dim: *dim
      depth: 6
      block:
        _target_: "text_recognizer.networks.transformer.decoder_block.\
          DecoderBlock"
        self_attn:
          _target_: text_recognizer.networks.transformer.Attention
          dim: *dim
          num_heads: 8
          dim_head: 64
          dropout_rate: &dropout_rate 0.2
          causal: true
        cross_attn:
          _target_: text_recognizer.networks.transformer.Attention
          dim: *dim
          num_heads: 8
          dim_head: 64
          dropout_rate: *dropout_rate
          causal: false
        norm:
          _target_: text_recognizer.networks.transformer.RMSNorm
          dim: *dim
        ff:
          _target_: text_recognizer.networks.transformer.FeedForward
          dim: *dim
          dim_out: null
          expansion_factor: 2
          glu: true
          dropout_rate: *dropout_rate
      rotary_embedding:
        _target_: text_recognizer.networks.transformer.RotaryEmbedding
        dim: 64

model:
  max_output_len: *max_output_len

trainer:
  gradient_clip_val: 1.0
  max_epochs: *epochs
  accumulate_grad_batches: 1
  limit_train_batches: 1.0
  limit_val_batches: 1.0
  limit_test_batches: 1.0