summaryrefslogtreecommitdiff
path: root/training/conf/experiment/conv_transformer_lines.yaml
blob: 3f5da865ae1613ac16358df36f8f5c932996cfd7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# @package _global_

defaults:
  - override /criterion: cross_entropy
  - override /callbacks: htr
  - override /datamodule: iam_lines
  - override /network: null
  - override /model: lit_transformer
  - override /lr_scheduler: null
  - override /optimizer: null

tags: [lines]
epochs: &epochs 64
ignore_index: &ignore_index 3
# summary: [[1, 1, 56, 1024], [1, 89]]

logger:
  wandb:
    tags: ${tags}

criterion:
  ignore_index: *ignore_index
  # label_smoothing: 0.05

callbacks:
  stochastic_weight_averaging:
    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
    swa_epoch_start: 0.75
    swa_lrs: 1.0e-5
    annealing_epochs: 10
    annealing_strategy: cos
    device: null

optimizer:
  _target_: adan_pytorch.Adan
  lr: 3.0e-4
  betas: [0.02, 0.08, 0.01]
  weight_decay: 0.02

lr_scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  mode: min
  factor: 0.8
  patience: 10
  threshold: 1.0e-4
  threshold_mode: rel
  cooldown: 0
  min_lr: 1.0e-5
  eps: 1.0e-8
  verbose: false
  interval: epoch
  monitor: val/cer

datamodule:
  batch_size: 16
  train_fraction: 0.95
  transform:
    _target_: text_recognizer.data.stems.line.IamLinesStem
    augment: false

network:
  _target_: text_recognizer.networks.ConvTransformer
  encoder:
    _target_: text_recognizer.networks.image_encoder.ImageEncoder
    encoder:
      _target_: text_recognizer.networks.convnext.ConvNext
      dim: 16
      dim_mults: [2, 4, 32]
      depths: [3, 3, 6]
      downsampling_factors: [[2, 2], [2, 2], [2, 2]]
      attn:
        _target_: text_recognizer.networks.convnext.TransformerBlock
        attn:
          _target_: text_recognizer.networks.convnext.Attention
          dim: &dim 512
          heads: 4
          dim_head: 64
          scale: 8
        ff:
          _target_: text_recognizer.networks.convnext.FeedForward
          dim: *dim
          mult: 2
    pixel_embedding:
      _target_: "text_recognizer.networks.transformer.embeddings.axial.\
        AxialPositionalEmbeddingImage"
      dim: *dim
      axial_shape: [7, 128]
  decoder:
    _target_: text_recognizer.networks.text_decoder.TextDecoder
    dim: *dim
    num_classes: 58
    pad_index: *ignore_index
    decoder:
      _target_: text_recognizer.networks.transformer.Decoder
      dim: *dim
      depth: 6
      block:
        _target_: "text_recognizer.networks.transformer.decoder_block.\
          DecoderBlock"
        self_attn:
          _target_: text_recognizer.networks.transformer.Attention
          dim: *dim
          num_heads: 8
          dim_head: &dim_head 64
          dropout_rate: &dropout_rate 0.2
          causal: true
        cross_attn:
          _target_: text_recognizer.networks.transformer.Attention
          dim: *dim
          num_heads: 8
          dim_head: *dim_head
          dropout_rate: *dropout_rate
          causal: false
        norm:
          _target_: text_recognizer.networks.transformer.RMSNorm
          dim: *dim
        ff:
          _target_: text_recognizer.networks.transformer.FeedForward
          dim: *dim
          dim_out: null
          expansion_factor: 2
          glu: true
          dropout_rate: *dropout_rate
      rotary_embedding:
        _target_: text_recognizer.networks.transformer.RotaryEmbedding
        dim: *dim_head

model:
  max_output_len: 89

trainer:
  gradient_clip_val: 1.0
  max_epochs: *epochs
  accumulate_grad_batches: 1
  limit_train_batches: 1.0
  limit_val_batches: 1.0
  limit_test_batches: 1.0