summaryrefslogtreecommitdiff
path: root/training/conf/experiment/vit_lines.yaml
diff options
context:
space:
mode:
Diffstat (limited to 'training/conf/experiment/vit_lines.yaml')
-rw-r--r--training/conf/experiment/vit_lines.yaml45
1 files changed, 3 insertions, 42 deletions
diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml
index f57eead..f3049ea 100644
--- a/training/conf/experiment/vit_lines.yaml
+++ b/training/conf/experiment/vit_lines.yaml
@@ -4,13 +4,13 @@ defaults:
- override /criterion: cross_entropy
- override /callbacks: htr
- override /datamodule: iam_lines
- - override /network: null
+ - override /network: vit_lines
- override /model: lit_transformer
- override /lr_scheduler: null
- override /optimizer: null
tags: [lines, vit]
-epochs: &epochs 256
+epochs: &epochs 128
ignore_index: &ignore_index 3
# summary: [[1, 1, 56, 1024], [1, 89]]
@@ -59,45 +59,6 @@ datamodule:
batch_size: 16
train_fraction: 0.95
-network:
- _target_: text_recognizer.network.vit.VisionTransformer
- image_height: 56
- image_width: 1024
- patch_height: 28
- patch_width: 32
- dim: &dim 1024
- num_classes: &num_classes 58
- encoder:
- _target_: text_recognizer.network.transformer.encoder.Encoder
- dim: *dim
- inner_dim: 2048
- heads: 16
- dim_head: 64
- depth: 4
- dropout_rate: 0.0
- decoder:
- _target_: text_recognizer.network.transformer.decoder.Decoder
- dim: *dim
- inner_dim: 2048
- heads: 16
- dim_head: 64
- depth: 4
- dropout_rate: 0.0
- token_embedding:
- _target_: "text_recognizer.network.transformer.embedding.token.\
- TokenEmbedding"
- num_tokens: *num_classes
- dim: *dim
- use_l2: true
- pos_embedding:
- _target_: "text_recognizer.network.transformer.embedding.absolute.\
- AbsolutePositionalEmbedding"
- dim: *dim
- max_length: 89
- use_l2: true
- tie_embeddings: true
- pad_index: 3
-
model:
max_output_len: 89
@@ -105,7 +66,7 @@ trainer:
fast_dev_run: false
gradient_clip_val: 1.0
max_epochs: *epochs
- accumulate_grad_batches: 4
+ accumulate_grad_batches: 1
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0