diff options
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 102 |
1 files changed, 54 insertions, 48 deletions
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 7a72a1a..afa1785 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -1,3 +1,4 @@ +--- # @package _global_ defaults: @@ -10,7 +11,7 @@ defaults: - override /lr_schedulers: null - override /optimizers: null -epochs: &epochs 600 +epochs: &epochs 200 ignore_index: &ignore_index 3 num_classes: &num_classes 58 max_output_len: &max_output_len 682 @@ -18,12 +19,12 @@ summary: [[1, 1, 576, 640], [1, 682]] criterion: ignore_index: *ignore_index - label_smoothing: 0.1 - + # label_smoothing: 0.1 + mapping: &mapping mapping: _target_: text_recognizer.data.mappings.emnist.EmnistMapping - extra_symbols: [ "\n" ] + extra_symbols: ["\n"] callbacks: stochastic_weight_averaging: @@ -35,24 +36,24 @@ callbacks: device: null optimizers: - madgrad: - _target_: madgrad.MADGRAD + radam: + _target_: torch.optim.RAdam lr: 1.5e-4 - momentum: 0.9 - weight_decay: 0.0 - eps: 1.0e-6 + betas: [0.9, 0.999] + weight_decay: 0 + eps: 1.0e-8 parameters: network lr_schedulers: network: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau mode: min - factor: 0.1 + factor: 0.5 patience: 10 threshold: 1.0e-4 threshold_mode: rel cooldown: 0 - min_lr: 1.0e-5 + min_lr: 1.0e-6 eps: 1.0e-8 verbose: false interval: epoch @@ -66,33 +67,32 @@ datamodule: pin_memory: true << : *mapping +encoder: &encoder + _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet + arch: b0 + stochastic_dropout_rate: 0.2 + bn_momentum: 0.99 + bn_eps: 1.0e-3 + depth: 7 + rotary_embedding: &rotary_embedding - rotary_embedding: - _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding + rotary_embedding: + _target_: > + text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding dim: 64 attn: &attn - dim: &hidden_dim 256 + dim: &hidden_dim 512 num_heads: 4 dim_head: 64 - dropout_rate: &dropout_rate 0.25 + dropout_rate: &dropout_rate 0.4 -network: - _target_: text_recognizer.networks.conv_transformer.ConvTransformer - input_dims: [1, 576, 640] - hidden_dim: *hidden_dim - num_classes: *num_classes - pad_index: *ignore_index - encoder: - _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet - arch: b0 - stochastic_dropout_rate: 0.2 - bn_momentum: 0.99 - bn_eps: 1.0e-3 - depth: 5 - decoder: - depth: 6 - _target_: text_recognizer.networks.transformer.layers.Decoder +decoder: &decoder + _target_: text_recognizer.networks.transformer.decoder.Decoder + depth: 6 + has_pos_emb: true + block: + _target_: text_recognizer.networks.transformer.decoder.DecoderBlock self_attn: _target_: text_recognizer.networks.transformer.attention.Attention << : *attn @@ -103,28 +103,34 @@ network: << : *attn causal: false norm: - _target_: text_recognizer.networks.transformer.norm.ScaleNorm - normalized_shape: *hidden_dim - ff: + _target_: text_recognizer.networks.transformer.norm.RMSNorm + dim: *hidden_dim + ff: _target_: text_recognizer.networks.transformer.mlp.FeedForward dim: *hidden_dim dim_out: null - expansion_factor: 4 + expansion_factor: 2 glu: true dropout_rate: *dropout_rate - pre_norm: true + +pixel_pos_embedding: &pixel_pos_embedding + _target_: > + text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding + dim: *hidden_dim + shape: &shape [18, 20] + +network: + _target_: text_recognizer.networks.conv_transformer.ConvTransformer + input_dims: [1, 1, 576, 640] + hidden_dim: *hidden_dim + num_classes: *num_classes + pad_index: *ignore_index + encoder: + << : *encoder + decoder: + << : *decoder pixel_pos_embedding: - _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding - dim: *hidden_dim - shape: &shape [36, 40] - axial_encoder: null - # _target_: text_recognizer.networks.transformer.axial_attention.encoder.AxialEncoder - # dim: *hidden_dim - # heads: 4 - # shape: *shape - # depth: 2 - # dim_head: 64 - # dim_index: 1 + << : *pixel_pos_embedding model: _target_: text_recognizer.models.transformer.TransformerLitModel @@ -146,7 +152,7 @@ trainer: max_epochs: *epochs terminate_on_nan: true weights_summary: null - limit_train_batches: 1.0 + limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 resume_from_checkpoint: null |