summaryrefslogtreecommitdiff
path: root/training/conf/experiment/cnn_htr_char_lines.yaml
blob: 0f28ff946da5663cecdc80d16516b544dbfedbd3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# @package _global_

defaults:
  - override /mapping: null
  - override /criterion: null
  - override /datamodule: null
  - override /network: null
  - override /model: null
  - override /lr_schedulers: null
  - override /optimizers: null


criterion:
  _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss
  smoothing: 0.1 
  ignore_index: 1000
    
mapping:
  _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping
  num_features: 1000
  tokens: iamdb_1kwp_tokens_1000.txt
  lexicon: iamdb_1kwp_lex_1000.txt
  data_dir: null
  use_words: false
  prepend_wordsep: false
  special_tokens: [ <s>, <e>, <p> ]
  # _target_: text_recognizer.data.emnist_mapping.EmnistMapping
  # extra_symbols: [ "\n" ]

callbacks:
  stochastic_weight_averaging:
    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
    swa_epoch_start: 0.8
    swa_lrs: 0.05
    annealing_epochs: 10
    annealing_strategy: cos
    device: null

optimizers:
  madgrad:
    _target_: madgrad.MADGRAD
    lr: 1.0e-4
    momentum: 0.9
    weight_decay: 0
    eps: 1.0e-6

    parameters: network

lr_schedulers:
  network:
    _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
    mode: min
    factor: 0.1
    patience: 10
    threshold: 1.0e-4
    threshold_mode: rel
    cooldown: 0
    min_lr: 1.0e-7
    eps: 1.0e-8
    interval: epoch
    monitor: val/loss

datamodule:
  _target_: text_recognizer.data.iam_lines.IAMLines
  batch_size: 16
  num_workers: 12
  train_fraction: 0.8
  augment: true
  pin_memory: false

network:
  _target_: text_recognizer.networks.conv_transformer.ConvTransformer
  input_dims: [1, 56, 1024]
  hidden_dim: 128
  encoder_dim: 1280
  dropout_rate: 0.2
  num_classes: 1006
  pad_index: 1000
  encoder:
    _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
    arch: b0
    out_channels: 1280
    stochastic_dropout_rate: 0.2
    bn_momentum: 0.99
    bn_eps: 1.0e-3
  decoder:
    _target_: text_recognizer.networks.transformer.Decoder
    dim: 128
    depth: 3 
    num_heads: 4
    attn_fn: text_recognizer.networks.transformer.attention.Attention
    attn_kwargs:
      dim_head: 32
      dropout_rate: 0.2
    norm_fn: torch.nn.LayerNorm
    ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
    ff_kwargs:
      dim_out: null
      expansion_factor: 4
      glu: true
      dropout_rate: 0.2
    cross_attend: true
    pre_norm: true
    rotary_emb: null

model:
  _target_: text_recognizer.models.transformer.TransformerLitModel
  max_output_len: 89
  start_token: <s>
  end_token: <e>
  pad_token: <p>

trainer:
  _target_: pytorch_lightning.Trainer
  stochastic_weight_avg: true
  auto_scale_batch_size: binsearch
  auto_lr_find: false
  gradient_clip_val: 0
  fast_dev_run: false
  gpus: 1
  precision: 16
  max_epochs: 1024
  terminate_on_nan: true
  weights_summary: null
  limit_train_batches: 1.0 
  limit_val_batches: 1.0
  limit_test_batches: 1.0
  resume_from_checkpoint: null
  accumulate_grad_batches: 4
  overfit_batches: 0.0

summary: [[1, 1, 56, 1024], [1, 89]]