1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
|
# @package _global_
defaults:
- override /criterion: cross_entropy
- override /callbacks: htr
- override /datamodule: iam_extended_paragraphs
- override /network: null
- override /model: lit_transformer
- override /lr_scheduler: null
- override /optimizer: null
tags: [paragraphs]
epochs: &epochs 600
num_classes: &num_classes 58
ignore_index: &ignore_index 3
# max_output_len: &max_output_len 682
# summary: [[1, 1, 576, 640], [1, 682]]
logger:
wandb:
tags: ${tags}
criterion:
ignore_index: *ignore_index
# label_smoothing: 0.05
callbacks:
stochastic_weight_averaging:
_target_: pytorch_lightning.callbacks.StochasticWeightAveraging
swa_epoch_start: 0.75
swa_lrs: 1.0e-5
annealing_epochs: 10
annealing_strategy: cos
device: null
optimizer:
_target_: adan_pytorch.Adan
lr: 3.0e-4
betas: [0.02, 0.08, 0.01]
weight_decay: 0.02
lr_scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
mode: min
factor: 0.8
patience: 10
threshold: 1.0e-4
threshold_mode: rel
cooldown: 0
min_lr: 1.0e-5
eps: 1.0e-8
verbose: false
interval: epoch
monitor: val/cer
datamodule:
batch_size: 2
train_fraction: 0.95
network:
_target_: text_recognizer.networks.ConvTransformer
input_dims: [1, 1, 576, 640]
hidden_dim: &hidden_dim 128
num_classes: *num_classes
pad_index: 3
encoder:
_target_: text_recognizer.networks.convnext.ConvNext
dim: 16
dim_mults: [1, 2, 4, 8, 8]
depths: [3, 3, 3, 3, 6]
downsampling_factors: [[2, 2], [2, 2], [2, 1], [2, 1], [2, 1]]
attn:
_target_: text_recognizer.networks.convnext.TransformerBlock
attn:
_target_: text_recognizer.networks.convnext.Attention
dim: 128
heads: 4
dim_head: 64
scale: 8
ff:
_target_: text_recognizer.networks.convnext.FeedForward
dim: 128
mult: 4
decoder:
_target_: text_recognizer.networks.transformer.Decoder
depth: 6
block:
_target_: text_recognizer.networks.transformer.DecoderBlock
self_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *hidden_dim
num_heads: 12
dim_head: 64
dropout_rate: &dropout_rate 0.2
causal: true
rotary_embedding:
_target_: text_recognizer.networks.transformer.RotaryEmbedding
dim: 64
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *hidden_dim
num_heads: 12
dim_head: 64
dropout_rate: *dropout_rate
causal: false
norm:
_target_: text_recognizer.networks.transformer.RMSNorm
dim: *hidden_dim
ff:
_target_: text_recognizer.networks.transformer.FeedForward
dim: *hidden_dim
dim_out: null
expansion_factor: 2
glu: true
dropout_rate: *dropout_rate
pixel_embedding:
_target_: "text_recognizer.networks.transformer.embeddings.axial.\
AxialPositionalEmbeddingImage"
dim: *hidden_dim
axial_shape: [18, 160]
axial_dims: [64, 64]
token_pos_embedding:
_target_: "text_recognizer.networks.transformer.embeddings.fourier.\
PositionalEncoding"
dim: *hidden_dim
dropout_rate: 0.1
max_len: 89
trainer:
gradient_clip_val: 1.0
max_epochs: *epochs
accumulate_grad_batches: 8
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
|