summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-03 01:13:37 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-03 01:13:37 +0200
commita65d3ec18a5541cec5297769f1027422975a62bc (patch)
tree08e5e22f76db2449d265476f5fb42c5ea64a2776
parente4d618443808f0931bbef0b9e10a2c2a215281a5 (diff)
Update confs and callbacks
-rw-r--r--training/callbacks/wandb.py34
-rw-r--r--training/conf/callbacks/default.yaml1
-rw-r--r--training/conf/callbacks/lightning/swa.yaml (renamed from training/conf/callbacks/wandb/swa.yaml)0
-rw-r--r--training/conf/callbacks/wandb/captions.yaml2
-rw-r--r--training/conf/callbacks/wandb/checkpoints.yaml4
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml134
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml132
-rw-r--r--training/conf/experiment/vit_lines.yaml45
-rw-r--r--training/conf/logger/wandb.yaml2
-rw-r--r--training/conf/network/vit_lines.yaml14
-rw-r--r--training/conf/sweep/conv_transformer.yaml29
11 files changed, 15 insertions, 382 deletions
diff --git a/training/callbacks/wandb.py b/training/callbacks/wandb.py
index d9bb9b8..6adbebe 100644
--- a/training/callbacks/wandb.py
+++ b/training/callbacks/wandb.py
@@ -62,32 +62,7 @@ class UploadConfigAsArtifact(Callback):
experiment.use_artifact(artifact)
-class UploadCheckpointsAsArtifact(Callback):
- """Upload checkpoint to wandb as an artifact, at the end of a run."""
-
- def __init__(
- self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False
- ) -> None:
- self.ckpt_dir = Path(ckpt_dir)
- self.upload_best_only = upload_best_only
-
- @rank_zero_only
- def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Uploads model checkpoint to W&B."""
- logger = get_wandb_logger(trainer)
- experiment = logger.experiment
- ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
-
- if self.upload_best_only:
- ckpts.add_file(trainer.checkpoint_callback.best_model_path)
- else:
- for ckpt in (self.ckpt_dir).rglob("*.ckpt"):
- ckpts.add_file(ckpt)
-
- experiment.use_artifact(ckpts)
-
-
-class ImageToCaptionLogger(Callback):
+class ImageToCaption(Callback):
"""Logs the image and output caption."""
def __init__(self, num_samples: int = 8, on_train: bool = True) -> None:
@@ -114,7 +89,7 @@ class ImageToCaptionLogger(Callback):
pl_module: LightningModule,
outputs: dict,
batch: Tuple[Tensor, Tensor],
- batch_idx: int,
+ *args,
) -> None:
"""Logs predictions on validation batch end."""
if self.has_metrics(outputs):
@@ -127,9 +102,7 @@ class ImageToCaptionLogger(Callback):
pl_module: LightningModule,
outputs: dict,
batch: Tuple[Tensor, Tensor],
- batch_idx: int,
*args,
- # dataloader_idx: int,
) -> None:
"""Logs predictions on validation batch end."""
if self.has_metrics(outputs):
@@ -142,8 +115,7 @@ class ImageToCaptionLogger(Callback):
pl_module: LightningModule,
outputs: dict,
batch: Tuple[Tensor, Tensor],
- batch_idx: int,
- dataloader_idx: int,
+ *args,
) -> None:
"""Logs predictions on train batch end."""
if self.has_metrics(outputs):
diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml
index 4d8e399..9b1347b 100644
--- a/training/conf/callbacks/default.yaml
+++ b/training/conf/callbacks/default.yaml
@@ -3,4 +3,3 @@ defaults:
- lightning/learning_rate_monitor
- wandb/watch
- wandb/config
- - wandb/checkpoints
diff --git a/training/conf/callbacks/wandb/swa.yaml b/training/conf/callbacks/lightning/swa.yaml
index 73f8c66..73f8c66 100644
--- a/training/conf/callbacks/wandb/swa.yaml
+++ b/training/conf/callbacks/lightning/swa.yaml
diff --git a/training/conf/callbacks/wandb/captions.yaml b/training/conf/callbacks/wandb/captions.yaml
index 3215a90..64f44c2 100644
--- a/training/conf/callbacks/wandb/captions.yaml
+++ b/training/conf/callbacks/wandb/captions.yaml
@@ -1,4 +1,4 @@
log_text_predictions:
- _target_: callbacks.wandb.ImageToCaptionLogger
+ _target_: callbacks.wandb.ImageToCaption
num_samples: 8
on_train: true
diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb/checkpoints.yaml
deleted file mode 100644
index b9a3fd7..0000000
--- a/training/conf/callbacks/wandb/checkpoints.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-upload_ckpts_as_artifact:
- _target_: callbacks.wandb.UploadCheckpointsAsArtifact
- ckpt_dir: checkpoints/
- upload_best_only: true
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
deleted file mode 100644
index 12fe701..0000000
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ /dev/null
@@ -1,134 +0,0 @@
-# @package _global_
-
-defaults:
- - override /criterion: cross_entropy
- - override /callbacks: htr
- - override /datamodule: iam_lines
- - override /network: null
- - override /model: lit_transformer
- - override /lr_scheduler: null
- - override /optimizer: null
-
-tags: [lines]
-epochs: &epochs 64
-ignore_index: &ignore_index 3
-# summary: [[1, 1, 56, 1024], [1, 89]]
-
-logger:
- wandb:
- tags: ${tags}
-
-criterion:
- ignore_index: *ignore_index
- # label_smoothing: 0.05
-
-callbacks:
- stochastic_weight_averaging:
- _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
- swa_epoch_start: 0.75
- swa_lrs: 1.0e-5
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
-
-optimizer:
- _target_: adan_pytorch.Adan
- lr: 3.0e-4
- betas: [0.02, 0.08, 0.01]
- weight_decay: 0.02
-
-lr_scheduler:
- _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
- mode: min
- factor: 0.8
- patience: 10
- threshold: 1.0e-4
- threshold_mode: rel
- cooldown: 0
- min_lr: 1.0e-5
- eps: 1.0e-8
- verbose: false
- interval: epoch
- monitor: val/cer
-
-datamodule:
- batch_size: 16
- train_fraction: 0.95
-
-network:
- _target_: text_recognizer.network.ConvTransformer
- encoder:
- _target_: text_recognizer.network.image_encoder.ImageEncoder
- encoder:
- _target_: text_recognizer.network.convnext.ConvNext
- dim: 16
- dim_mults: [2, 4, 32]
- depths: [3, 3, 6]
- downsampling_factors: [[2, 2], [2, 2], [2, 2]]
- attn:
- _target_: text_recognizer.network.convnext.TransformerBlock
- attn:
- _target_: text_recognizer.network.convnext.Attention
- dim: &dim 512
- heads: 4
- dim_head: 64
- scale: 8
- ff:
- _target_: text_recognizer.network.convnext.FeedForward
- dim: *dim
- mult: 2
- pixel_embedding:
- _target_: "text_recognizer.network.transformer.embeddings.axial.\
- AxialPositionalEmbeddingImage"
- dim: *dim
- axial_shape: [7, 128]
- decoder:
- _target_: text_recognizer.network.text_decoder.TextDecoder
- dim: *dim
- num_classes: 58
- pad_index: *ignore_index
- decoder:
- _target_: text_recognizer.network.transformer.Decoder
- dim: *dim
- depth: 6
- block:
- _target_: "text_recognizer.network.transformer.decoder_block.\
- DecoderBlock"
- self_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *dim
- num_heads: 8
- dim_head: &dim_head 64
- dropout_rate: &dropout_rate 0.2
- causal: true
- cross_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *dim
- num_heads: 8
- dim_head: *dim_head
- dropout_rate: *dropout_rate
- causal: false
- norm:
- _target_: text_recognizer.network.transformer.RMSNorm
- dim: *dim
- ff:
- _target_: text_recognizer.network.transformer.FeedForward
- dim: *dim
- dim_out: null
- expansion_factor: 2
- glu: true
- dropout_rate: *dropout_rate
- rotary_embedding:
- _target_: text_recognizer.network.transformer.RotaryEmbedding
- dim: *dim_head
-
-model:
- max_output_len: 89
-
-trainer:
- gradient_clip_val: 1.0
- max_epochs: *epochs
- accumulate_grad_batches: 1
- limit_train_batches: 1.0
- limit_val_batches: 1.0
- limit_test_batches: 1.0
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
deleted file mode 100644
index 9df2ea9..0000000
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ /dev/null
@@ -1,132 +0,0 @@
-# @package _global_
-
-defaults:
- - override /criterion: cross_entropy
- - override /callbacks: htr
- - override /datamodule: iam_extended_paragraphs
- - override /network: null
- - override /model: lit_transformer
- - override /lr_scheduler: null
- - override /optimizer: null
-
-tags: [paragraphs]
-epochs: &epochs 256
-ignore_index: &ignore_index 3
-# max_output_len: &max_output_len 682
-# summary: [[1, 1, 576, 640], [1, 682]]
-
-logger:
- wandb:
- tags: ${tags}
-
-criterion:
- ignore_index: *ignore_index
- # label_smoothing: 0.05
-
-callbacks:
- stochastic_weight_averaging:
- _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
- swa_epoch_start: 0.75
- swa_lrs: 1.0e-5
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
-
-optimizer:
- _target_: adan_pytorch.Adan
- lr: 3.0e-4
- betas: [0.02, 0.08, 0.01]
- weight_decay: 0.02
-
-lr_scheduler:
- _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
- mode: min
- factor: 0.8
- patience: 10
- threshold: 1.0e-4
- threshold_mode: rel
- cooldown: 0
- min_lr: 1.0e-5
- eps: 1.0e-8
- verbose: false
- interval: epoch
- monitor: val/cer
-
-datamodule:
- batch_size: 4
- train_fraction: 0.95
-
-network:
- _target_: text_recognizer.network.ConvTransformer
- encoder:
- _target_: text_recognizer.network.image_encoder.ImageEncoder
- encoder:
- _target_: text_recognizer.network.convnext.ConvNext
- dim: 16
- dim_mults: [1, 2, 4, 8, 32]
- depths: [2, 3, 3, 3, 6]
- downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]]
- attn:
- _target_: text_recognizer.network.convnext.TransformerBlock
- attn:
- _target_: text_recognizer.network.convnext.Attention
- dim: &dim 512
- heads: 4
- dim_head: 64
- scale: 8
- ff:
- _target_: text_recognizer.network.convnext.FeedForward
- dim: *dim
- mult: 2
- pixel_embedding:
- _target_: "text_recognizer.network.transformer.embeddings.axial.\
- AxialPositionalEmbeddingImage"
- dim: *dim
- axial_shape: [18, 80]
- decoder:
- _target_: text_recognizer.network.text_decoder.TextDecoder
- dim: *dim
- num_classes: 58
- pad_index: *ignore_index
- decoder:
- _target_: text_recognizer.network.transformer.Decoder
- dim: *dim
- depth: 6
- block:
- _target_: "text_recognizer.network.transformer.decoder_block.\
- DecoderBlock"
- self_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *dim
- num_heads: 8
- dim_head: &dim_head 64
- dropout_rate: &dropout_rate 0.2
- causal: true
- cross_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *dim
- num_heads: 8
- dim_head: *dim_head
- dropout_rate: *dropout_rate
- causal: false
- norm:
- _target_: text_recognizer.network.transformer.RMSNorm
- dim: *dim
- ff:
- _target_: text_recognizer.network.transformer.FeedForward
- dim: *dim
- dim_out: null
- expansion_factor: 2
- glu: true
- dropout_rate: *dropout_rate
- rotary_embedding:
- _target_: text_recognizer.network.transformer.RotaryEmbedding
- dim: *dim_head
-
-trainer:
- gradient_clip_val: 1.0
- max_epochs: *epochs
- accumulate_grad_batches: 2
- limit_train_batches: 1.0
- limit_val_batches: 1.0
- limit_test_batches: 1.0
diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml
index f57eead..f3049ea 100644
--- a/training/conf/experiment/vit_lines.yaml
+++ b/training/conf/experiment/vit_lines.yaml
@@ -4,13 +4,13 @@ defaults:
- override /criterion: cross_entropy
- override /callbacks: htr
- override /datamodule: iam_lines
- - override /network: null
+ - override /network: vit_lines
- override /model: lit_transformer
- override /lr_scheduler: null
- override /optimizer: null
tags: [lines, vit]
-epochs: &epochs 256
+epochs: &epochs 128
ignore_index: &ignore_index 3
# summary: [[1, 1, 56, 1024], [1, 89]]
@@ -59,45 +59,6 @@ datamodule:
batch_size: 16
train_fraction: 0.95
-network:
- _target_: text_recognizer.network.vit.VisionTransformer
- image_height: 56
- image_width: 1024
- patch_height: 28
- patch_width: 32
- dim: &dim 1024
- num_classes: &num_classes 58
- encoder:
- _target_: text_recognizer.network.transformer.encoder.Encoder
- dim: *dim
- inner_dim: 2048
- heads: 16
- dim_head: 64
- depth: 4
- dropout_rate: 0.0
- decoder:
- _target_: text_recognizer.network.transformer.decoder.Decoder
- dim: *dim
- inner_dim: 2048
- heads: 16
- dim_head: 64
- depth: 4
- dropout_rate: 0.0
- token_embedding:
- _target_: "text_recognizer.network.transformer.embedding.token.\
- TokenEmbedding"
- num_tokens: *num_classes
- dim: *dim
- use_l2: true
- pos_embedding:
- _target_: "text_recognizer.network.transformer.embedding.absolute.\
- AbsolutePositionalEmbedding"
- dim: *dim
- max_length: 89
- use_l2: true
- tie_embeddings: true
- pad_index: 3
-
model:
max_output_len: 89
@@ -105,7 +66,7 @@ trainer:
fast_dev_run: false
gradient_clip_val: 1.0
max_epochs: *epochs
- accumulate_grad_batches: 4
+ accumulate_grad_batches: 1
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
diff --git a/training/conf/logger/wandb.yaml b/training/conf/logger/wandb.yaml
index ba3218a..09dd3f7 100644
--- a/training/conf/logger/wandb.yaml
+++ b/training/conf/logger/wandb.yaml
@@ -6,7 +6,7 @@ wandb:
offline: false # set True to store all logs only locally
id: null # pass correct id to resume experiment!
# entity: "" # set to name of your wandb team or just remove it
- log_model: false
+ log_model: true
prefix: ""
job_type: train
group: ""
diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml
index 35f83c3..f32cb83 100644
--- a/training/conf/network/vit_lines.yaml
+++ b/training/conf/network/vit_lines.yaml
@@ -3,21 +3,21 @@ image_height: 56
image_width: 1024
patch_height: 28
patch_width: 32
-dim: &dim 256
-num_classes: &num_classes 57
+dim: &dim 1024
+num_classes: &num_classes 58
encoder:
_target_: text_recognizer.network.transformer.encoder.Encoder
dim: *dim
- inner_dim: 1024
- heads: 8
+ inner_dim: 2048
+ heads: 16
dim_head: 64
depth: 6
dropout_rate: 0.0
decoder:
_target_: text_recognizer.network.transformer.decoder.Decoder
dim: *dim
- inner_dim: 1024
- heads: 8
+ inner_dim: 2048
+ heads: 16
dim_head: 64
depth: 6
dropout_rate: 0.0
@@ -33,5 +33,5 @@ pos_embedding:
dim: *dim
max_length: 89
use_l2: true
-tie_embeddings: true
+tie_embeddings: false
pad_index: 3
diff --git a/training/conf/sweep/conv_transformer.yaml b/training/conf/sweep/conv_transformer.yaml
deleted file mode 100644
index 70e0a56..0000000
--- a/training/conf/sweep/conv_transformer.yaml
+++ /dev/null
@@ -1,29 +0,0 @@
-# @package _global_
-
-# Example:
-# python main.py -m sweep=lines experiment=conv_transformer_lines
-
-defaults:
- - override /hydra/sweeper: optuna
-
-optimized_metric: train/loss
-
-hydra:
- mode: "MULTIRUN"
- sweeper:
- _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
- storage: null
- study_name: null
- n_jobs: 1
- direction: minimize
- n_trials: 20
- sampler:
- _target_: optuna.samplers.TPESampler
- seed: 4711
- n_startup_trials: 5
-
- params:
- # optimizer: interval(0.0003, 0.001)
- network.decoder.decoder.depth: choice(4, 8, 10)
- network.decoder.decoder.block.self_attn.num_heads: choice(4, 6, 8)
- network.decoder.decoder.block.cross_attn.num_heads: choice(6, 8, 12)