From 2862fdf77a19c4afa5e00c900af4877df31a3ea6 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 13 Sep 2022 19:09:25 +0200 Subject: Update configs --- .../conf/experiment/conv_transformer_lines.yaml | 101 +++++++++++++++------ .../experiment/conv_transformer_paragraphs.yaml | 7 +- 2 files changed, 75 insertions(+), 33 deletions(-) (limited to 'training/conf/experiment') diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index eb9bc9e..d4478cc 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -4,20 +4,26 @@ defaults: - override /criterion: cross_entropy - override /callbacks: htr - override /datamodule: iam_lines - - override /network: conv_transformer + - override /network: null + # - override /network: conv_transformer - override /model: lit_transformer - override /lr_scheduler: null - override /optimizer: null -epochs: &epochs 512 +tags: [lines] +epochs: &epochs 260 ignore_index: &ignore_index 3 num_classes: &num_classes 57 max_output_len: &max_output_len 89 summary: [[1, 1, 56, 1024], [1, 89]] +logger: + wandb: + tags: ${tags} + criterion: ignore_index: *ignore_index - label_smoothing: 0.05 + # label_smoothing: 0.1 callbacks: stochastic_weight_averaging: @@ -29,30 +35,23 @@ callbacks: device: null optimizer: - _target_: torch.optim.RAdam + _target_: adan_pytorch.Adan lr: 3.0e-4 - betas: [0.9, 0.999] - weight_decay: 0 - eps: 1.0e-8 - parameters: network + betas: [0.02, 0.08, 0.01] + weight_decay: 0.02 lr_scheduler: - _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 3.0e-4 - total_steps: null - epochs: *epochs - steps_per_epoch: 1354 - pct_start: 0.3 - anneal_strategy: cos - cycle_momentum: true - base_momentum: 0.85 - max_momentum: 0.95 - div_factor: 25.0 - final_div_factor: 10000.0 - three_phase: true - last_epoch: -1 + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 verbose: false - interval: step + interval: epoch monitor: val/cer datamodule: @@ -60,20 +59,66 @@ datamodule: train_fraction: 0.95 network: + _target_: text_recognizer.networks.ConvTransformer input_dims: [1, 1, 56, 1024] - num_classes: *num_classes - pad_index: *ignore_index + hidden_dim: &hidden_dim 128 + num_classes: 58 + pad_index: 3 encoder: - depth: 5 + _target_: text_recognizer.networks.convnext.ConvNext + dim: 16 + dim_mults: [2, 4, 8] + depths: [3, 3, 6] + downsampling_factors: [[2, 2], [2, 2], [2, 2]] decoder: + _target_: text_recognizer.networks.transformer.Decoder depth: 6 + block: + _target_: text_recognizer.networks.transformer.DecoderBlock + self_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: &dropout_rate 0.2 + causal: true + rotary_embedding: + _target_: text_recognizer.networks.transformer.RotaryEmbedding + dim: 64 + cross_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: *dropout_rate + causal: false + norm: + _target_: text_recognizer.networks.transformer.RMSNorm + dim: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 2 + glu: true + dropout_rate: *dropout_rate pixel_embedding: - shape: [3, 64] + _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + AxialPositionalEmbeddingImage" + dim: *hidden_dim + axial_shape: [7, 128] + axial_dims: [64, 64] + token_pos_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + PositionalEncoding" + dim: *hidden_dim + dropout_rate: 0.1 + max_len: 89 model: max_output_len: *max_output_len trainer: - gradient_clip_val: 0.5 + gradient_clip_val: 1.0 max_epochs: *epochs accumulate_grad_batches: 1 diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 41c236d..4bd3b45 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -19,10 +19,11 @@ summary: [[1, 1, 576, 640], [1, 682]] logger: wandb: tags: ${tags} + id: 8je5lxmx criterion: ignore_index: *ignore_index - label_smoothing: 0.05 + # label_smoothing: 0.05 callbacks: stochastic_weight_averaging: @@ -62,10 +63,6 @@ network: input_dims: [1, 1, 576, 640] num_classes: *num_classes pad_index: *ignore_index - encoder: - depth: 4 - decoder: - depth: 6 pixel_embedding: shape: [18, 79] -- cgit v1.2.3-70-g09d2