summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/conf/callbacks/htr.yaml2
-rw-r--r--training/conf/callbacks/wandb/captions.yaml4
-rw-r--r--training/conf/callbacks/wandb/checkpoints.yaml2
-rw-r--r--training/conf/callbacks/wandb/config.yaml2
-rw-r--r--training/conf/callbacks/wandb/predictions.yaml3
-rw-r--r--training/conf/callbacks/wandb/watch.yaml2
-rw-r--r--training/conf/config.yaml4
-rw-r--r--training/conf/experiment/vit_lines.yaml14
-rw-r--r--training/conf/network/conv_transformer.yaml59
-rw-r--r--training/conf/network/convnext.yaml17
10 files changed, 16 insertions, 93 deletions
diff --git a/training/conf/callbacks/htr.yaml b/training/conf/callbacks/htr.yaml
index ebcb0fa..5bd9b12 100644
--- a/training/conf/callbacks/htr.yaml
+++ b/training/conf/callbacks/htr.yaml
@@ -1,3 +1,3 @@
defaults:
- default
- - wandb/predictions
+ - wandb/captions
diff --git a/training/conf/callbacks/wandb/captions.yaml b/training/conf/callbacks/wandb/captions.yaml
new file mode 100644
index 0000000..3215a90
--- /dev/null
+++ b/training/conf/callbacks/wandb/captions.yaml
@@ -0,0 +1,4 @@
+log_text_predictions:
+ _target_: callbacks.wandb.ImageToCaptionLogger
+ num_samples: 8
+ on_train: true
diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb/checkpoints.yaml
index a4a16ff..b9a3fd7 100644
--- a/training/conf/callbacks/wandb/checkpoints.yaml
+++ b/training/conf/callbacks/wandb/checkpoints.yaml
@@ -1,4 +1,4 @@
upload_ckpts_as_artifact:
- _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
+ _target_: callbacks.wandb.UploadCheckpointsAsArtifact
ckpt_dir: checkpoints/
upload_best_only: true
diff --git a/training/conf/callbacks/wandb/config.yaml b/training/conf/callbacks/wandb/config.yaml
index 747a7c6..d841e94 100644
--- a/training/conf/callbacks/wandb/config.yaml
+++ b/training/conf/callbacks/wandb/config.yaml
@@ -1,2 +1,2 @@
upload_code_as_artifact:
- _target_: callbacks.wandb_callbacks.UploadConfigAsArtifact
+ _target_: callbacks.wandb.UploadConfigAsArtifact
diff --git a/training/conf/callbacks/wandb/predictions.yaml b/training/conf/callbacks/wandb/predictions.yaml
deleted file mode 100644
index 573fa96..0000000
--- a/training/conf/callbacks/wandb/predictions.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-log_text_predictions:
- _target_: callbacks.wandb_callbacks.LogTextPredictions
- num_samples: 8
diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml
index 660ae47..1f60978 100644
--- a/training/conf/callbacks/wandb/watch.yaml
+++ b/training/conf/callbacks/wandb/watch.yaml
@@ -1,5 +1,5 @@
watch_model:
- _target_: callbacks.wandb_callbacks.WatchModel
+ _target_: callbacks.wandb.WatchModel
log_params: gradients
log_freq: 100
log_graph: true
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index 8a1317c..6f5a15d 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -10,10 +10,10 @@ defaults:
- logger: wandb
- lr_scheduler: cosine_annealing
- model: lit_transformer
- - network: conv_transformer
+ - network: vit_lines
- optimizer: radam
- trainer: default
- - experiment: vit_lines
+ - experiment: null
seed: 4711
tune: false
diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml
index e2ddebf..f57eead 100644
--- a/training/conf/experiment/vit_lines.yaml
+++ b/training/conf/experiment/vit_lines.yaml
@@ -10,7 +10,7 @@ defaults:
- override /optimizer: null
tags: [lines, vit]
-epochs: &epochs 64
+epochs: &epochs 256
ignore_index: &ignore_index 3
# summary: [[1, 1, 56, 1024], [1, 89]]
@@ -56,7 +56,7 @@ lr_scheduler:
monitor: val/cer
datamodule:
- batch_size: 8
+ batch_size: 16
train_fraction: 0.95
network:
@@ -95,7 +95,7 @@ network:
dim: *dim
max_length: 89
use_l2: true
- tie_embeddings: false
+ tie_embeddings: true
pad_index: 3
model:
@@ -105,9 +105,7 @@ trainer:
fast_dev_run: false
gradient_clip_val: 1.0
max_epochs: *epochs
- accumulate_grad_batches: 1
- limit_val_batches: .02
- limit_test_batches: .02
+ accumulate_grad_batches: 4
limit_train_batches: 1.0
- # limit_val_batches: 1.0
- # limit_test_batches: 1.0
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
deleted file mode 100644
index 1e03946..0000000
--- a/training/conf/network/conv_transformer.yaml
+++ /dev/null
@@ -1,59 +0,0 @@
-_target_: text_recognizer.network.ConvTransformer
-encoder:
- _target_: text_recognizer.network.image_encoder.ImageEncoder
- encoder:
- _target_: text_recognizer.network.convnext.ConvNext
- dim: 16
- dim_mults: [2, 4, 8]
- depths: [3, 3, 6]
- downsampling_factors: [[2, 2], [2, 2], [2, 2]]
- pixel_embedding:
- _target_: "text_recognizer.network.transformer.embeddings.axial.\
- AxialPositionalEmbeddingImage"
- dim: &hidden_dim 128
- axial_shape: [7, 128]
- axial_dims: [64, 64]
-decoder:
- _target_: text_recognizer.network.text_decoder.TextDecoder
- hidden_dim: *hidden_dim
- num_classes: 58
- pad_index: 3
- decoder:
- _target_: text_recognizer.network.transformer.Decoder
- dim: *hidden_dim
- depth: 10
- block:
- _target_: text_recognizer.network.transformer.decoder_block.DecoderBlock
- self_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *hidden_dim
- num_heads: 12
- dim_head: 64
- dropout_rate: &dropout_rate 0.2
- causal: true
- cross_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *hidden_dim
- num_heads: 12
- dim_head: 64
- dropout_rate: *dropout_rate
- causal: false
- norm:
- _target_: text_recognizer.network.transformer.RMSNorm
- dim: *hidden_dim
- ff:
- _target_: text_recognizer.network.transformer.FeedForward
- dim: *hidden_dim
- dim_out: null
- expansion_factor: 2
- glu: true
- dropout_rate: *dropout_rate
- rotary_embedding:
- _target_: text_recognizer.network.transformer.RotaryEmbedding
- dim: 64
- token_pos_embedding:
- _target_: "text_recognizer.network.transformer.embeddings.fourier.\
- PositionalEncoding"
- dim: *hidden_dim
- dropout_rate: 0.1
- max_len: 89
diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml
deleted file mode 100644
index 904bd56..0000000
--- a/training/conf/network/convnext.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-_target_: text_recognizer.network.convnext.ConvNext
-dim: 16
-dim_mults: [2, 4, 8]
-depths: [3, 3, 6]
-downsampling_factors: [[2, 2], [2, 2], [2, 2]]
-attn:
- _target_: text_recognizer.network.convnext.TransformerBlock
- attn:
- _target_: text_recognizer.network.convnext.Attention
- dim: 128
- heads: 4
- dim_head: 64
- scale: 8
- ff:
- _target_: text_recognizer.network.convnext.FeedForward
- dim: 128
- mult: 4