summaryrefslogtreecommitdiff
path: root/training/conf/experiment
diff options
context:
space:
mode:
Diffstat (limited to 'training/conf/experiment')
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml30
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml30
-rw-r--r--training/conf/experiment/vit_lines.yaml113
3 files changed, 143 insertions, 30 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index 948968a..12fe701 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -56,70 +56,70 @@ datamodule:
train_fraction: 0.95
network:
- _target_: text_recognizer.networks.ConvTransformer
+ _target_: text_recognizer.network.ConvTransformer
encoder:
- _target_: text_recognizer.networks.image_encoder.ImageEncoder
+ _target_: text_recognizer.network.image_encoder.ImageEncoder
encoder:
- _target_: text_recognizer.networks.convnext.ConvNext
+ _target_: text_recognizer.network.convnext.ConvNext
dim: 16
dim_mults: [2, 4, 32]
depths: [3, 3, 6]
downsampling_factors: [[2, 2], [2, 2], [2, 2]]
attn:
- _target_: text_recognizer.networks.convnext.TransformerBlock
+ _target_: text_recognizer.network.convnext.TransformerBlock
attn:
- _target_: text_recognizer.networks.convnext.Attention
+ _target_: text_recognizer.network.convnext.Attention
dim: &dim 512
heads: 4
dim_head: 64
scale: 8
ff:
- _target_: text_recognizer.networks.convnext.FeedForward
+ _target_: text_recognizer.network.convnext.FeedForward
dim: *dim
mult: 2
pixel_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ _target_: "text_recognizer.network.transformer.embeddings.axial.\
AxialPositionalEmbeddingImage"
dim: *dim
axial_shape: [7, 128]
decoder:
- _target_: text_recognizer.networks.text_decoder.TextDecoder
+ _target_: text_recognizer.network.text_decoder.TextDecoder
dim: *dim
num_classes: 58
pad_index: *ignore_index
decoder:
- _target_: text_recognizer.networks.transformer.Decoder
+ _target_: text_recognizer.network.transformer.Decoder
dim: *dim
depth: 6
block:
- _target_: "text_recognizer.networks.transformer.decoder_block.\
+ _target_: "text_recognizer.network.transformer.decoder_block.\
DecoderBlock"
self_attn:
- _target_: text_recognizer.networks.transformer.Attention
+ _target_: text_recognizer.network.transformer.Attention
dim: *dim
num_heads: 8
dim_head: &dim_head 64
dropout_rate: &dropout_rate 0.2
causal: true
cross_attn:
- _target_: text_recognizer.networks.transformer.Attention
+ _target_: text_recognizer.network.transformer.Attention
dim: *dim
num_heads: 8
dim_head: *dim_head
dropout_rate: *dropout_rate
causal: false
norm:
- _target_: text_recognizer.networks.transformer.RMSNorm
+ _target_: text_recognizer.network.transformer.RMSNorm
dim: *dim
ff:
- _target_: text_recognizer.networks.transformer.FeedForward
+ _target_: text_recognizer.network.transformer.FeedForward
dim: *dim
dim_out: null
expansion_factor: 2
glu: true
dropout_rate: *dropout_rate
rotary_embedding:
- _target_: text_recognizer.networks.transformer.RotaryEmbedding
+ _target_: text_recognizer.network.transformer.RotaryEmbedding
dim: *dim_head
model:
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index ff931cc..9df2ea9 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -57,70 +57,70 @@ datamodule:
train_fraction: 0.95
network:
- _target_: text_recognizer.networks.ConvTransformer
+ _target_: text_recognizer.network.ConvTransformer
encoder:
- _target_: text_recognizer.networks.image_encoder.ImageEncoder
+ _target_: text_recognizer.network.image_encoder.ImageEncoder
encoder:
- _target_: text_recognizer.networks.convnext.ConvNext
+ _target_: text_recognizer.network.convnext.ConvNext
dim: 16
dim_mults: [1, 2, 4, 8, 32]
depths: [2, 3, 3, 3, 6]
downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]]
attn:
- _target_: text_recognizer.networks.convnext.TransformerBlock
+ _target_: text_recognizer.network.convnext.TransformerBlock
attn:
- _target_: text_recognizer.networks.convnext.Attention
+ _target_: text_recognizer.network.convnext.Attention
dim: &dim 512
heads: 4
dim_head: 64
scale: 8
ff:
- _target_: text_recognizer.networks.convnext.FeedForward
+ _target_: text_recognizer.network.convnext.FeedForward
dim: *dim
mult: 2
pixel_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ _target_: "text_recognizer.network.transformer.embeddings.axial.\
AxialPositionalEmbeddingImage"
dim: *dim
axial_shape: [18, 80]
decoder:
- _target_: text_recognizer.networks.text_decoder.TextDecoder
+ _target_: text_recognizer.network.text_decoder.TextDecoder
dim: *dim
num_classes: 58
pad_index: *ignore_index
decoder:
- _target_: text_recognizer.networks.transformer.Decoder
+ _target_: text_recognizer.network.transformer.Decoder
dim: *dim
depth: 6
block:
- _target_: "text_recognizer.networks.transformer.decoder_block.\
+ _target_: "text_recognizer.network.transformer.decoder_block.\
DecoderBlock"
self_attn:
- _target_: text_recognizer.networks.transformer.Attention
+ _target_: text_recognizer.network.transformer.Attention
dim: *dim
num_heads: 8
dim_head: &dim_head 64
dropout_rate: &dropout_rate 0.2
causal: true
cross_attn:
- _target_: text_recognizer.networks.transformer.Attention
+ _target_: text_recognizer.network.transformer.Attention
dim: *dim
num_heads: 8
dim_head: *dim_head
dropout_rate: *dropout_rate
causal: false
norm:
- _target_: text_recognizer.networks.transformer.RMSNorm
+ _target_: text_recognizer.network.transformer.RMSNorm
dim: *dim
ff:
- _target_: text_recognizer.networks.transformer.FeedForward
+ _target_: text_recognizer.network.transformer.FeedForward
dim: *dim
dim_out: null
expansion_factor: 2
glu: true
dropout_rate: *dropout_rate
rotary_embedding:
- _target_: text_recognizer.networks.transformer.RotaryEmbedding
+ _target_: text_recognizer.network.transformer.RotaryEmbedding
dim: *dim_head
trainer:
diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml
new file mode 100644
index 0000000..e2ddebf
--- /dev/null
+++ b/training/conf/experiment/vit_lines.yaml
@@ -0,0 +1,113 @@
+# @package _global_
+
+defaults:
+ - override /criterion: cross_entropy
+ - override /callbacks: htr
+ - override /datamodule: iam_lines
+ - override /network: null
+ - override /model: lit_transformer
+ - override /lr_scheduler: null
+ - override /optimizer: null
+
+tags: [lines, vit]
+epochs: &epochs 64
+ignore_index: &ignore_index 3
+# summary: [[1, 1, 56, 1024], [1, 89]]
+
+logger:
+ wandb:
+ tags: ${tags}
+
+criterion:
+ ignore_index: *ignore_index
+ # label_smoothing: 0.05
+
+
+decoder:
+ max_output_len: 89
+
+callbacks:
+ stochastic_weight_averaging:
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.75
+ swa_lrs: 1.0e-5
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+
+optimizer:
+ _target_: adan_pytorch.Adan
+ lr: 3.0e-4
+ betas: [0.02, 0.08, 0.01]
+ weight_decay: 0.02
+
+lr_scheduler:
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ mode: min
+ factor: 0.8
+ patience: 10
+ threshold: 1.0e-4
+ threshold_mode: rel
+ cooldown: 0
+ min_lr: 1.0e-5
+ eps: 1.0e-8
+ verbose: false
+ interval: epoch
+ monitor: val/cer
+
+datamodule:
+ batch_size: 8
+ train_fraction: 0.95
+
+network:
+ _target_: text_recognizer.network.vit.VisionTransformer
+ image_height: 56
+ image_width: 1024
+ patch_height: 28
+ patch_width: 32
+ dim: &dim 1024
+ num_classes: &num_classes 58
+ encoder:
+ _target_: text_recognizer.network.transformer.encoder.Encoder
+ dim: *dim
+ inner_dim: 2048
+ heads: 16
+ dim_head: 64
+ depth: 4
+ dropout_rate: 0.0
+ decoder:
+ _target_: text_recognizer.network.transformer.decoder.Decoder
+ dim: *dim
+ inner_dim: 2048
+ heads: 16
+ dim_head: 64
+ depth: 4
+ dropout_rate: 0.0
+ token_embedding:
+ _target_: "text_recognizer.network.transformer.embedding.token.\
+ TokenEmbedding"
+ num_tokens: *num_classes
+ dim: *dim
+ use_l2: true
+ pos_embedding:
+ _target_: "text_recognizer.network.transformer.embedding.absolute.\
+ AbsolutePositionalEmbedding"
+ dim: *dim
+ max_length: 89
+ use_l2: true
+ tie_embeddings: false
+ pad_index: 3
+
+model:
+ max_output_len: 89
+
+trainer:
+ fast_dev_run: false
+ gradient_clip_val: 1.0
+ max_epochs: *epochs
+ accumulate_grad_batches: 1
+ limit_val_batches: .02
+ limit_test_batches: .02
+ limit_train_batches: 1.0
+ # limit_val_batches: 1.0
+ # limit_test_batches: 1.0