From 4da7a2c812221d56a430b35139ac40b23fa76f77 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 29 Jun 2021 22:54:52 +0200 Subject: Refactor of config, more granular --- training/conf/callbacks/checkpoint.yaml | 5 +- training/conf/callbacks/early_stopping.yaml | 5 +- training/conf/callbacks/learning_rate_monitor.yaml | 5 +- training/conf/callbacks/swa.yaml | 5 +- training/conf/config.yaml | 6 +++ training/conf/criterion/mse.yaml | 3 ++ training/conf/dataset/iam_extended_paragraphs.yaml | 9 ++-- training/conf/lr_scheduler/one_cycle.yaml | 8 ++++ training/conf/model/lit_vqvae.yaml | 23 +-------- training/conf/network/vqvae.yaml | 23 +++++---- training/conf/optimizer/madgrad.yaml | 6 +++ training/conf/trainer/default.yaml | 23 ++++----- training/run_experiment.py | 54 ++++++++++++++-------- 13 files changed, 97 insertions(+), 78 deletions(-) create mode 100644 training/conf/criterion/mse.yaml create mode 100644 training/conf/lr_scheduler/one_cycle.yaml create mode 100644 training/conf/optimizer/madgrad.yaml diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml index afc536f..f3beb1b 100644 --- a/training/conf/callbacks/checkpoint.yaml +++ b/training/conf/callbacks/checkpoint.yaml @@ -1,5 +1,6 @@ -type: ModelCheckpoint -args: +checkpoint: + type: ModelCheckpoint + args: monitor: val_loss mode: min save_last: true diff --git a/training/conf/callbacks/early_stopping.yaml b/training/conf/callbacks/early_stopping.yaml index caab824..ec671fd 100644 --- a/training/conf/callbacks/early_stopping.yaml +++ b/training/conf/callbacks/early_stopping.yaml @@ -1,5 +1,6 @@ -type: EarlyStopping -args: +early_stopping: + type: EarlyStopping + args: monitor: val_loss mode: min patience: 10 diff --git a/training/conf/callbacks/learning_rate_monitor.yaml b/training/conf/callbacks/learning_rate_monitor.yaml index 003ab7a..11a5ecf 100644 --- a/training/conf/callbacks/learning_rate_monitor.yaml +++ b/training/conf/callbacks/learning_rate_monitor.yaml @@ -1,3 +1,4 @@ -type: LearningRateMonitor -args: +learning_rate_monitor: + type: LearningRateMonitor + args: logging_interval: step diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml index 279ca69..92d9e6b 100644 --- a/training/conf/callbacks/swa.yaml +++ b/training/conf/callbacks/swa.yaml @@ -1,5 +1,6 @@ -type: StochasticWeightAveraging -args: +stochastic_weight_averaging: + type: StochasticWeightAveraging + args: swa_epoch_start: 0.8 swa_lrs: 0.05 annealing_epochs: 10 diff --git a/training/conf/config.yaml b/training/conf/config.yaml index c413a1a..b43e375 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,8 +1,14 @@ defaults: - network: vqvae + - criterion: mse + - optimizer: madgrad + - lr_scheduler: one_cycle - model: lit_vqvae - dataset: iam_extended_paragraphs - trainer: default - callbacks: - checkpoint - learning_rate_monitor + +load_checkpoint: null +logging: INFO diff --git a/training/conf/criterion/mse.yaml b/training/conf/criterion/mse.yaml new file mode 100644 index 0000000..4d89cbc --- /dev/null +++ b/training/conf/criterion/mse.yaml @@ -0,0 +1,3 @@ +type: MSELoss +args: + reduction: mean diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml index 6bd7fc9..6439a15 100644 --- a/training/conf/dataset/iam_extended_paragraphs.yaml +++ b/training/conf/dataset/iam_extended_paragraphs.yaml @@ -1,7 +1,6 @@ -# @package _group_ type: IAMExtendedParagraphs args: - batch_size: 32 - num_workers: 12 - train_fraction: 0.8 - augment: true + batch_size: 32 + num_workers: 12 + train_fraction: 0.8 + augment: true diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml new file mode 100644 index 0000000..60a6f27 --- /dev/null +++ b/training/conf/lr_scheduler/one_cycle.yaml @@ -0,0 +1,8 @@ +type: OneCycleLR +args: + interval: step + max_lr: 1.0e-3 + three_phase: true + epochs: 64 + steps_per_epoch: 633 # num_samples / batch_size +monitor: val_loss diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml index 90780b7..7136dbd 100644 --- a/training/conf/model/lit_vqvae.yaml +++ b/training/conf/model/lit_vqvae.yaml @@ -1,24 +1,3 @@ -# @package _group_ type: LitVQVAEModel args: - optimizer: - type: MADGRAD - args: - lr: 1.0e-3 - momentum: 0.9 - weight_decay: 0 - eps: 1.0e-6 - lr_scheduler: - type: OneCycleLR - args: - interval: step - max_lr: 1.0e-3 - three_phase: true - epochs: 64 - steps_per_epoch: 633 # num_samples / batch_size - criterion: - type: MSELoss - args: - reduction: mean - monitor: val_loss - mapping: sentence_piece + mapping: sentence_piece diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index 288d2aa..22eebf8 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -1,14 +1,13 @@ -# @package _group_ type: VQVAE args: - in_channels: 1 - channels: [64, 96] - kernel_sizes: [4, 4] - strides: [2, 2] - num_residual_layers: 2 - embedding_dim: 64 - num_embeddings: 256 - upsampling: null - beta: 0.25 - activation: leaky_relu - dropout_rate: 0.2 + in_channels: 1 + channels: [64, 96] + kernel_sizes: [4, 4] + strides: [2, 2] + num_residual_layers: 2 + embedding_dim: 64 + num_embeddings: 256 + upsampling: null + beta: 0.25 + activation: leaky_relu + dropout_rate: 0.2 diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml new file mode 100644 index 0000000..2f2cff9 --- /dev/null +++ b/training/conf/optimizer/madgrad.yaml @@ -0,0 +1,6 @@ +type: MADGRAD +args: + lr: 1.0e-3 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index 3a88c6a..5797741 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -1,19 +1,16 @@ -# @package _group_ seed: 4711 -load_checkpoint: null wandb: false tune: false train: true test: true -logging: INFO args: - stochastic_weight_avg: false - auto_scale_batch_size: binsearch - auto_lr_find: false - gradient_clip_val: 0 - fast_dev_run: false - gpus: 1 - precision: 16 - max_epochs: 64 - terminate_on_nan: true - weights_summary: top + stochastic_weight_avg: false + auto_scale_batch_size: binsearch + auto_lr_find: false + gradient_clip_val: 0 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: 64 + terminate_on_nan: true + weights_summary: top diff --git a/training/run_experiment.py b/training/run_experiment.py index def1e77..b3c9552 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -3,6 +3,9 @@ from datetime import datetime import importlib from pathlib import Path from typing import List, Optional, Type +import warnings + +warnings.filterwarnings("ignore") import hydra from loguru import logger @@ -29,7 +32,7 @@ def _create_experiment_dir(config: DictConfig) -> Path: def _save_config(config: DictConfig, log_dir: Path) -> None: """Saves config to log directory.""" - with (log_dir / "config.yaml").open("r") as f: + with (log_dir / "config.yaml").open("w") as f: OmegaConf.save(config=config, f=f) @@ -52,12 +55,11 @@ def _import_class(module_and_class_name: str) -> type: return getattr(module, class_name) -def _configure_callbacks( - callbacks: List[DictConfig], -) -> List[Type[pl.callbacks.Callback]]: +def _configure_callbacks(callbacks: DictConfig,) -> List[Type[pl.callbacks.Callback]]: """Configures lightning callbacks.""" pl_callbacks = [ - getattr(pl.callbacks, callback.type)(**callback.args) for callback in callbacks + getattr(pl.callbacks, callback.type)(**callback.args) + for callback in callbacks.values() ] return pl_callbacks @@ -77,12 +79,12 @@ def _configure_logger( def _save_best_weights( - callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool + pl_callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool ) -> None: """Saves the best model.""" model_checkpoint_callback = next( callback - for callback in callbacks + for callback in pl_callbacks if isinstance(callback, pl.callbacks.ModelCheckpoint) ) best_model_path = model_checkpoint_callback.best_model_path @@ -97,20 +99,31 @@ def _load_lit_model( lit_model_class: type, network: Type[nn.Module], config: DictConfig ) -> Type[pl.LightningModule]: """Load lightning model.""" - if config.trainer.load_checkpoint is not None: + if config.load_checkpoint is not None: logger.info( - f"Loading network weights from checkpoint: {config.trainer.load_checkpoint}" + f"Loading network weights from checkpoint: {config.load_checkpoint}" ) return lit_model_class.load_from_checkpoint( - config.trainer.load_checkpoint, network=network, **config.model.args + config.load_checkpoint, + network=network, + optimizer=config.optimizer, + criterion=config.criterion, + lr_scheduler=config.lr_scheduler, + **config.model.args, ) - return lit_model_class(network=network, **config.model.args) + return lit_model_class( + network=network, + optimizer=config.optimizer, + criterion=config.criterion, + lr_scheduler=config.lr_scheduler, + **config.model.args, + ) def run(config: DictConfig) -> None: """Runs experiment.""" log_dir = _create_experiment_dir(config) - _configure_logging(log_dir, level=config.trainer.logging) + _configure_logging(log_dir, level=config.logging) logger.info("Starting experiment...") pl.utilities.seed.seed_everything(config.trainer.seed) @@ -125,7 +138,7 @@ def run(config: DictConfig) -> None: network = network_class(**data_module.config(), **config.network.args) # Load callback and logger. - callbacks = _configure_callbacks(config.callbacks) + pl_callbacks = _configure_callbacks(config.callbacks) pl_logger = _configure_logger(network, config, log_dir) # Load ligtning model. @@ -136,12 +149,17 @@ def run(config: DictConfig) -> None: trainer = pl.Trainer( **config.trainer.args, - callbacks=callbacks, + callbacks=pl_callbacks, logger=pl_logger, weights_save_path=str(log_dir), ) - if config.trainer.tune and not config.trainer.args.fast_dev_run: + if config.trainer.args.fast_dev_run: + logger.info("Fast development run...") + trainer.fit(lit_model, datamodule=data_module) + return None + + if config.trainer.tune: logger.info("Tuning learning rate and batch size...") trainer.tune(lit_model, datamodule=data_module) @@ -149,17 +167,17 @@ def run(config: DictConfig) -> None: logger.info("Training network...") trainer.fit(lit_model, datamodule=data_module) - if config.trainer.test and not config.trainer.args.fast_dev_run: + if config.trainer.test: logger.info("Testing network...") trainer.test(lit_model, datamodule=data_module) - if not config.trainer.args.fast_dev_run: - _save_best_weights(callbacks, config.trainer.wandb) + _save_best_weights(pl_callbacks, config.trainer.wandb) @hydra.main(config_path="conf", config_name="config") def main(config: DictConfig) -> None: """Loads config with hydra.""" + print(OmegaConf.to_yaml(config)) run(config) -- cgit v1.2.3-70-g09d2