From 6968572c1a21394b88a29f675b17b9698784a898 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 25 Aug 2023 23:19:39 +0200 Subject: Update training stuff --- training/callbacks/wandb_callbacks.py | 23 ++--- training/conf/callbacks/wandb/watch.yaml | 2 +- training/conf/config.yaml | 2 +- training/conf/decoder/greedy.yaml | 2 +- .../conf/experiment/conv_transformer_lines.yaml | 30 +++--- .../experiment/conv_transformer_paragraphs.yaml | 30 +++--- training/conf/experiment/vit_lines.yaml | 113 +++++++++++++++++++++ training/conf/logger/csv.yaml | 4 + training/conf/logger/wandb.yaml | 2 +- training/conf/model/lit_transformer.yaml | 2 +- training/conf/network/conv_transformer.yaml | 26 ++--- training/conf/network/convnext.yaml | 8 +- training/conf/network/vit_lines.yaml | 37 +++++++ training/conf/trainer/default.yaml | 6 +- training/main.py | 4 +- training/run.py | 4 +- training/utils.py | 8 +- 17 files changed, 227 insertions(+), 76 deletions(-) create mode 100644 training/conf/experiment/vit_lines.yaml create mode 100644 training/conf/logger/csv.yaml create mode 100644 training/conf/network/vit_lines.yaml diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 7d4f48e..1c7955c 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -3,23 +3,17 @@ from pathlib import Path import wandb from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import LoggerCollection, WandbLogger +from pytorch_lightning.loggers import Logger, WandbLogger from pytorch_lightning.utilities import rank_zero_only -from torch import nn from torch.utils.data import DataLoader -from torchvision.utils import make_grid def get_wandb_logger(trainer: Trainer) -> WandbLogger: """Safely get W&B logger from Trainer.""" - if isinstance(trainer.logger, WandbLogger): - return trainer.logger - - if isinstance(trainer.logger, LoggerCollection): - for logger in trainer.logger: - if isinstance(logger, WandbLogger): - return logger + for logger in trainer.loggers: + if isinstance(logger, WandbLogger): + return logger raise Exception("Weight and Biases logger not found for some reason...") @@ -28,9 +22,12 @@ class WatchModel(Callback): """Make W&B watch the model at the beginning of the run.""" def __init__( - self, log: str = "gradients", log_freq: int = 100, log_graph: bool = False + self, + log_params: str = "gradients", + log_freq: int = 100, + log_graph: bool = False, ) -> None: - self.log = log + self.log_params = log_params self.log_freq = log_freq self.log_graph = log_graph @@ -40,7 +37,7 @@ class WatchModel(Callback): logger = get_wandb_logger(trainer) logger.watch( model=trainer.model, - log=self.log, + log=self.log_params, log_freq=self.log_freq, log_graph=self.log_graph, ) diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml index bada03b..660ae47 100644 --- a/training/conf/callbacks/wandb/watch.yaml +++ b/training/conf/callbacks/wandb/watch.yaml @@ -1,5 +1,5 @@ watch_model: _target_: callbacks.wandb_callbacks.WatchModel - log: gradients + log_params: gradients log_freq: 100 log_graph: true diff --git a/training/conf/config.yaml b/training/conf/config.yaml index e57a8a8..8a1317c 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -13,7 +13,7 @@ defaults: - network: conv_transformer - optimizer: radam - trainer: default - - experiment: null + - experiment: vit_lines seed: 4711 tune: false diff --git a/training/conf/decoder/greedy.yaml b/training/conf/decoder/greedy.yaml index 1d1a131..a88b5a6 100644 --- a/training/conf/decoder/greedy.yaml +++ b/training/conf/decoder/greedy.yaml @@ -1,2 +1,2 @@ -_target_: text_recognizer.models.greedy_decoder.GreedyDecoder +_target_: text_recognizer.model.greedy_decoder.GreedyDecoder max_output_len: 682 diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index 948968a..12fe701 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -56,70 +56,70 @@ datamodule: train_fraction: 0.95 network: - _target_: text_recognizer.networks.ConvTransformer + _target_: text_recognizer.network.ConvTransformer encoder: - _target_: text_recognizer.networks.image_encoder.ImageEncoder + _target_: text_recognizer.network.image_encoder.ImageEncoder encoder: - _target_: text_recognizer.networks.convnext.ConvNext + _target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [2, 4, 32] depths: [3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2]] attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.network.convnext.TransformerBlock attn: - _target_: text_recognizer.networks.convnext.Attention + _target_: text_recognizer.network.convnext.Attention dim: &dim 512 heads: 4 dim_head: 64 scale: 8 ff: - _target_: text_recognizer.networks.convnext.FeedForward + _target_: text_recognizer.network.convnext.FeedForward dim: *dim mult: 2 pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + _target_: "text_recognizer.network.transformer.embeddings.axial.\ AxialPositionalEmbeddingImage" dim: *dim axial_shape: [7, 128] decoder: - _target_: text_recognizer.networks.text_decoder.TextDecoder + _target_: text_recognizer.network.text_decoder.TextDecoder dim: *dim num_classes: 58 pad_index: *ignore_index decoder: - _target_: text_recognizer.networks.transformer.Decoder + _target_: text_recognizer.network.transformer.Decoder dim: *dim depth: 6 block: - _target_: "text_recognizer.networks.transformer.decoder_block.\ + _target_: "text_recognizer.network.transformer.decoder_block.\ DecoderBlock" self_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: &dim_head 64 dropout_rate: &dropout_rate 0.2 causal: true cross_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: *dim_head dropout_rate: *dropout_rate causal: false norm: - _target_: text_recognizer.networks.transformer.RMSNorm + _target_: text_recognizer.network.transformer.RMSNorm dim: *dim ff: - _target_: text_recognizer.networks.transformer.FeedForward + _target_: text_recognizer.network.transformer.FeedForward dim: *dim dim_out: null expansion_factor: 2 glu: true dropout_rate: *dropout_rate rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding + _target_: text_recognizer.network.transformer.RotaryEmbedding dim: *dim_head model: diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index ff931cc..9df2ea9 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -57,70 +57,70 @@ datamodule: train_fraction: 0.95 network: - _target_: text_recognizer.networks.ConvTransformer + _target_: text_recognizer.network.ConvTransformer encoder: - _target_: text_recognizer.networks.image_encoder.ImageEncoder + _target_: text_recognizer.network.image_encoder.ImageEncoder encoder: - _target_: text_recognizer.networks.convnext.ConvNext + _target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [1, 2, 4, 8, 32] depths: [2, 3, 3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]] attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.network.convnext.TransformerBlock attn: - _target_: text_recognizer.networks.convnext.Attention + _target_: text_recognizer.network.convnext.Attention dim: &dim 512 heads: 4 dim_head: 64 scale: 8 ff: - _target_: text_recognizer.networks.convnext.FeedForward + _target_: text_recognizer.network.convnext.FeedForward dim: *dim mult: 2 pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + _target_: "text_recognizer.network.transformer.embeddings.axial.\ AxialPositionalEmbeddingImage" dim: *dim axial_shape: [18, 80] decoder: - _target_: text_recognizer.networks.text_decoder.TextDecoder + _target_: text_recognizer.network.text_decoder.TextDecoder dim: *dim num_classes: 58 pad_index: *ignore_index decoder: - _target_: text_recognizer.networks.transformer.Decoder + _target_: text_recognizer.network.transformer.Decoder dim: *dim depth: 6 block: - _target_: "text_recognizer.networks.transformer.decoder_block.\ + _target_: "text_recognizer.network.transformer.decoder_block.\ DecoderBlock" self_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: &dim_head 64 dropout_rate: &dropout_rate 0.2 causal: true cross_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: *dim_head dropout_rate: *dropout_rate causal: false norm: - _target_: text_recognizer.networks.transformer.RMSNorm + _target_: text_recognizer.network.transformer.RMSNorm dim: *dim ff: - _target_: text_recognizer.networks.transformer.FeedForward + _target_: text_recognizer.network.transformer.FeedForward dim: *dim dim_out: null expansion_factor: 2 glu: true dropout_rate: *dropout_rate rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding + _target_: text_recognizer.network.transformer.RotaryEmbedding dim: *dim_head trainer: diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml new file mode 100644 index 0000000..e2ddebf --- /dev/null +++ b/training/conf/experiment/vit_lines.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +defaults: + - override /criterion: cross_entropy + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: null + - override /model: lit_transformer + - override /lr_scheduler: null + - override /optimizer: null + +tags: [lines, vit] +epochs: &epochs 64 +ignore_index: &ignore_index 3 +# summary: [[1, 1, 56, 1024], [1, 89]] + +logger: + wandb: + tags: ${tags} + +criterion: + ignore_index: *ignore_index + # label_smoothing: 0.05 + + +decoder: + max_output_len: 89 + +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.75 + swa_lrs: 1.0e-5 + annealing_epochs: 10 + annealing_strategy: cos + device: null + +optimizer: + _target_: adan_pytorch.Adan + lr: 3.0e-4 + betas: [0.02, 0.08, 0.01] + weight_decay: 0.02 + +lr_scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 + verbose: false + interval: epoch + monitor: val/cer + +datamodule: + batch_size: 8 + train_fraction: 0.95 + +network: + _target_: text_recognizer.network.vit.VisionTransformer + image_height: 56 + image_width: 1024 + patch_height: 28 + patch_width: 32 + dim: &dim 1024 + num_classes: &num_classes 58 + encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + inner_dim: 2048 + heads: 16 + dim_head: 64 + depth: 4 + dropout_rate: 0.0 + decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + inner_dim: 2048 + heads: 16 + dim_head: 64 + depth: 4 + dropout_rate: 0.0 + token_embedding: + _target_: "text_recognizer.network.transformer.embedding.token.\ + TokenEmbedding" + num_tokens: *num_classes + dim: *dim + use_l2: true + pos_embedding: + _target_: "text_recognizer.network.transformer.embedding.absolute.\ + AbsolutePositionalEmbedding" + dim: *dim + max_length: 89 + use_l2: true + tie_embeddings: false + pad_index: 3 + +model: + max_output_len: 89 + +trainer: + fast_dev_run: false + gradient_clip_val: 1.0 + max_epochs: *epochs + accumulate_grad_batches: 1 + limit_val_batches: .02 + limit_test_batches: .02 + limit_train_batches: 1.0 + # limit_val_batches: 1.0 + # limit_test_batches: 1.0 diff --git a/training/conf/logger/csv.yaml b/training/conf/logger/csv.yaml new file mode 100644 index 0000000..9fa6cad --- /dev/null +++ b/training/conf/logger/csv.yaml @@ -0,0 +1,4 @@ +csv: + _target_: pytorch_lightning.loggers.CSVLogger + name: null + save_dir: "." diff --git a/training/conf/logger/wandb.yaml b/training/conf/logger/wandb.yaml index 081ebeb..ba3218a 100644 --- a/training/conf/logger/wandb.yaml +++ b/training/conf/logger/wandb.yaml @@ -1,5 +1,5 @@ wandb: - _target_: pytorch_lightning.loggers.wandb.WandbLogger + _target_: pytorch_lightning.loggers.WandbLogger project: text-recognizer name: null save_dir: "." diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml index e6af035..533f8f3 100644 --- a/training/conf/model/lit_transformer.yaml +++ b/training/conf/model/lit_transformer.yaml @@ -1,2 +1,2 @@ -_target_: text_recognizer.models.LitTransformer +_target_: text_recognizer.model.transformer.LitTransformer max_output_len: 682 diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index 016adbb..1e03946 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -1,58 +1,58 @@ -_target_: text_recognizer.networks.ConvTransformer +_target_: text_recognizer.network.ConvTransformer encoder: - _target_: text_recognizer.networks.image_encoder.ImageEncoder + _target_: text_recognizer.network.image_encoder.ImageEncoder encoder: - _target_: text_recognizer.networks.convnext.ConvNext + _target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [2, 4, 8] depths: [3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2]] pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + _target_: "text_recognizer.network.transformer.embeddings.axial.\ AxialPositionalEmbeddingImage" dim: &hidden_dim 128 axial_shape: [7, 128] axial_dims: [64, 64] decoder: - _target_: text_recognizer.networks.text_decoder.TextDecoder + _target_: text_recognizer.network.text_decoder.TextDecoder hidden_dim: *hidden_dim num_classes: 58 pad_index: 3 decoder: - _target_: text_recognizer.networks.transformer.Decoder + _target_: text_recognizer.network.transformer.Decoder dim: *hidden_dim depth: 10 block: - _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock + _target_: text_recognizer.network.transformer.decoder_block.DecoderBlock self_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *hidden_dim num_heads: 12 dim_head: 64 dropout_rate: &dropout_rate 0.2 causal: true cross_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *hidden_dim num_heads: 12 dim_head: 64 dropout_rate: *dropout_rate causal: false norm: - _target_: text_recognizer.networks.transformer.RMSNorm + _target_: text_recognizer.network.transformer.RMSNorm dim: *hidden_dim ff: - _target_: text_recognizer.networks.transformer.FeedForward + _target_: text_recognizer.network.transformer.FeedForward dim: *hidden_dim dim_out: null expansion_factor: 2 glu: true dropout_rate: *dropout_rate rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding + _target_: text_recognizer.network.transformer.RotaryEmbedding dim: 64 token_pos_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + _target_: "text_recognizer.network.transformer.embeddings.fourier.\ PositionalEncoding" dim: *hidden_dim dropout_rate: 0.1 diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml index 63ad424..904bd56 100644 --- a/training/conf/network/convnext.yaml +++ b/training/conf/network/convnext.yaml @@ -1,17 +1,17 @@ -_target_: text_recognizer.networks.convnext.ConvNext +_target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [2, 4, 8] depths: [3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2]] attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.network.convnext.TransformerBlock attn: - _target_: text_recognizer.networks.convnext.Attention + _target_: text_recognizer.network.convnext.Attention dim: 128 heads: 4 dim_head: 64 scale: 8 ff: - _target_: text_recognizer.networks.convnext.FeedForward + _target_: text_recognizer.network.convnext.FeedForward dim: 128 mult: 4 diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml new file mode 100644 index 0000000..35f83c3 --- /dev/null +++ b/training/conf/network/vit_lines.yaml @@ -0,0 +1,37 @@ +_target_: text_recognizer.network.vit.VisionTransformer +image_height: 56 +image_width: 1024 +patch_height: 28 +patch_width: 32 +dim: &dim 256 +num_classes: &num_classes 57 +encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + inner_dim: 1024 + heads: 8 + dim_head: 64 + depth: 6 + dropout_rate: 0.0 +decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + inner_dim: 1024 + heads: 8 + dim_head: 64 + depth: 6 + dropout_rate: 0.0 +token_embedding: + _target_: "text_recognizer.network.transformer.embedding.token.\ + TokenEmbedding" + num_tokens: *num_classes + dim: *dim + use_l2: true +pos_embedding: + _target_: "text_recognizer.network.transformer.embedding.absolute.\ + AbsolutePositionalEmbedding" + dim: *dim + max_length: 89 + use_l2: true +tie_embeddings: true +pad_index: 3 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index 6112cd8..2e593e8 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -1,14 +1,12 @@ _target_: pytorch_lightning.Trainer -auto_scale_batch_size: binsearch -auto_lr_find: false gradient_clip_val: 0.5 fast_dev_run: false -gpus: 1 +accelerator: gpu +devices: 1 precision: 16 max_epochs: 256 limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 -resume_from_checkpoint: null accumulate_grad_batches: 1 overfit_batches: 0 diff --git a/training/main.py b/training/main.py index 6e54a23..73adda0 100644 --- a/training/main.py +++ b/training/main.py @@ -4,7 +4,9 @@ from omegaconf import DictConfig from training.metadata import TRAINING_DIR -@hydra.main(version_base="1.2", config_path=TRAINING_DIR / "conf", config_name="config") +@hydra.main( + version_base="1.2", config_path=str(TRAINING_DIR / "conf"), config_name="config" +) def main(config: DictConfig) -> None: """Loads config with hydra and runs the experiment.""" import utils diff --git a/training/run.py b/training/run.py index 288a1ef..cffc3ae 100644 --- a/training/run.py +++ b/training/run.py @@ -11,7 +11,7 @@ from pytorch_lightning import ( seed_everything, Trainer, ) -from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loggers import Logger from torch import nn from torchinfo import summary import utils @@ -55,7 +55,7 @@ def run(config: DictConfig) -> Optional[float]: # Load callback and logger. callbacks: List[Type[Callback]] = utils.configure_callbacks(config) - logger: List[Type[LightningLoggerBase]] = utils.configure_logger(config) + logger: List[Type[Logger]] = utils.configure_logger(config) log.info(f"Instantiating trainer <{config.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate( diff --git a/training/utils.py b/training/utils.py index c8ea1be..d1801a7 100644 --- a/training/utils.py +++ b/training/utils.py @@ -10,7 +10,7 @@ from pytorch_lightning import ( LightningModule, Trainer, ) -from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loggers import Logger from pytorch_lightning.loggers.wandb import WandbLogger from pytorch_lightning.utilities import rank_zero_only from tqdm import tqdm @@ -59,10 +59,10 @@ def configure_callbacks( return callbacks -def configure_logger(config: DictConfig) -> List[Type[LightningLoggerBase]]: +def configure_logger(config: DictConfig) -> List[Type[Logger]]: """Configures Lightning loggers.""" - def load_logger(logger_config: DictConfig) -> Type[LightningLoggerBase]: + def load_logger(logger_config: DictConfig) -> Type[Logger]: log.info(f"Instantiating logger <{logger_config._target_}>") return hydra.utils.instantiate(logger_config) @@ -137,7 +137,7 @@ def log_hyperparameters( def finish( - logger: List[Type[LightningLoggerBase]], + logger: List[Type[Logger]], ) -> None: """Makes sure everything closed properly.""" for lg in logger: -- cgit v1.2.3-70-g09d2