summaryrefslogtreecommitdiff
path: root/training/conf/experiment
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-11 22:11:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-11 22:11:24 +0200
commitc9e50644ba14ae09aec5a44c12f8116bada26bab (patch)
treeba637de07ef6843050ce278ad79dee78a1a8f975 /training/conf/experiment
parent216fb147afbf6d38ef60e243a30f29d0f3993771 (diff)
Update Barlow Twins config
Diffstat (limited to 'training/conf/experiment')
-rw-r--r--training/conf/experiment/barlow_twins.yaml84
1 files changed, 27 insertions, 57 deletions
diff --git a/training/conf/experiment/barlow_twins.yaml b/training/conf/experiment/barlow_twins.yaml
index 4901e18..cb5035b 100644
--- a/training/conf/experiment/barlow_twins.yaml
+++ b/training/conf/experiment/barlow_twins.yaml
@@ -1,3 +1,5 @@
+# @package _global_
+
defaults:
- override /criterion: null
- override /datamodule: null
@@ -7,9 +9,15 @@ defaults:
- override /optimizers: null
+print_config: true
epochs: &epochs 1000
summary: [[1, 1, 56, 1024]]
+criterion:
+ _target_: text_recognizer.criterions.barlow_twins.BarlowTwinsLoss
+ dim: 512
+ lambda_: 5.1e-3
+
callbacks:
stochastic_weight_averaging:
_target_: pytorch_lightning.callbacks.StochasticWeightAveraging
@@ -22,21 +30,20 @@ callbacks:
optimizers:
madgrad:
_target_: madgrad.MADGRAD
- lr: 2.0e-4
+ lr: 3.0e-4
momentum: 0.9
- weight_decay: 0
+ weight_decay: 1.0e-6
eps: 1.0e-6
-
parameters: network
lr_schedulers:
network:
_target_: torch.optim.lr_scheduler.OneCycleLR
- max_lr: 2.0e-4
+ max_lr: 3.0e-4
total_steps: null
epochs: *epochs
- steps_per_epoch: 632
- pct_start: 0.3
+ steps_per_epoch: 45
+ pct_start: 0.03
anneal_strategy: cos
cycle_momentum: true
base_momentum: 0.85
@@ -51,23 +58,18 @@ lr_schedulers:
monitor: val/loss
datamodule:
- _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
- batch_size: 4
- num_workers: 12
- train_fraction: 0.8
- augment: true
- pin_memory: true
- word_pieces: false
- resize: null
+ _target_: text_recognizer.data.iam_lines.IAMLines
+ batch_size: 16
+ num_workers: 12
+ train_fraction: 0.8
+ pin_memory: false
+ transform: transform/iam_lines_barlow.yaml
+ test_transform: transform/iam_lines_barlow.yaml
+ mapping:
+ _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping
network:
- _target_: text_recognizer.networks.conv_transformer.ConvTransformer
- input_dims: [1, 576, 640]
- hidden_dim: &hidden_dim 128
- encoder_dim: 1280
- dropout_rate: 0.2
- num_classes: *num_classes
- pad_index: *ignore_index
+ _target_: text_recognizer.networks.barlow_twins.network.BarlowTwins
encoder:
_target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
arch: b0
@@ -75,44 +77,12 @@ network:
stochastic_dropout_rate: 0.2
bn_momentum: 0.99
bn_eps: 1.0e-3
- decoder:
- _target_: text_recognizer.networks.transformer.Decoder
- dim: *hidden_dim
- depth: 3
- num_heads: 4
- attn_fn: text_recognizer.networks.transformer.attention.Attention
- attn_kwargs:
- dim_head: 32
- dropout_rate: 0.2
- norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm
- ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
- ff_kwargs:
- dim_out: null
- expansion_factor: 4
- glu: true
- dropout_rate: 0.2
- cross_attend: true
- pre_norm: true
- rotary_emb:
- _target_: text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding
- dim: 32
- pixel_pos_embedding:
- _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D
- hidden_dim: *hidden_dim
- max_h: 18
- max_w: 20
- token_pos_embedding:
- _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding
- hidden_dim: *hidden_dim
- dropout_rate: 0.2
- max_len: *max_output_len
+ projector:
+ _target_: text_recognizer.networks.barlow_twins.projector.Projector
+ dims: [1280, 512, 512, 512]
model:
- _target_: text_recognizer.models.transformer.TransformerLitModel
- max_output_len: *max_output_len
- start_token: <s>
- end_token: <e>
- pad_token: <p>
+ _target_: text_recognizer.models.barlow_twins.BarlowTwinsLitModel
trainer:
_target_: pytorch_lightning.Trainer