diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/conf/experiment/barlow_twins.yaml | 84 |
1 files changed, 27 insertions, 57 deletions
diff --git a/training/conf/experiment/barlow_twins.yaml b/training/conf/experiment/barlow_twins.yaml index 4901e18..cb5035b 100644 --- a/training/conf/experiment/barlow_twins.yaml +++ b/training/conf/experiment/barlow_twins.yaml @@ -1,3 +1,5 @@ +# @package _global_ + defaults: - override /criterion: null - override /datamodule: null @@ -7,9 +9,15 @@ defaults: - override /optimizers: null +print_config: true epochs: &epochs 1000 summary: [[1, 1, 56, 1024]] +criterion: + _target_: text_recognizer.criterions.barlow_twins.BarlowTwinsLoss + dim: 512 + lambda_: 5.1e-3 + callbacks: stochastic_weight_averaging: _target_: pytorch_lightning.callbacks.StochasticWeightAveraging @@ -22,21 +30,20 @@ callbacks: optimizers: madgrad: _target_: madgrad.MADGRAD - lr: 2.0e-4 + lr: 3.0e-4 momentum: 0.9 - weight_decay: 0 + weight_decay: 1.0e-6 eps: 1.0e-6 - parameters: network lr_schedulers: network: _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 2.0e-4 + max_lr: 3.0e-4 total_steps: null epochs: *epochs - steps_per_epoch: 632 - pct_start: 0.3 + steps_per_epoch: 45 + pct_start: 0.03 anneal_strategy: cos cycle_momentum: true base_momentum: 0.85 @@ -51,23 +58,18 @@ lr_schedulers: monitor: val/loss datamodule: - _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs - batch_size: 4 - num_workers: 12 - train_fraction: 0.8 - augment: true - pin_memory: true - word_pieces: false - resize: null + _target_: text_recognizer.data.iam_lines.IAMLines + batch_size: 16 + num_workers: 12 + train_fraction: 0.8 + pin_memory: false + transform: transform/iam_lines_barlow.yaml + test_transform: transform/iam_lines_barlow.yaml + mapping: + _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping network: - _target_: text_recognizer.networks.conv_transformer.ConvTransformer - input_dims: [1, 576, 640] - hidden_dim: &hidden_dim 128 - encoder_dim: 1280 - dropout_rate: 0.2 - num_classes: *num_classes - pad_index: *ignore_index + _target_: text_recognizer.networks.barlow_twins.network.BarlowTwins encoder: _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet arch: b0 @@ -75,44 +77,12 @@ network: stochastic_dropout_rate: 0.2 bn_momentum: 0.99 bn_eps: 1.0e-3 - decoder: - _target_: text_recognizer.networks.transformer.Decoder - dim: *hidden_dim - depth: 3 - num_heads: 4 - attn_fn: text_recognizer.networks.transformer.attention.Attention - attn_kwargs: - dim_head: 32 - dropout_rate: 0.2 - norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm - ff_fn: text_recognizer.networks.transformer.mlp.FeedForward - ff_kwargs: - dim_out: null - expansion_factor: 4 - glu: true - dropout_rate: 0.2 - cross_attend: true - pre_norm: true - rotary_emb: - _target_: text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding - dim: 32 - pixel_pos_embedding: - _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D - hidden_dim: *hidden_dim - max_h: 18 - max_w: 20 - token_pos_embedding: - _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding - hidden_dim: *hidden_dim - dropout_rate: 0.2 - max_len: *max_output_len + projector: + _target_: text_recognizer.networks.barlow_twins.projector.Projector + dims: [1280, 512, 512, 512] model: - _target_: text_recognizer.models.transformer.TransformerLitModel - max_output_len: *max_output_len - start_token: <s> - end_token: <e> - pad_token: <p> + _target_: text_recognizer.models.barlow_twins.BarlowTwinsLitModel trainer: _target_: pytorch_lightning.Trainer |