1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
|
# @package _global_
defaults:
- override /criterion: cross_entropy
- override /callbacks: htr
- override /datamodule: iam_lines
- override /network: null
- override /model: lit_transformer
- override /lr_scheduler: null
- override /optimizer: null
tags: [lines]
epochs: &epochs 64
ignore_index: &ignore_index 3
# summary: [[1, 1, 56, 1024], [1, 89]]
logger:
wandb:
tags: ${tags}
criterion:
ignore_index: *ignore_index
# label_smoothing: 0.05
callbacks:
stochastic_weight_averaging:
_target_: pytorch_lightning.callbacks.StochasticWeightAveraging
swa_epoch_start: 0.75
swa_lrs: 1.0e-5
annealing_epochs: 10
annealing_strategy: cos
device: null
optimizer:
_target_: adan_pytorch.Adan
lr: 3.0e-4
betas: [0.02, 0.08, 0.01]
weight_decay: 0.02
lr_scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
mode: min
factor: 0.8
patience: 10
threshold: 1.0e-4
threshold_mode: rel
cooldown: 0
min_lr: 1.0e-5
eps: 1.0e-8
verbose: false
interval: epoch
monitor: val/cer
datamodule:
batch_size: 16
train_fraction: 0.95
transform:
_target_: text_recognizer.data.stems.line.IamLinesStem
augment: false
network:
_target_: text_recognizer.networks.ConvTransformer
encoder:
_target_: text_recognizer.networks.image_encoder.ImageEncoder
encoder:
_target_: text_recognizer.networks.convnext.ConvNext
dim: 16
dim_mults: [2, 4, 32]
depths: [3, 3, 6]
downsampling_factors: [[2, 2], [2, 2], [2, 2]]
attn:
_target_: text_recognizer.networks.convnext.TransformerBlock
attn:
_target_: text_recognizer.networks.convnext.Attention
dim: &dim 512
heads: 4
dim_head: 64
scale: 8
ff:
_target_: text_recognizer.networks.convnext.FeedForward
dim: *dim
mult: 2
pixel_embedding:
_target_: "text_recognizer.networks.transformer.embeddings.axial.\
AxialPositionalEmbeddingImage"
dim: *dim
axial_shape: [7, 128]
decoder:
_target_: text_recognizer.networks.text_decoder.TextDecoder
dim: *dim
num_classes: 58
pad_index: *ignore_index
decoder:
_target_: text_recognizer.networks.transformer.Decoder
dim: *dim
depth: 6
block:
_target_: "text_recognizer.networks.transformer.decoder_block.\
DecoderBlock"
self_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *dim
num_heads: 8
dim_head: &dim_head 64
dropout_rate: &dropout_rate 0.2
causal: true
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *dim
num_heads: 8
dim_head: *dim_head
dropout_rate: *dropout_rate
causal: false
norm:
_target_: text_recognizer.networks.transformer.RMSNorm
dim: *dim
ff:
_target_: text_recognizer.networks.transformer.FeedForward
dim: *dim
dim_out: null
expansion_factor: 2
glu: true
dropout_rate: *dropout_rate
rotary_embedding:
_target_: text_recognizer.networks.transformer.RotaryEmbedding
dim: *dim_head
model:
max_output_len: 89
trainer:
gradient_clip_val: 1.0
max_epochs: *epochs
accumulate_grad_batches: 1
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
|