From 70540bf897df1d60375ea220cfab838cbd28c47f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 5 Nov 2021 19:27:50 +0100 Subject: Update lines config --- .../conf/experiment/conv_transformer_lines.yaml | 51 ++++++++-------------- 1 file changed, 18 insertions(+), 33 deletions(-) (limited to 'training') diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index d2a666f..6ba4535 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -2,7 +2,7 @@ defaults: - override /mapping: null - - override /criterion: null + - override /criterion: cross_entropy - override /callbacks: htr - override /datamodule: iam_lines - override /network: null @@ -10,21 +10,18 @@ defaults: - override /lr_schedulers: null - override /optimizers: null -epochs: &epochs 512 +epochs: &epochs 256 ignore_index: &ignore_index 3 -num_classes: &num_classes 58 +num_classes: &num_classes 57 max_output_len: &max_output_len 89 summary: [[1, 1, 56, 1024], [1, 89]] criterion: - _target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss - smoothing: 0.1 ignore_index: *ignore_index mapping: &mapping mapping: _target_: text_recognizer.data.mappings.emnist.EmnistMapping - # extra_symbols: [ "\n" ] callbacks: stochastic_weight_averaging: @@ -38,31 +35,20 @@ callbacks: optimizers: madgrad: _target_: madgrad.MADGRAD - lr: 1.0e-4 + lr: 3.0e-4 momentum: 0.9 - weight_decay: 5.0e-6 + weight_decay: 0 eps: 1.0e-6 parameters: network lr_schedulers: network: - _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 1.0e-4 - total_steps: null - epochs: *epochs - steps_per_epoch: 722 - pct_start: 0.01 - anneal_strategy: cos - cycle_momentum: true - base_momentum: 0.85 - max_momentum: 0.95 - div_factor: 25 - final_div_factor: 1.0e2 - three_phase: false - last_epoch: -1 - verbose: false - interval: step - monitor: val/loss + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 256 + eta_min: 1.0e-5 + last_epoch: -1 + interval: epoch + monitor: val/loss datamodule: batch_size: 32 @@ -77,26 +63,25 @@ rotary_embedding: &rotary_embedding dim: 64 attn: &attn - dim: 192 - num_heads: 4 + dim: &hidden_dim 256 + num_heads: 8 dim_head: 64 - dropout_rate: 0.05 + dropout_rate: &dropout_rate 0.5 network: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 56, 1024] - hidden_dim: &hidden_dim 192 + hidden_dim: *hidden_dim num_classes: *num_classes pad_index: *ignore_index encoder: _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet - arch: b0 - out_channels: 1280 + arch: b2 stochastic_dropout_rate: 0.2 bn_momentum: 0.99 bn_eps: 1.0e-3 decoder: - depth: 4 + depth: 6 _target_: text_recognizer.networks.transformer.layers.Decoder self_attn: _target_: text_recognizer.networks.transformer.attention.Attention @@ -116,7 +101,7 @@ network: dim_out: null expansion_factor: 4 glu: true - dropout_rate: 0.05 + dropout_rate: *dropout_rate pre_norm: true pixel_pos_embedding: _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding -- cgit v1.2.3-70-g09d2