diff options
-rw-r--r-- | training/callbacks/wandb.py | 34 | ||||
-rw-r--r-- | training/conf/callbacks/default.yaml | 1 | ||||
-rw-r--r-- | training/conf/callbacks/lightning/swa.yaml (renamed from training/conf/callbacks/wandb/swa.yaml) | 0 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/captions.yaml | 2 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/checkpoints.yaml | 4 | ||||
-rw-r--r-- | training/conf/experiment/conv_transformer_lines.yaml | 134 | ||||
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 132 | ||||
-rw-r--r-- | training/conf/experiment/vit_lines.yaml | 45 | ||||
-rw-r--r-- | training/conf/logger/wandb.yaml | 2 | ||||
-rw-r--r-- | training/conf/network/vit_lines.yaml | 14 | ||||
-rw-r--r-- | training/conf/sweep/conv_transformer.yaml | 29 |
11 files changed, 15 insertions, 382 deletions
diff --git a/training/callbacks/wandb.py b/training/callbacks/wandb.py index d9bb9b8..6adbebe 100644 --- a/training/callbacks/wandb.py +++ b/training/callbacks/wandb.py @@ -62,32 +62,7 @@ class UploadConfigAsArtifact(Callback): experiment.use_artifact(artifact) -class UploadCheckpointsAsArtifact(Callback): - """Upload checkpoint to wandb as an artifact, at the end of a run.""" - - def __init__( - self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False - ) -> None: - self.ckpt_dir = Path(ckpt_dir) - self.upload_best_only = upload_best_only - - @rank_zero_only - def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Uploads model checkpoint to W&B.""" - logger = get_wandb_logger(trainer) - experiment = logger.experiment - ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") - - if self.upload_best_only: - ckpts.add_file(trainer.checkpoint_callback.best_model_path) - else: - for ckpt in (self.ckpt_dir).rglob("*.ckpt"): - ckpts.add_file(ckpt) - - experiment.use_artifact(ckpts) - - -class ImageToCaptionLogger(Callback): +class ImageToCaption(Callback): """Logs the image and output caption.""" def __init__(self, num_samples: int = 8, on_train: bool = True) -> None: @@ -114,7 +89,7 @@ class ImageToCaptionLogger(Callback): pl_module: LightningModule, outputs: dict, batch: Tuple[Tensor, Tensor], - batch_idx: int, + *args, ) -> None: """Logs predictions on validation batch end.""" if self.has_metrics(outputs): @@ -127,9 +102,7 @@ class ImageToCaptionLogger(Callback): pl_module: LightningModule, outputs: dict, batch: Tuple[Tensor, Tensor], - batch_idx: int, *args, - # dataloader_idx: int, ) -> None: """Logs predictions on validation batch end.""" if self.has_metrics(outputs): @@ -142,8 +115,7 @@ class ImageToCaptionLogger(Callback): pl_module: LightningModule, outputs: dict, batch: Tuple[Tensor, Tensor], - batch_idx: int, - dataloader_idx: int, + *args, ) -> None: """Logs predictions on train batch end.""" if self.has_metrics(outputs): diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml index 4d8e399..9b1347b 100644 --- a/training/conf/callbacks/default.yaml +++ b/training/conf/callbacks/default.yaml @@ -3,4 +3,3 @@ defaults: - lightning/learning_rate_monitor - wandb/watch - wandb/config - - wandb/checkpoints diff --git a/training/conf/callbacks/wandb/swa.yaml b/training/conf/callbacks/lightning/swa.yaml index 73f8c66..73f8c66 100644 --- a/training/conf/callbacks/wandb/swa.yaml +++ b/training/conf/callbacks/lightning/swa.yaml diff --git a/training/conf/callbacks/wandb/captions.yaml b/training/conf/callbacks/wandb/captions.yaml index 3215a90..64f44c2 100644 --- a/training/conf/callbacks/wandb/captions.yaml +++ b/training/conf/callbacks/wandb/captions.yaml @@ -1,4 +1,4 @@ log_text_predictions: - _target_: callbacks.wandb.ImageToCaptionLogger + _target_: callbacks.wandb.ImageToCaption num_samples: 8 on_train: true diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb/checkpoints.yaml deleted file mode 100644 index b9a3fd7..0000000 --- a/training/conf/callbacks/wandb/checkpoints.yaml +++ /dev/null @@ -1,4 +0,0 @@ -upload_ckpts_as_artifact: - _target_: callbacks.wandb.UploadCheckpointsAsArtifact - ckpt_dir: checkpoints/ - upload_best_only: true diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml deleted file mode 100644 index 12fe701..0000000 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ /dev/null @@ -1,134 +0,0 @@ -# @package _global_ - -defaults: - - override /criterion: cross_entropy - - override /callbacks: htr - - override /datamodule: iam_lines - - override /network: null - - override /model: lit_transformer - - override /lr_scheduler: null - - override /optimizer: null - -tags: [lines] -epochs: &epochs 64 -ignore_index: &ignore_index 3 -# summary: [[1, 1, 56, 1024], [1, 89]] - -logger: - wandb: - tags: ${tags} - -criterion: - ignore_index: *ignore_index - # label_smoothing: 0.05 - -callbacks: - stochastic_weight_averaging: - _target_: pytorch_lightning.callbacks.StochasticWeightAveraging - swa_epoch_start: 0.75 - swa_lrs: 1.0e-5 - annealing_epochs: 10 - annealing_strategy: cos - device: null - -optimizer: - _target_: adan_pytorch.Adan - lr: 3.0e-4 - betas: [0.02, 0.08, 0.01] - weight_decay: 0.02 - -lr_scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - mode: min - factor: 0.8 - patience: 10 - threshold: 1.0e-4 - threshold_mode: rel - cooldown: 0 - min_lr: 1.0e-5 - eps: 1.0e-8 - verbose: false - interval: epoch - monitor: val/cer - -datamodule: - batch_size: 16 - train_fraction: 0.95 - -network: - _target_: text_recognizer.network.ConvTransformer - encoder: - _target_: text_recognizer.network.image_encoder.ImageEncoder - encoder: - _target_: text_recognizer.network.convnext.ConvNext - dim: 16 - dim_mults: [2, 4, 32] - depths: [3, 3, 6] - downsampling_factors: [[2, 2], [2, 2], [2, 2]] - attn: - _target_: text_recognizer.network.convnext.TransformerBlock - attn: - _target_: text_recognizer.network.convnext.Attention - dim: &dim 512 - heads: 4 - dim_head: 64 - scale: 8 - ff: - _target_: text_recognizer.network.convnext.FeedForward - dim: *dim - mult: 2 - pixel_embedding: - _target_: "text_recognizer.network.transformer.embeddings.axial.\ - AxialPositionalEmbeddingImage" - dim: *dim - axial_shape: [7, 128] - decoder: - _target_: text_recognizer.network.text_decoder.TextDecoder - dim: *dim - num_classes: 58 - pad_index: *ignore_index - decoder: - _target_: text_recognizer.network.transformer.Decoder - dim: *dim - depth: 6 - block: - _target_: "text_recognizer.network.transformer.decoder_block.\ - DecoderBlock" - self_attn: - _target_: text_recognizer.network.transformer.Attention - dim: *dim - num_heads: 8 - dim_head: &dim_head 64 - dropout_rate: &dropout_rate 0.2 - causal: true - cross_attn: - _target_: text_recognizer.network.transformer.Attention - dim: *dim - num_heads: 8 - dim_head: *dim_head - dropout_rate: *dropout_rate - causal: false - norm: - _target_: text_recognizer.network.transformer.RMSNorm - dim: *dim - ff: - _target_: text_recognizer.network.transformer.FeedForward - dim: *dim - dim_out: null - expansion_factor: 2 - glu: true - dropout_rate: *dropout_rate - rotary_embedding: - _target_: text_recognizer.network.transformer.RotaryEmbedding - dim: *dim_head - -model: - max_output_len: 89 - -trainer: - gradient_clip_val: 1.0 - max_epochs: *epochs - accumulate_grad_batches: 1 - limit_train_batches: 1.0 - limit_val_batches: 1.0 - limit_test_batches: 1.0 diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml deleted file mode 100644 index 9df2ea9..0000000 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ /dev/null @@ -1,132 +0,0 @@ -# @package _global_ - -defaults: - - override /criterion: cross_entropy - - override /callbacks: htr - - override /datamodule: iam_extended_paragraphs - - override /network: null - - override /model: lit_transformer - - override /lr_scheduler: null - - override /optimizer: null - -tags: [paragraphs] -epochs: &epochs 256 -ignore_index: &ignore_index 3 -# max_output_len: &max_output_len 682 -# summary: [[1, 1, 576, 640], [1, 682]] - -logger: - wandb: - tags: ${tags} - -criterion: - ignore_index: *ignore_index - # label_smoothing: 0.05 - -callbacks: - stochastic_weight_averaging: - _target_: pytorch_lightning.callbacks.StochasticWeightAveraging - swa_epoch_start: 0.75 - swa_lrs: 1.0e-5 - annealing_epochs: 10 - annealing_strategy: cos - device: null - -optimizer: - _target_: adan_pytorch.Adan - lr: 3.0e-4 - betas: [0.02, 0.08, 0.01] - weight_decay: 0.02 - -lr_scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - mode: min - factor: 0.8 - patience: 10 - threshold: 1.0e-4 - threshold_mode: rel - cooldown: 0 - min_lr: 1.0e-5 - eps: 1.0e-8 - verbose: false - interval: epoch - monitor: val/cer - -datamodule: - batch_size: 4 - train_fraction: 0.95 - -network: - _target_: text_recognizer.network.ConvTransformer - encoder: - _target_: text_recognizer.network.image_encoder.ImageEncoder - encoder: - _target_: text_recognizer.network.convnext.ConvNext - dim: 16 - dim_mults: [1, 2, 4, 8, 32] - depths: [2, 3, 3, 3, 6] - downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]] - attn: - _target_: text_recognizer.network.convnext.TransformerBlock - attn: - _target_: text_recognizer.network.convnext.Attention - dim: &dim 512 - heads: 4 - dim_head: 64 - scale: 8 - ff: - _target_: text_recognizer.network.convnext.FeedForward - dim: *dim - mult: 2 - pixel_embedding: - _target_: "text_recognizer.network.transformer.embeddings.axial.\ - AxialPositionalEmbeddingImage" - dim: *dim - axial_shape: [18, 80] - decoder: - _target_: text_recognizer.network.text_decoder.TextDecoder - dim: *dim - num_classes: 58 - pad_index: *ignore_index - decoder: - _target_: text_recognizer.network.transformer.Decoder - dim: *dim - depth: 6 - block: - _target_: "text_recognizer.network.transformer.decoder_block.\ - DecoderBlock" - self_attn: - _target_: text_recognizer.network.transformer.Attention - dim: *dim - num_heads: 8 - dim_head: &dim_head 64 - dropout_rate: &dropout_rate 0.2 - causal: true - cross_attn: - _target_: text_recognizer.network.transformer.Attention - dim: *dim - num_heads: 8 - dim_head: *dim_head - dropout_rate: *dropout_rate - causal: false - norm: - _target_: text_recognizer.network.transformer.RMSNorm - dim: *dim - ff: - _target_: text_recognizer.network.transformer.FeedForward - dim: *dim - dim_out: null - expansion_factor: 2 - glu: true - dropout_rate: *dropout_rate - rotary_embedding: - _target_: text_recognizer.network.transformer.RotaryEmbedding - dim: *dim_head - -trainer: - gradient_clip_val: 1.0 - max_epochs: *epochs - accumulate_grad_batches: 2 - limit_train_batches: 1.0 - limit_val_batches: 1.0 - limit_test_batches: 1.0 diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml index f57eead..f3049ea 100644 --- a/training/conf/experiment/vit_lines.yaml +++ b/training/conf/experiment/vit_lines.yaml @@ -4,13 +4,13 @@ defaults: - override /criterion: cross_entropy - override /callbacks: htr - override /datamodule: iam_lines - - override /network: null + - override /network: vit_lines - override /model: lit_transformer - override /lr_scheduler: null - override /optimizer: null tags: [lines, vit] -epochs: &epochs 256 +epochs: &epochs 128 ignore_index: &ignore_index 3 # summary: [[1, 1, 56, 1024], [1, 89]] @@ -59,45 +59,6 @@ datamodule: batch_size: 16 train_fraction: 0.95 -network: - _target_: text_recognizer.network.vit.VisionTransformer - image_height: 56 - image_width: 1024 - patch_height: 28 - patch_width: 32 - dim: &dim 1024 - num_classes: &num_classes 58 - encoder: - _target_: text_recognizer.network.transformer.encoder.Encoder - dim: *dim - inner_dim: 2048 - heads: 16 - dim_head: 64 - depth: 4 - dropout_rate: 0.0 - decoder: - _target_: text_recognizer.network.transformer.decoder.Decoder - dim: *dim - inner_dim: 2048 - heads: 16 - dim_head: 64 - depth: 4 - dropout_rate: 0.0 - token_embedding: - _target_: "text_recognizer.network.transformer.embedding.token.\ - TokenEmbedding" - num_tokens: *num_classes - dim: *dim - use_l2: true - pos_embedding: - _target_: "text_recognizer.network.transformer.embedding.absolute.\ - AbsolutePositionalEmbedding" - dim: *dim - max_length: 89 - use_l2: true - tie_embeddings: true - pad_index: 3 - model: max_output_len: 89 @@ -105,7 +66,7 @@ trainer: fast_dev_run: false gradient_clip_val: 1.0 max_epochs: *epochs - accumulate_grad_batches: 4 + accumulate_grad_batches: 1 limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 diff --git a/training/conf/logger/wandb.yaml b/training/conf/logger/wandb.yaml index ba3218a..09dd3f7 100644 --- a/training/conf/logger/wandb.yaml +++ b/training/conf/logger/wandb.yaml @@ -6,7 +6,7 @@ wandb: offline: false # set True to store all logs only locally id: null # pass correct id to resume experiment! # entity: "" # set to name of your wandb team or just remove it - log_model: false + log_model: true prefix: "" job_type: train group: "" diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml index 35f83c3..f32cb83 100644 --- a/training/conf/network/vit_lines.yaml +++ b/training/conf/network/vit_lines.yaml @@ -3,21 +3,21 @@ image_height: 56 image_width: 1024 patch_height: 28 patch_width: 32 -dim: &dim 256 -num_classes: &num_classes 57 +dim: &dim 1024 +num_classes: &num_classes 58 encoder: _target_: text_recognizer.network.transformer.encoder.Encoder dim: *dim - inner_dim: 1024 - heads: 8 + inner_dim: 2048 + heads: 16 dim_head: 64 depth: 6 dropout_rate: 0.0 decoder: _target_: text_recognizer.network.transformer.decoder.Decoder dim: *dim - inner_dim: 1024 - heads: 8 + inner_dim: 2048 + heads: 16 dim_head: 64 depth: 6 dropout_rate: 0.0 @@ -33,5 +33,5 @@ pos_embedding: dim: *dim max_length: 89 use_l2: true -tie_embeddings: true +tie_embeddings: false pad_index: 3 diff --git a/training/conf/sweep/conv_transformer.yaml b/training/conf/sweep/conv_transformer.yaml deleted file mode 100644 index 70e0a56..0000000 --- a/training/conf/sweep/conv_transformer.yaml +++ /dev/null @@ -1,29 +0,0 @@ -# @package _global_ - -# Example: -# python main.py -m sweep=lines experiment=conv_transformer_lines - -defaults: - - override /hydra/sweeper: optuna - -optimized_metric: train/loss - -hydra: - mode: "MULTIRUN" - sweeper: - _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper - storage: null - study_name: null - n_jobs: 1 - direction: minimize - n_trials: 20 - sampler: - _target_: optuna.samplers.TPESampler - seed: 4711 - n_startup_trials: 5 - - params: - # optimizer: interval(0.0003, 0.001) - network.decoder.decoder.depth: choice(4, 8, 10) - network.decoder.decoder.block.self_attn.num_heads: choice(4, 6, 8) - network.decoder.decoder.block.cross_attn.num_heads: choice(6, 8, 12) |