summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml108
1 files changed, 60 insertions, 48 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index cca452d..259e4ea 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -1,3 +1,4 @@
+---
# @package _global_
defaults:
@@ -18,8 +19,8 @@ summary: [[1, 1, 56, 1024], [1, 89]]
criterion:
ignore_index: *ignore_index
- label_smoothing: 0.1
-
+ # label_smoothing: 0.1
+
mapping: &mapping
mapping:
_target_: text_recognizer.data.mappings.emnist.EmnistMapping
@@ -34,20 +35,26 @@ callbacks:
device: null
optimizers:
- madgrad:
- _target_: madgrad.MADGRAD
- lr: 1.0e-4
- momentum: 0.9
+ radam:
+ _target_: torch.optim.RAdam
+ lr: 3.0e-4
+ betas: [0.9, 0.999]
weight_decay: 0
- eps: 1.0e-6
+ eps: 1.0e-8
parameters: network
lr_schedulers:
network:
- _target_: torch.optim.lr_scheduler.CosineAnnealingLR
- T_max: *epochs
- eta_min: 1.0e-6
- last_epoch: -1
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ mode: min
+ factor: 0.5
+ patience: 10
+ threshold: 1.0e-4
+ threshold_mode: rel
+ cooldown: 0
+ min_lr: 1.0e-5
+ eps: 1.0e-8
+ verbose: false
interval: epoch
monitor: val/loss
@@ -58,33 +65,32 @@ datamodule:
pin_memory: true
<< : *mapping
+encoder: &encoder
+ _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet
+ arch: b0
+ stochastic_dropout_rate: 0.2
+ bn_momentum: 0.99
+ bn_eps: 1.0e-3
+ depth: 5
+
rotary_embedding: &rotary_embedding
- rotary_embedding:
- _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
+ rotary_embedding:
+ _target_: >
+ text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
dim: 64
attn: &attn
- dim: &hidden_dim 128
+ dim: &hidden_dim 512
num_heads: 4
dim_head: 64
- dropout_rate: &dropout_rate 0.2
+ dropout_rate: &dropout_rate 0.4
-network:
- _target_: text_recognizer.networks.conv_transformer.ConvTransformer
- input_dims: [1, 56, 1024]
- hidden_dim: *hidden_dim
- num_classes: *num_classes
- pad_index: *ignore_index
- encoder:
- _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet
- arch: b0
- depth: 5
- stochastic_dropout_rate: 0.2
- bn_momentum: 0.99
- bn_eps: 1.0e-3
- decoder:
- depth: 3
- _target_: text_recognizer.networks.transformer.layers.Decoder
+decoder: &decoder
+ _target_: text_recognizer.networks.transformer.decoder.Decoder
+ depth: 6
+ has_pos_emb: true
+ block:
+ _target_: text_recognizer.networks.transformer.decoder.DecoderBlock
self_attn:
_target_: text_recognizer.networks.transformer.attention.Attention
<< : *attn
@@ -95,28 +101,34 @@ network:
<< : *attn
causal: false
norm:
- _target_: text_recognizer.networks.transformer.norm.ScaleNorm
- normalized_shape: *hidden_dim
- ff:
+ _target_: text_recognizer.networks.transformer.norm.RMSNorm
+ dim: *hidden_dim
+ ff:
_target_: text_recognizer.networks.transformer.mlp.FeedForward
dim: *hidden_dim
dim_out: null
- expansion_factor: 4
+ expansion_factor: 2
glu: true
dropout_rate: *dropout_rate
- pre_norm: true
+
+pixel_pos_embedding: &pixel_pos_embedding
+ _target_: >
+ text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
+ dim: *hidden_dim
+ shape: &shape [3, 64]
+
+network:
+ _target_: text_recognizer.networks.conv_transformer.ConvTransformer
+ input_dims: [1, 1, 56, 1024]
+ hidden_dim: *hidden_dim
+ num_classes: *num_classes
+ pad_index: *ignore_index
+ encoder:
+ << : *encoder
+ decoder:
+ << : *decoder
pixel_pos_embedding:
- _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
- dim: *hidden_dim
- shape: &shape [3, 64]
- axial_encoder: null
- # _target_: text_recognizer.networks.transformer.axial_attention.encoder.AxialEncoder
- # dim: *hidden_dim
- # heads: 4
- # shape: *shape
- # depth: 2
- # dim_head: 64
- # dim_index: 1
+ << : *pixel_pos_embedding
model:
_target_: text_recognizer.models.transformer.TransformerLitModel
@@ -138,7 +150,7 @@ trainer:
max_epochs: *epochs
terminate_on_nan: true
weights_summary: null
- limit_train_batches: 1.0
+ limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
resume_from_checkpoint: null