summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml102
1 files changed, 54 insertions, 48 deletions
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index 7a72a1a..afa1785 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -1,3 +1,4 @@
+---
# @package _global_
defaults:
@@ -10,7 +11,7 @@ defaults:
- override /lr_schedulers: null
- override /optimizers: null
-epochs: &epochs 600
+epochs: &epochs 200
ignore_index: &ignore_index 3
num_classes: &num_classes 58
max_output_len: &max_output_len 682
@@ -18,12 +19,12 @@ summary: [[1, 1, 576, 640], [1, 682]]
criterion:
ignore_index: *ignore_index
- label_smoothing: 0.1
-
+ # label_smoothing: 0.1
+
mapping: &mapping
mapping:
_target_: text_recognizer.data.mappings.emnist.EmnistMapping
- extra_symbols: [ "\n" ]
+ extra_symbols: ["\n"]
callbacks:
stochastic_weight_averaging:
@@ -35,24 +36,24 @@ callbacks:
device: null
optimizers:
- madgrad:
- _target_: madgrad.MADGRAD
+ radam:
+ _target_: torch.optim.RAdam
lr: 1.5e-4
- momentum: 0.9
- weight_decay: 0.0
- eps: 1.0e-6
+ betas: [0.9, 0.999]
+ weight_decay: 0
+ eps: 1.0e-8
parameters: network
lr_schedulers:
network:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
mode: min
- factor: 0.1
+ factor: 0.5
patience: 10
threshold: 1.0e-4
threshold_mode: rel
cooldown: 0
- min_lr: 1.0e-5
+ min_lr: 1.0e-6
eps: 1.0e-8
verbose: false
interval: epoch
@@ -66,33 +67,32 @@ datamodule:
pin_memory: true
<< : *mapping
+encoder: &encoder
+ _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet
+ arch: b0
+ stochastic_dropout_rate: 0.2
+ bn_momentum: 0.99
+ bn_eps: 1.0e-3
+ depth: 7
+
rotary_embedding: &rotary_embedding
- rotary_embedding:
- _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
+ rotary_embedding:
+ _target_: >
+ text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
dim: 64
attn: &attn
- dim: &hidden_dim 256
+ dim: &hidden_dim 512
num_heads: 4
dim_head: 64
- dropout_rate: &dropout_rate 0.25
+ dropout_rate: &dropout_rate 0.4
-network:
- _target_: text_recognizer.networks.conv_transformer.ConvTransformer
- input_dims: [1, 576, 640]
- hidden_dim: *hidden_dim
- num_classes: *num_classes
- pad_index: *ignore_index
- encoder:
- _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet
- arch: b0
- stochastic_dropout_rate: 0.2
- bn_momentum: 0.99
- bn_eps: 1.0e-3
- depth: 5
- decoder:
- depth: 6
- _target_: text_recognizer.networks.transformer.layers.Decoder
+decoder: &decoder
+ _target_: text_recognizer.networks.transformer.decoder.Decoder
+ depth: 6
+ has_pos_emb: true
+ block:
+ _target_: text_recognizer.networks.transformer.decoder.DecoderBlock
self_attn:
_target_: text_recognizer.networks.transformer.attention.Attention
<< : *attn
@@ -103,28 +103,34 @@ network:
<< : *attn
causal: false
norm:
- _target_: text_recognizer.networks.transformer.norm.ScaleNorm
- normalized_shape: *hidden_dim
- ff:
+ _target_: text_recognizer.networks.transformer.norm.RMSNorm
+ dim: *hidden_dim
+ ff:
_target_: text_recognizer.networks.transformer.mlp.FeedForward
dim: *hidden_dim
dim_out: null
- expansion_factor: 4
+ expansion_factor: 2
glu: true
dropout_rate: *dropout_rate
- pre_norm: true
+
+pixel_pos_embedding: &pixel_pos_embedding
+ _target_: >
+ text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
+ dim: *hidden_dim
+ shape: &shape [18, 20]
+
+network:
+ _target_: text_recognizer.networks.conv_transformer.ConvTransformer
+ input_dims: [1, 1, 576, 640]
+ hidden_dim: *hidden_dim
+ num_classes: *num_classes
+ pad_index: *ignore_index
+ encoder:
+ << : *encoder
+ decoder:
+ << : *decoder
pixel_pos_embedding:
- _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
- dim: *hidden_dim
- shape: &shape [36, 40]
- axial_encoder: null
- # _target_: text_recognizer.networks.transformer.axial_attention.encoder.AxialEncoder
- # dim: *hidden_dim
- # heads: 4
- # shape: *shape
- # depth: 2
- # dim_head: 64
- # dim_index: 1
+ << : *pixel_pos_embedding
model:
_target_: text_recognizer.models.transformer.TransformerLitModel
@@ -146,7 +152,7 @@ trainer:
max_epochs: *epochs
terminate_on_nan: true
weights_summary: null
- limit_train_batches: 1.0
+ limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
resume_from_checkpoint: null