diff options
Diffstat (limited to 'training/conf')
-rw-r--r-- | training/conf/experiment/conv_transformer_lines.yaml | 108 |
1 files changed, 60 insertions, 48 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index cca452d..259e4ea 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -1,3 +1,4 @@ +--- # @package _global_ defaults: @@ -18,8 +19,8 @@ summary: [[1, 1, 56, 1024], [1, 89]] criterion: ignore_index: *ignore_index - label_smoothing: 0.1 - + # label_smoothing: 0.1 + mapping: &mapping mapping: _target_: text_recognizer.data.mappings.emnist.EmnistMapping @@ -34,20 +35,26 @@ callbacks: device: null optimizers: - madgrad: - _target_: madgrad.MADGRAD - lr: 1.0e-4 - momentum: 0.9 + radam: + _target_: torch.optim.RAdam + lr: 3.0e-4 + betas: [0.9, 0.999] weight_decay: 0 - eps: 1.0e-6 + eps: 1.0e-8 parameters: network lr_schedulers: network: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: *epochs - eta_min: 1.0e-6 - last_epoch: -1 + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.5 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 + verbose: false interval: epoch monitor: val/loss @@ -58,33 +65,32 @@ datamodule: pin_memory: true << : *mapping +encoder: &encoder + _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet + arch: b0 + stochastic_dropout_rate: 0.2 + bn_momentum: 0.99 + bn_eps: 1.0e-3 + depth: 5 + rotary_embedding: &rotary_embedding - rotary_embedding: - _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding + rotary_embedding: + _target_: > + text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding dim: 64 attn: &attn - dim: &hidden_dim 128 + dim: &hidden_dim 512 num_heads: 4 dim_head: 64 - dropout_rate: &dropout_rate 0.2 + dropout_rate: &dropout_rate 0.4 -network: - _target_: text_recognizer.networks.conv_transformer.ConvTransformer - input_dims: [1, 56, 1024] - hidden_dim: *hidden_dim - num_classes: *num_classes - pad_index: *ignore_index - encoder: - _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet - arch: b0 - depth: 5 - stochastic_dropout_rate: 0.2 - bn_momentum: 0.99 - bn_eps: 1.0e-3 - decoder: - depth: 3 - _target_: text_recognizer.networks.transformer.layers.Decoder +decoder: &decoder + _target_: text_recognizer.networks.transformer.decoder.Decoder + depth: 6 + has_pos_emb: true + block: + _target_: text_recognizer.networks.transformer.decoder.DecoderBlock self_attn: _target_: text_recognizer.networks.transformer.attention.Attention << : *attn @@ -95,28 +101,34 @@ network: << : *attn causal: false norm: - _target_: text_recognizer.networks.transformer.norm.ScaleNorm - normalized_shape: *hidden_dim - ff: + _target_: text_recognizer.networks.transformer.norm.RMSNorm + dim: *hidden_dim + ff: _target_: text_recognizer.networks.transformer.mlp.FeedForward dim: *hidden_dim dim_out: null - expansion_factor: 4 + expansion_factor: 2 glu: true dropout_rate: *dropout_rate - pre_norm: true + +pixel_pos_embedding: &pixel_pos_embedding + _target_: > + text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding + dim: *hidden_dim + shape: &shape [3, 64] + +network: + _target_: text_recognizer.networks.conv_transformer.ConvTransformer + input_dims: [1, 1, 56, 1024] + hidden_dim: *hidden_dim + num_classes: *num_classes + pad_index: *ignore_index + encoder: + << : *encoder + decoder: + << : *decoder pixel_pos_embedding: - _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding - dim: *hidden_dim - shape: &shape [3, 64] - axial_encoder: null - # _target_: text_recognizer.networks.transformer.axial_attention.encoder.AxialEncoder - # dim: *hidden_dim - # heads: 4 - # shape: *shape - # depth: 2 - # dim_head: 64 - # dim_index: 1 + << : *pixel_pos_embedding model: _target_: text_recognizer.models.transformer.TransformerLitModel @@ -138,7 +150,7 @@ trainer: max_epochs: *epochs terminate_on_nan: true weights_summary: null - limit_train_batches: 1.0 + limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 resume_from_checkpoint: null |