diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-02 01:53:55 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-02 01:53:55 +0200 |
commit | cf42f987a34d4de0db10f733fd532f949dd9c278 (patch) | |
tree | 1420cc62d2792a3921272f4a5aa39e40ef7f7ce9 /training/conf | |
parent | abc2d60d69d115cdb34615d8bcb6c03ab6357141 (diff) |
Update configs
Diffstat (limited to 'training/conf')
-rw-r--r-- | training/conf/callbacks/htr.yaml | 2 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/captions.yaml | 4 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/checkpoints.yaml | 2 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/config.yaml | 2 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/predictions.yaml | 3 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/watch.yaml | 2 | ||||
-rw-r--r-- | training/conf/config.yaml | 4 | ||||
-rw-r--r-- | training/conf/experiment/vit_lines.yaml | 14 | ||||
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 59 | ||||
-rw-r--r-- | training/conf/network/convnext.yaml | 17 |
10 files changed, 16 insertions, 93 deletions
diff --git a/training/conf/callbacks/htr.yaml b/training/conf/callbacks/htr.yaml index ebcb0fa..5bd9b12 100644 --- a/training/conf/callbacks/htr.yaml +++ b/training/conf/callbacks/htr.yaml @@ -1,3 +1,3 @@ defaults: - default - - wandb/predictions + - wandb/captions diff --git a/training/conf/callbacks/wandb/captions.yaml b/training/conf/callbacks/wandb/captions.yaml new file mode 100644 index 0000000..3215a90 --- /dev/null +++ b/training/conf/callbacks/wandb/captions.yaml @@ -0,0 +1,4 @@ +log_text_predictions: + _target_: callbacks.wandb.ImageToCaptionLogger + num_samples: 8 + on_train: true diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb/checkpoints.yaml index a4a16ff..b9a3fd7 100644 --- a/training/conf/callbacks/wandb/checkpoints.yaml +++ b/training/conf/callbacks/wandb/checkpoints.yaml @@ -1,4 +1,4 @@ upload_ckpts_as_artifact: - _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact + _target_: callbacks.wandb.UploadCheckpointsAsArtifact ckpt_dir: checkpoints/ upload_best_only: true diff --git a/training/conf/callbacks/wandb/config.yaml b/training/conf/callbacks/wandb/config.yaml index 747a7c6..d841e94 100644 --- a/training/conf/callbacks/wandb/config.yaml +++ b/training/conf/callbacks/wandb/config.yaml @@ -1,2 +1,2 @@ upload_code_as_artifact: - _target_: callbacks.wandb_callbacks.UploadConfigAsArtifact + _target_: callbacks.wandb.UploadConfigAsArtifact diff --git a/training/conf/callbacks/wandb/predictions.yaml b/training/conf/callbacks/wandb/predictions.yaml deleted file mode 100644 index 573fa96..0000000 --- a/training/conf/callbacks/wandb/predictions.yaml +++ /dev/null @@ -1,3 +0,0 @@ -log_text_predictions: - _target_: callbacks.wandb_callbacks.LogTextPredictions - num_samples: 8 diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml index 660ae47..1f60978 100644 --- a/training/conf/callbacks/wandb/watch.yaml +++ b/training/conf/callbacks/wandb/watch.yaml @@ -1,5 +1,5 @@ watch_model: - _target_: callbacks.wandb_callbacks.WatchModel + _target_: callbacks.wandb.WatchModel log_params: gradients log_freq: 100 log_graph: true diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 8a1317c..6f5a15d 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -10,10 +10,10 @@ defaults: - logger: wandb - lr_scheduler: cosine_annealing - model: lit_transformer - - network: conv_transformer + - network: vit_lines - optimizer: radam - trainer: default - - experiment: vit_lines + - experiment: null seed: 4711 tune: false diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml index e2ddebf..f57eead 100644 --- a/training/conf/experiment/vit_lines.yaml +++ b/training/conf/experiment/vit_lines.yaml @@ -10,7 +10,7 @@ defaults: - override /optimizer: null tags: [lines, vit] -epochs: &epochs 64 +epochs: &epochs 256 ignore_index: &ignore_index 3 # summary: [[1, 1, 56, 1024], [1, 89]] @@ -56,7 +56,7 @@ lr_scheduler: monitor: val/cer datamodule: - batch_size: 8 + batch_size: 16 train_fraction: 0.95 network: @@ -95,7 +95,7 @@ network: dim: *dim max_length: 89 use_l2: true - tie_embeddings: false + tie_embeddings: true pad_index: 3 model: @@ -105,9 +105,7 @@ trainer: fast_dev_run: false gradient_clip_val: 1.0 max_epochs: *epochs - accumulate_grad_batches: 1 - limit_val_batches: .02 - limit_test_batches: .02 + accumulate_grad_batches: 4 limit_train_batches: 1.0 - # limit_val_batches: 1.0 - # limit_test_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml deleted file mode 100644 index 1e03946..0000000 --- a/training/conf/network/conv_transformer.yaml +++ /dev/null @@ -1,59 +0,0 @@ -_target_: text_recognizer.network.ConvTransformer -encoder: - _target_: text_recognizer.network.image_encoder.ImageEncoder - encoder: - _target_: text_recognizer.network.convnext.ConvNext - dim: 16 - dim_mults: [2, 4, 8] - depths: [3, 3, 6] - downsampling_factors: [[2, 2], [2, 2], [2, 2]] - pixel_embedding: - _target_: "text_recognizer.network.transformer.embeddings.axial.\ - AxialPositionalEmbeddingImage" - dim: &hidden_dim 128 - axial_shape: [7, 128] - axial_dims: [64, 64] -decoder: - _target_: text_recognizer.network.text_decoder.TextDecoder - hidden_dim: *hidden_dim - num_classes: 58 - pad_index: 3 - decoder: - _target_: text_recognizer.network.transformer.Decoder - dim: *hidden_dim - depth: 10 - block: - _target_: text_recognizer.network.transformer.decoder_block.DecoderBlock - self_attn: - _target_: text_recognizer.network.transformer.Attention - dim: *hidden_dim - num_heads: 12 - dim_head: 64 - dropout_rate: &dropout_rate 0.2 - causal: true - cross_attn: - _target_: text_recognizer.network.transformer.Attention - dim: *hidden_dim - num_heads: 12 - dim_head: 64 - dropout_rate: *dropout_rate - causal: false - norm: - _target_: text_recognizer.network.transformer.RMSNorm - dim: *hidden_dim - ff: - _target_: text_recognizer.network.transformer.FeedForward - dim: *hidden_dim - dim_out: null - expansion_factor: 2 - glu: true - dropout_rate: *dropout_rate - rotary_embedding: - _target_: text_recognizer.network.transformer.RotaryEmbedding - dim: 64 - token_pos_embedding: - _target_: "text_recognizer.network.transformer.embeddings.fourier.\ - PositionalEncoding" - dim: *hidden_dim - dropout_rate: 0.1 - max_len: 89 diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml deleted file mode 100644 index 904bd56..0000000 --- a/training/conf/network/convnext.yaml +++ /dev/null @@ -1,17 +0,0 @@ -_target_: text_recognizer.network.convnext.ConvNext -dim: 16 -dim_mults: [2, 4, 8] -depths: [3, 3, 6] -downsampling_factors: [[2, 2], [2, 2], [2, 2]] -attn: - _target_: text_recognizer.network.convnext.TransformerBlock - attn: - _target_: text_recognizer.network.convnext.Attention - dim: 128 - heads: 4 - dim_head: 64 - scale: 8 - ff: - _target_: text_recognizer.network.convnext.FeedForward - dim: 128 - mult: 4 |