From 41c3e99fe57874ba1855c893bf47087d474ec6b8 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 25 Oct 2021 22:32:10 +0200
Subject: Updates to configs

---
 training/conf/criterion/transducer.yaml            |  14 ++
 training/conf/criterion/vqgan_loss.yaml            |   1 -
 .../datamodule/target_transform/word_piece.yaml    |   3 +
 .../target_transform/word_piece_iam.yaml           |   3 +
 .../datamodule/target_transform/word_pieces.yaml   |   3 -
 .../datamodule/transform/barlow_paragraphs.yaml    |  46 +++++++
 training/conf/experiment/barlow_twins.yaml         |  12 +-
 .../conf/experiment/barlow_twins_paragraphs.yaml   | 103 ++++++++++++++
 training/conf/experiment/cnn_htr_char_lines.yaml   |  19 +--
 .../conf/experiment/cnn_htr_ctc_char_lines.yaml    | 149 +++++++++++++++++++++
 training/conf/experiment/cnn_htr_wp_lines.yaml     |  52 +++----
 .../experiment/conv_transformer_paragraphs.yaml    |   6 +-
 .../experiment/conv_transformer_paragraphs_wp.yaml |  39 +++---
 training/conf/mapping/word_piece.yaml              |   1 -
 14 files changed, 384 insertions(+), 67 deletions(-)
 create mode 100644 training/conf/criterion/transducer.yaml
 create mode 100644 training/conf/datamodule/target_transform/word_piece.yaml
 create mode 100644 training/conf/datamodule/target_transform/word_piece_iam.yaml
 delete mode 100644 training/conf/datamodule/target_transform/word_pieces.yaml
 create mode 100644 training/conf/datamodule/transform/barlow_paragraphs.yaml
 create mode 100644 training/conf/experiment/barlow_twins_paragraphs.yaml
 create mode 100644 training/conf/experiment/cnn_htr_ctc_char_lines.yaml

(limited to 'training/conf')

diff --git a/training/conf/criterion/transducer.yaml b/training/conf/criterion/transducer.yaml
new file mode 100644
index 0000000..7e661a7
--- /dev/null
+++ b/training/conf/criterion/transducer.yaml
@@ -0,0 +1,14 @@
+_target_: text_recognizer.criterions.transducer.Transducer
+preprocessor:
+  _target_: text_recognizer.data.utils.iam_preprocessor.Preprocessor 
+  num_features: 1000
+  tokens: iamdb_1kwp_tokens_1000.txt
+  lexicon: iamdb_1kwp_lex_1000.txt
+  use_words: false
+  prepend_wordsep: false
+  special_tokens: [ "<s>", "<e>", "<p>", "\n" ]
+ngram: 0
+# transitions: 1kwp_prune_0_10_optblank.bin
+blank: "optional"
+allow_repeats: true
+reduction: "none"
diff --git a/training/conf/criterion/vqgan_loss.yaml b/training/conf/criterion/vqgan_loss.yaml
index f983f6f..43e05cd 100644
--- a/training/conf/criterion/vqgan_loss.yaml
+++ b/training/conf/criterion/vqgan_loss.yaml
@@ -9,4 +9,3 @@ discriminator:
   num_layers: 3
 vq_loss_weight: 1.0
 discriminator_weight: 1.0
-
diff --git a/training/conf/datamodule/target_transform/word_piece.yaml b/training/conf/datamodule/target_transform/word_piece.yaml
new file mode 100644
index 0000000..bf284fb
--- /dev/null
+++ b/training/conf/datamodule/target_transform/word_piece.yaml
@@ -0,0 +1,3 @@
+word_piece:
+  _target_: text_recognizer.data.transforms.word_piece.WordPiece
+  max_len: 451
diff --git a/training/conf/datamodule/target_transform/word_piece_iam.yaml b/training/conf/datamodule/target_transform/word_piece_iam.yaml
new file mode 100644
index 0000000..478987c
--- /dev/null
+++ b/training/conf/datamodule/target_transform/word_piece_iam.yaml
@@ -0,0 +1,3 @@
+word_piece:
+  _target_: text_recognizer.data.transforms.word_piece.WordPiece
+  max_len: 72
diff --git a/training/conf/datamodule/target_transform/word_pieces.yaml b/training/conf/datamodule/target_transform/word_pieces.yaml
deleted file mode 100644
index 8ace2af..0000000
--- a/training/conf/datamodule/target_transform/word_pieces.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-word_pieces:
-  _target_: text_recognizer.data.transforms.word_pieces.WordPieces
-  max_len: 451
diff --git a/training/conf/datamodule/transform/barlow_paragraphs.yaml b/training/conf/datamodule/transform/barlow_paragraphs.yaml
new file mode 100644
index 0000000..5eefce5
--- /dev/null
+++ b/training/conf/datamodule/transform/barlow_paragraphs.yaml
@@ -0,0 +1,46 @@
+
+barlow:
+  _target_: text_recognizer.data.transforms.barlow.BarlowTransform
+  prim:
+    random_crop:
+      _target_: torchvision.transforms.RandomCrop
+      size: [576, 640]
+      padding: null
+      pad_if_needed: true
+      fill: 0
+      padding_mode: constant
+
+    color_jitter:
+      _target_: torchvision.transforms.ColorJitter
+      brightness: [0.8, 1.6]
+
+    random_affine:
+      _target_: torchvision.transforms.RandomAffine
+      degrees: 1
+      shear: [-10, 10]
+      interpolation: BILINEAR
+
+    to_tensor:
+      _target_: torchvision.transforms.ToTensor
+
+  bis:
+    random_crop:
+      _target_: torchvision.transforms.RandomCrop
+      size: [576, 640]
+      padding: null
+      pad_if_needed: true
+      fill: 0
+      padding_mode: constant
+
+    color_jitter:
+      _target_: torchvision.transforms.ColorJitter
+      brightness: [0.8, 2.0]
+
+    random_affine:
+      _target_: torchvision.transforms.RandomAffine
+      degrees: 1
+      shear: [-5, 5]
+      interpolation: BILINEAR
+
+    to_tensor:
+      _target_: torchvision.transforms.ToTensor
diff --git a/training/conf/experiment/barlow_twins.yaml b/training/conf/experiment/barlow_twins.yaml
index e3586bf..cc1295d 100644
--- a/training/conf/experiment/barlow_twins.yaml
+++ b/training/conf/experiment/barlow_twins.yaml
@@ -8,14 +8,12 @@ defaults:
   - override /lr_schedulers: null
   - override /optimizers: null
 
-
-print_config: true
 epochs: &epochs 1000
 summary: [[1, 1, 56, 1024]]
 
 criterion:
   _target_: text_recognizer.criterions.barlow_twins.BarlowTwinsLoss
-  dim: 2048
+  dim: 512
   lambda_: 3.9e-3
 
 callbacks:
@@ -30,7 +28,7 @@ callbacks:
 optimizers:
   madgrad:
     _target_: madgrad.MADGRAD
-    lr: 3.0e-4
+    lr: 1.0e-3
     momentum: 0.9
     weight_decay: 1.0e-6
     eps: 1.0e-6
@@ -61,7 +59,7 @@ datamodule:
     _target_: text_recognizer.data.iam_lines.IAMLines
     batch_size: 16
     num_workers: 12
-    train_fraction: 0.8
+    train_fraction: 0.9
     pin_memory: false
     transform: transform/iam_lines_barlow.yaml
     test_transform: transform/iam_lines_barlow.yaml
@@ -79,7 +77,7 @@ network:
     bn_eps: 1.0e-3
   projector:
     _target_: text_recognizer.networks.barlow_twins.projector.Projector
-    dims: [1280, 2048, 2048]
+    dims: [1280, 512, 512, 512]
 
 model:
   _target_: text_recognizer.models.barlow_twins.BarlowTwinsLitModel
@@ -100,5 +98,5 @@ trainer:
   limit_val_batches: 1.0
   limit_test_batches: 1.0
   resume_from_checkpoint: null
-  accumulate_grad_batches: 64
+  accumulate_grad_batches: 32
   overfit_batches: 0
diff --git a/training/conf/experiment/barlow_twins_paragraphs.yaml b/training/conf/experiment/barlow_twins_paragraphs.yaml
new file mode 100644
index 0000000..caefb47
--- /dev/null
+++ b/training/conf/experiment/barlow_twins_paragraphs.yaml
@@ -0,0 +1,103 @@
+# @package _global_
+
+defaults:
+  - override /criterion: null
+  - override /datamodule: null
+  - override /network: null
+  - override /model: null
+  - override /lr_schedulers: null
+  - override /optimizers: null
+
+epochs: &epochs 1000
+summary: [[1, 1, 576, 640]]
+
+criterion:
+  _target_: text_recognizer.criterions.barlow_twins.BarlowTwinsLoss
+  dim: 512
+  lambda_: 3.9e-3
+
+# callbacks:
+#   stochastic_weight_averaging:
+#     _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+#     swa_epoch_start: 0.75
+#     swa_lrs: 1.0e-5
+#     annealing_epochs: 10
+#     annealing_strategy: cos
+#     device: null
+
+optimizers:
+  madgrad:
+    _target_: madgrad.MADGRAD
+    lr: 1.0e-3
+    momentum: 0.9
+    weight_decay: 1.0e-6
+    eps: 1.0e-6
+    parameters: network
+
+lr_schedulers:
+  network:
+    _target_: torch.optim.lr_scheduler.OneCycleLR
+    max_lr: 1.0e-1
+    total_steps: null
+    epochs: *epochs
+    steps_per_epoch: 5053
+    pct_start: 0.03
+    anneal_strategy: cos
+    cycle_momentum: true
+    base_momentum: 0.85
+    max_momentum: 0.95
+    div_factor: 25
+    final_div_factor: 1.0e4
+    three_phase: false
+    last_epoch: -1
+    verbose: false
+    # Non-class arguments
+    interval: step
+    monitor: val/loss
+
+datamodule:
+  _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
+  batch_size: 4
+  num_workers: 12
+  train_fraction: 0.9
+  pin_memory: true
+  transform: transform/barlow_paragraphs.yaml
+  test_transform: transform/barlow_paragraphs.yaml
+  mapping:
+    _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping
+    extra_symbols: [ "\n" ]
+
+network:
+  _target_: text_recognizer.networks.barlow_twins.network.BarlowTwins
+  encoder:
+    _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
+    arch: b0
+    out_channels: 1280
+    stochastic_dropout_rate: 0.2
+    bn_momentum: 0.99
+    bn_eps: 1.0e-3
+  projector:
+    _target_: text_recognizer.networks.barlow_twins.projector.Projector
+    dims: [1280, 512, 512, 512]
+
+model:
+  _target_: text_recognizer.models.barlow_twins.BarlowTwinsLitModel
+
+trainer:
+  _target_: pytorch_lightning.Trainer
+  stochastic_weight_avg: true
+  auto_scale_batch_size: binsearch
+  auto_lr_find: false
+  gradient_clip_val: 0.0
+  fast_dev_run: false
+  gpus: 1
+  precision: 16
+  max_epochs: *epochs
+  terminate_on_nan: true
+  weights_summary: null
+  limit_train_batches: 1.0 
+  limit_val_batches: 1.0
+  limit_test_batches: 1.0
+  resume_from_checkpoint: null
+  accumulate_grad_batches: 128
+  overfit_batches: 0
diff --git a/training/conf/experiment/cnn_htr_char_lines.yaml b/training/conf/experiment/cnn_htr_char_lines.yaml
index 682f138..759161c 100644
--- a/training/conf/experiment/cnn_htr_char_lines.yaml
+++ b/training/conf/experiment/cnn_htr_char_lines.yaml
@@ -1,3 +1,5 @@
+# @package _glbal_
+
 defaults:
   - override /mapping: null
   - override /criterion: null
@@ -34,7 +36,7 @@ callbacks:
 optimizers:
   madgrad:
     _target_: madgrad.MADGRAD
-    lr: 3.0e-4
+    lr: 1.0e-4
     momentum: 0.9
     weight_decay: 0
     eps: 1.0e-6
@@ -63,21 +65,20 @@ lr_schedulers:
 
 datamodule:
   _target_: text_recognizer.data.iam_lines.IAMLines
-  batch_size: 32
   num_workers: 12
-  train_fraction: 0.8
-  augment: true
+  train_fraction: 0.9
   pin_memory: true
-  word_pieces: false
+  transform: transform/iam_lines.yaml
+  test_transform: transform/iam_lines.yaml
+  target_transform: target_transform/word_pieces
   <<: *mapping
 
 network:
   _target_: text_recognizer.networks.conv_transformer.ConvTransformer
   input_dims: [1, 56, 1024]
-  hidden_dim: &hidden_dim 128
+  hidden_dim: &hidden_dim 256
   encoder_dim: 1280
   dropout_rate: 0.2
-  <<: *mapping
   num_classes: *num_classes
   pad_index: *ignore_index
   encoder:
@@ -111,8 +112,8 @@ network:
   pixel_pos_embedding:
     _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D
     hidden_dim: *hidden_dim 
-    max_h: 1
-    max_w: 32
+    max_h: 18
+    max_w: 20
   token_pos_embedding:
     _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding
     hidden_dim: *hidden_dim 
diff --git a/training/conf/experiment/cnn_htr_ctc_char_lines.yaml b/training/conf/experiment/cnn_htr_ctc_char_lines.yaml
new file mode 100644
index 0000000..965d35f
--- /dev/null
+++ b/training/conf/experiment/cnn_htr_ctc_char_lines.yaml
@@ -0,0 +1,149 @@
+# @package _global_
+
+defaults:
+  - override /mapping: null
+  - override /callbacks: htr
+  - override /criterion: null
+  - override /datamodule: null
+  - override /network: null
+  - override /model: null
+  - override /lr_schedulers: null
+  - override /optimizers: null
+
+
+epochs: &epochs 200
+ignore_index: &ignore_index 3
+num_classes: &num_classes 58
+max_output_len: &max_output_len 89
+
+criterion:
+  _target_: text_recognizer.criterions.ctc.CTCLoss
+  blank: *ignore_index
+    
+mapping: &mapping
+  mapping: 
+    _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping
+
+callbacks:
+  stochastic_weight_averaging:
+    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+    swa_epoch_start: 0.75
+    swa_lrs: 5.0e-5
+    annealing_epochs: 10
+    annealing_strategy: cos
+    device: null
+
+optimizers:
+  madgrad:
+    _target_: madgrad.MADGRAD
+    lr: 1.0e-4
+    momentum: 0.9
+    weight_decay: 0
+    eps: 1.0e-6
+    parameters: network
+
+lr_schedulers:
+  network:
+    _target_: torch.optim.lr_scheduler.OneCycleLR
+    max_lr: 3.0e-4
+    total_steps: null
+    epochs: *epochs
+    steps_per_epoch: 90
+    pct_start: 0.1
+    anneal_strategy: cos
+    cycle_momentum: true
+    base_momentum: 0.85
+    max_momentum: 0.95
+    div_factor: 25
+    final_div_factor: 1.0e4
+    three_phase: false
+    last_epoch: -1
+    verbose: false
+    # Non-class arguments
+    interval: step
+    monitor: val/loss
+
+datamodule:
+  _target_: text_recognizer.data.iam_lines.IAMLines
+  num_workers: 12
+  batch_size: 16
+  train_fraction: 0.9
+  pin_memory: true
+  transform: transform/iam_lines.yaml
+  test_transform: transform/iam_lines.yaml
+  <<: *mapping
+
+network:
+  _target_: text_recognizer.networks.conv_transformer.ConvTransformer
+  input_dims: [1, 56, 1024]
+  hidden_dim: &hidden_dim 256
+  encoder_dim: 1280
+  dropout_rate: 0.2
+  num_classes: *num_classes
+  pad_index: *ignore_index
+  encoder:
+    _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
+    arch: b0
+    out_channels: 1280
+    stochastic_dropout_rate: 0.2
+    bn_momentum: 0.99
+    bn_eps: 1.0e-3
+  decoder:
+    _target_: text_recognizer.networks.transformer.Decoder
+    dim: *hidden_dim
+    depth: 3 
+    num_heads: 4
+    attn_fn: text_recognizer.networks.transformer.attention.Attention
+    attn_kwargs:
+      dim_head: 32
+      dropout_rate: 0.2
+    norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm
+    ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
+    ff_kwargs:
+      dim_out: null
+      expansion_factor: 4
+      glu: true
+      dropout_rate: 0.2
+    cross_attend: true
+    pre_norm: true
+    rotary_emb:
+      _target_: text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding
+      dim: 32
+  pixel_pos_embedding:
+    _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D
+    hidden_dim: *hidden_dim 
+    max_h: 1
+    max_w: 32
+  token_pos_embedding:
+    _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding
+    hidden_dim: *hidden_dim 
+    dropout_rate: 0.2
+    max_len: *max_output_len
+
+model:
+  _target_: text_recognizer.models.transformer.TransformerLitModel
+  max_output_len: *max_output_len
+  start_token: <s>
+  end_token: <e>
+  pad_token: <p>
+
+trainer:
+  _target_: pytorch_lightning.Trainer
+  stochastic_weight_avg: true
+  auto_scale_batch_size: binsearch
+  auto_lr_find: false
+  gradient_clip_val: 0.5
+  fast_dev_run: false
+  gpus: 1
+  precision: 16
+  max_epochs: *epochs
+  terminate_on_nan: true
+  weights_summary: null
+  limit_train_batches: 1.0 
+  limit_val_batches: 1.0
+  limit_test_batches: 1.0
+  resume_from_checkpoint: null
+  accumulate_grad_batches: 4
+  overfit_batches: 0
+
+summary: [[1, 1, 56, 1024], [1, 89]]
diff --git a/training/conf/experiment/cnn_htr_wp_lines.yaml b/training/conf/experiment/cnn_htr_wp_lines.yaml
index f467b74..6cdd023 100644
--- a/training/conf/experiment/cnn_htr_wp_lines.yaml
+++ b/training/conf/experiment/cnn_htr_wp_lines.yaml
@@ -1,5 +1,8 @@
+# @package _global_
+
 defaults:
   - override /mapping: null
+  - override /callbacks: htr
   - override /criterion: null
   - override /datamodule: null
   - override /network: null
@@ -7,28 +10,27 @@ defaults:
   - override /lr_schedulers: null
   - override /optimizers: null
 
-epochs: &epochs 256
+epochs: &epochs 512
 ignore_index: &ignore_index 1000
 num_classes: &num_classes 1006
 max_output_len: &max_output_len 72
 
 criterion:
-  _target_: torch.nn.CrossEntropyLoss
-  ignore_index: *ignore_index
-  # _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss
-  # smoothing: 0.1 
+  # _target_: torch.nn.CrossEntropyLoss
   # ignore_index: *ignore_index
+  _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss
+  smoothing: 0.1
+  ignore_index: *ignore_index
     
-mapping:
-  _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping
-  num_features: 1000
-  tokens: iamdb_1kwp_tokens_1000.txt
-  lexicon: iamdb_1kwp_lex_1000.txt
-  data_dir: null
-  use_words: false
-  prepend_wordsep: false
-  special_tokens: [ <s>, <e>, <p> ]
-  # extra_symbols: [ "\n" ]
+mapping: &mapping
+  mapping:
+    _target_: text_recognizer.data.mappings.word_piece_mapping.WordPieceMapping
+    num_features: 1000
+    tokens: iamdb_1kwp_tokens_1000.txt
+    lexicon: iamdb_1kwp_lex_1000.txt
+    use_words: false
+    prepend_wordsep: false
+    special_tokens: [ <s>, <e>, <p> ]
 
 callbacks:
   stochastic_weight_averaging:
@@ -42,7 +44,7 @@ callbacks:
 optimizers:
   madgrad:
     _target_: madgrad.MADGRAD
-    lr: 3.0e-4
+    lr: 1.0e-4
     momentum: 0.9
     weight_decay: 0
     eps: 1.0e-6
@@ -52,11 +54,11 @@ optimizers:
 lr_schedulers:
   network:
     _target_: torch.optim.lr_scheduler.OneCycleLR
-    max_lr: 3.0e-4
+    max_lr: 1.0e-4
     total_steps: null
     epochs: *epochs
-    steps_per_epoch: 90
-    pct_start: 0.1
+    steps_per_epoch: 179
+    pct_start: 0.03
     anneal_strategy: cos
     cycle_momentum: true
     base_momentum: 0.85
@@ -72,17 +74,19 @@ lr_schedulers:
 
 datamodule:
   _target_: text_recognizer.data.iam_lines.IAMLines
-  batch_size: 32
   num_workers: 12
-  train_fraction: 0.8
-  augment: true
+  batch_size: 16
+  train_fraction: 0.9
   pin_memory: true
-  word_pieces: true
+  transform: transform/iam_lines.yaml
+  test_transform: transform/iam_lines.yaml
+  target_transform: target_transform/word_piece_iam.yaml
+  <<: *mapping
 
 network:
   _target_: text_recognizer.networks.conv_transformer.ConvTransformer
   input_dims: [1, 56, 1024]
-  hidden_dim: &hidden_dim 128
+  hidden_dim: &hidden_dim 256
   encoder_dim: 1280
   dropout_rate: 0.2
   num_classes: *num_classes
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index ebbd6ef..9e8bc50 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -1,4 +1,4 @@
-# @package _global_
+# @package _glbal_
 
 defaults:
   - override /mapping: null
@@ -48,7 +48,7 @@ optimizers:
 lr_schedulers:
   network:
     _target_: torch.optim.lr_scheduler.OneCycleLR
-    max_lr: 2.0e-4
+    max_lr: 1.0e-4
     total_steps: null
     epochs: *epochs
     steps_per_epoch: 632
@@ -77,7 +77,7 @@ datamodule:
 network:
   _target_: text_recognizer.networks.conv_transformer.ConvTransformer
   input_dims: [1, 576, 640]
-  hidden_dim: &hidden_dim 128
+  hidden_dim: &hidden_dim 256
   encoder_dim: 1280
   dropout_rate: 0.2
   num_classes: *num_classes
diff --git a/training/conf/experiment/conv_transformer_paragraphs_wp.yaml b/training/conf/experiment/conv_transformer_paragraphs_wp.yaml
index 499a609..ebaa17a 100644
--- a/training/conf/experiment/conv_transformer_paragraphs_wp.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs_wp.yaml
@@ -1,3 +1,5 @@
+# @package _global_
+
 defaults:
   - override /mapping: null
   - override /criterion: null
@@ -7,7 +9,6 @@ defaults:
   - override /lr_schedulers: null
   - override /optimizers: null
 
-
 epochs: &epochs 1000
 ignore_index: &ignore_index 1000
 num_classes: &num_classes 1006
@@ -17,17 +18,17 @@ summary: [[1, 1, 576, 640], [1, 451]]
 criterion:
   _target_: torch.nn.CrossEntropyLoss
   ignore_index: *ignore_index
-    
-mapping:
-  _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping
-  num_features: 1000
-  tokens: iamdb_1kwp_tokens_1000.txt
-  lexicon: iamdb_1kwp_lex_1000.txt
-  data_dir: null
-  use_words: false
-  prepend_wordsep: false
-  special_tokens: [ <s>, <e>, <p> ]
-  extra_symbols: [ "\n" ]
+
+mapping: &mapping
+  mapping:
+    _target_: text_recognizer.data.mappings.word_piece_mapping.WordPieceMapping
+    num_features: 1000
+    tokens: iamdb_1kwp_tokens_1000.txt
+    lexicon: iamdb_1kwp_lex_1000.txt
+    use_words: false
+    prepend_wordsep: false
+    special_tokens: [ <s>, <e>, <p> ]
+    extra_symbols: [ "\n" ]
 
 callbacks:
   stochastic_weight_averaging:
@@ -41,7 +42,7 @@ callbacks:
 optimizers:
   madgrad:
     _target_: madgrad.MADGRAD
-    lr: 3.0e-4
+    lr: 1.0e-4
     momentum: 0.9
     weight_decay: 0
     eps: 1.0e-6
@@ -71,13 +72,13 @@ lr_schedulers:
 
 datamodule:
   _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
-  batch_size: 4
   num_workers: 12
-  train_fraction: 0.8
-  augment: true
+  train_fraction: 0.9
   pin_memory: true
-  word_pieces: true
-  resize: null
+  transform: transform/paragraphs.yaml
+  test_transform: transform/paragraphs.yaml
+  target_transform: target_transform/word_piece.yaml
+  << : *mapping
 
 network:
   _target_: text_recognizer.networks.conv_transformer.ConvTransformer
@@ -138,7 +139,7 @@ trainer:
   stochastic_weight_avg: true
   auto_scale_batch_size: binsearch
   auto_lr_find: false
-  gradient_clip_val: 0.0
+  gradient_clip_val: 0.5
   fast_dev_run: false
   gpus: 1
   precision: 16
diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml
index ca8dd9c..c005cc4 100644
--- a/training/conf/mapping/word_piece.yaml
+++ b/training/conf/mapping/word_piece.yaml
@@ -2,7 +2,6 @@ _target_: text_recognizer.data.mappings.word_piece_mapping.WordPieceMapping
 num_features: 1000
 tokens: iamdb_1kwp_tokens_1000.txt
 lexicon: iamdb_1kwp_lex_1000.txt
-data_dir: null
 use_words: false
 prepend_wordsep: false
 special_tokens: [ <s>, <e>, <p> ]
-- 
cgit v1.2.3-70-g09d2