1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
|
# @package _global_
defaults:
- override /criterion: cross_entropy
- override /callbacks: htr
- override /datamodule: iam_lines
- override /network: null
- override /model: lit_transformer
- override /lr_scheduler: null
- override /optimizer: null
tags: [lines]
epochs: &epochs 260
ignore_index: &ignore_index 3
num_classes: &num_classes 58
max_output_len: &max_output_len 89
dim: &dim 384
# summary: [[1, 1, 56, 1024], [1, 89]]
logger:
wandb:
tags: ${tags}
criterion:
ignore_index: *ignore_index
# label_smoothing: 0.05
callbacks:
stochastic_weight_averaging:
_target_: pytorch_lightning.callbacks.StochasticWeightAveraging
swa_epoch_start: 0.75
swa_lrs: 1.0e-5
annealing_epochs: 10
annealing_strategy: cos
device: null
optimizer:
_target_: adan_pytorch.Adan
lr: 3.0e-4
betas: [0.02, 0.08, 0.01]
weight_decay: 0.02
lr_scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
mode: min
factor: 0.8
patience: 10
threshold: 1.0e-4
threshold_mode: rel
cooldown: 0
min_lr: 1.0e-5
eps: 1.0e-8
verbose: false
interval: epoch
monitor: val/cer
datamodule:
batch_size: 16
train_fraction: 0.95
network:
_target_: text_recognizer.networks.ConvTransformer
encoder:
_target_: text_recognizer.networks.image_encoder.ImageEncoder
encoder:
_target_: text_recognizer.networks.convnext.ConvNext
dim: 16
dim_mults: [2, 4, 24]
depths: [3, 3, 6]
downsampling_factors: [[2, 2], [2, 2], [2, 2]]
attn:
_target_: text_recognizer.networks.convnext.TransformerBlock
attn:
_target_: text_recognizer.networks.convnext.Attention
dim: *dim
heads: 4
dim_head: 64
scale: 8
ff:
_target_: text_recognizer.networks.convnext.FeedForward
dim: *dim
mult: 2
pixel_embedding:
_target_: "text_recognizer.networks.transformer.embeddings.axial.\
AxialPositionalEmbeddingImage"
dim: *dim
axial_shape: [7, 128]
axial_dims: [192, 192]
decoder:
_target_: text_recognizer.networks.text_decoder.TextDecoder
hidden_dim: *dim
num_classes: *num_classes
pad_index: *ignore_index
decoder:
_target_: text_recognizer.networks.transformer.Decoder
dim: *dim
depth: 6
block:
_target_: "text_recognizer.networks.transformer.decoder_block.\
DecoderBlock"
self_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *dim
num_heads: 8
dim_head: 64
dropout_rate: &dropout_rate 0.2
causal: true
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *dim
num_heads: 8
dim_head: 64
dropout_rate: *dropout_rate
causal: false
norm:
_target_: text_recognizer.networks.transformer.RMSNorm
dim: *dim
ff:
_target_: text_recognizer.networks.transformer.FeedForward
dim: *dim
dim_out: null
expansion_factor: 2
glu: true
dropout_rate: *dropout_rate
rotary_embedding:
_target_: text_recognizer.networks.transformer.RotaryEmbedding
dim: 64
model:
max_output_len: *max_output_len
trainer:
gradient_clip_val: 1.0
max_epochs: *epochs
accumulate_grad_batches: 1
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
|