summaryrefslogtreecommitdiff
path: root/training/conf
diff options
context:
space:
mode:
Diffstat (limited to 'training/conf')
-rw-r--r--training/conf/experiment/cnn_transformer_paragraphs.yaml25
1 files changed, 14 insertions, 11 deletions
diff --git a/training/conf/experiment/cnn_transformer_paragraphs.yaml b/training/conf/experiment/cnn_transformer_paragraphs.yaml
index 910d408..5ee5536 100644
--- a/training/conf/experiment/cnn_transformer_paragraphs.yaml
+++ b/training/conf/experiment/cnn_transformer_paragraphs.yaml
@@ -1,7 +1,10 @@
+# @package _global_
+
defaults:
- override /mapping: null
- override /criterion: null
- - override /datamodule: null
+ - override /callbacks: htr
+ - override /datamodule: iam_extended_paragraphs
- override /network: null
- override /model: null
- override /lr_schedulers: null
@@ -18,9 +21,10 @@ criterion:
_target_: torch.nn.CrossEntropyLoss
ignore_index: *ignore_index
-mapping:
- _target_: text_recognizer.data.emnist_mapping.EmnistMapping
- extra_symbols: [ "\n" ]
+mapping: &mapping
+ mapping:
+ _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping
+ extra_symbols: [ "\n" ]
callbacks:
stochastic_weight_averaging:
@@ -34,9 +38,9 @@ callbacks:
optimizers:
madgrad:
_target_: madgrad.MADGRAD
- lr: 2.0e-4
+ lr: 3.0e-4
momentum: 0.9
- weight_decay: 0
+ weight_decay: 5.0e-6
eps: 1.0e-6
parameters: network
@@ -44,11 +48,11 @@ optimizers:
lr_schedulers:
network:
_target_: torch.optim.lr_scheduler.OneCycleLR
- max_lr: 2.0e-4
+ max_lr: 3.0e-4
total_steps: null
epochs: *epochs
steps_per_epoch: 632
- pct_start: 0.3
+ pct_start: 0.03
anneal_strategy: cos
cycle_momentum: true
base_momentum: 0.85
@@ -67,10 +71,8 @@ datamodule:
batch_size: 4
num_workers: 12
train_fraction: 0.8
- augment: true
pin_memory: true
- word_pieces: false
- resize: null
+ << : *mapping
network:
_target_: text_recognizer.networks.conv_transformer.ConvTransformer
@@ -121,6 +123,7 @@ network:
model:
_target_: text_recognizer.models.transformer.TransformerLitModel
+ << : *mapping
max_output_len: *max_output_len
start_token: <s>
end_token: <e>