From 4da7a2c812221d56a430b35139ac40b23fa76f77 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 29 Jun 2021 22:54:52 +0200 Subject: Refactor of config, more granular --- training/run_experiment.py | 54 ++++++++++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 18 deletions(-) (limited to 'training/run_experiment.py') diff --git a/training/run_experiment.py b/training/run_experiment.py index def1e77..b3c9552 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -3,6 +3,9 @@ from datetime import datetime import importlib from pathlib import Path from typing import List, Optional, Type +import warnings + +warnings.filterwarnings("ignore") import hydra from loguru import logger @@ -29,7 +32,7 @@ def _create_experiment_dir(config: DictConfig) -> Path: def _save_config(config: DictConfig, log_dir: Path) -> None: """Saves config to log directory.""" - with (log_dir / "config.yaml").open("r") as f: + with (log_dir / "config.yaml").open("w") as f: OmegaConf.save(config=config, f=f) @@ -52,12 +55,11 @@ def _import_class(module_and_class_name: str) -> type: return getattr(module, class_name) -def _configure_callbacks( - callbacks: List[DictConfig], -) -> List[Type[pl.callbacks.Callback]]: +def _configure_callbacks(callbacks: DictConfig,) -> List[Type[pl.callbacks.Callback]]: """Configures lightning callbacks.""" pl_callbacks = [ - getattr(pl.callbacks, callback.type)(**callback.args) for callback in callbacks + getattr(pl.callbacks, callback.type)(**callback.args) + for callback in callbacks.values() ] return pl_callbacks @@ -77,12 +79,12 @@ def _configure_logger( def _save_best_weights( - callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool + pl_callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool ) -> None: """Saves the best model.""" model_checkpoint_callback = next( callback - for callback in callbacks + for callback in pl_callbacks if isinstance(callback, pl.callbacks.ModelCheckpoint) ) best_model_path = model_checkpoint_callback.best_model_path @@ -97,20 +99,31 @@ def _load_lit_model( lit_model_class: type, network: Type[nn.Module], config: DictConfig ) -> Type[pl.LightningModule]: """Load lightning model.""" - if config.trainer.load_checkpoint is not None: + if config.load_checkpoint is not None: logger.info( - f"Loading network weights from checkpoint: {config.trainer.load_checkpoint}" + f"Loading network weights from checkpoint: {config.load_checkpoint}" ) return lit_model_class.load_from_checkpoint( - config.trainer.load_checkpoint, network=network, **config.model.args + config.load_checkpoint, + network=network, + optimizer=config.optimizer, + criterion=config.criterion, + lr_scheduler=config.lr_scheduler, + **config.model.args, ) - return lit_model_class(network=network, **config.model.args) + return lit_model_class( + network=network, + optimizer=config.optimizer, + criterion=config.criterion, + lr_scheduler=config.lr_scheduler, + **config.model.args, + ) def run(config: DictConfig) -> None: """Runs experiment.""" log_dir = _create_experiment_dir(config) - _configure_logging(log_dir, level=config.trainer.logging) + _configure_logging(log_dir, level=config.logging) logger.info("Starting experiment...") pl.utilities.seed.seed_everything(config.trainer.seed) @@ -125,7 +138,7 @@ def run(config: DictConfig) -> None: network = network_class(**data_module.config(), **config.network.args) # Load callback and logger. - callbacks = _configure_callbacks(config.callbacks) + pl_callbacks = _configure_callbacks(config.callbacks) pl_logger = _configure_logger(network, config, log_dir) # Load ligtning model. @@ -136,12 +149,17 @@ def run(config: DictConfig) -> None: trainer = pl.Trainer( **config.trainer.args, - callbacks=callbacks, + callbacks=pl_callbacks, logger=pl_logger, weights_save_path=str(log_dir), ) - if config.trainer.tune and not config.trainer.args.fast_dev_run: + if config.trainer.args.fast_dev_run: + logger.info("Fast development run...") + trainer.fit(lit_model, datamodule=data_module) + return None + + if config.trainer.tune: logger.info("Tuning learning rate and batch size...") trainer.tune(lit_model, datamodule=data_module) @@ -149,17 +167,17 @@ def run(config: DictConfig) -> None: logger.info("Training network...") trainer.fit(lit_model, datamodule=data_module) - if config.trainer.test and not config.trainer.args.fast_dev_run: + if config.trainer.test: logger.info("Testing network...") trainer.test(lit_model, datamodule=data_module) - if not config.trainer.args.fast_dev_run: - _save_best_weights(callbacks, config.trainer.wandb) + _save_best_weights(pl_callbacks, config.trainer.wandb) @hydra.main(config_path="conf", config_name="config") def main(config: DictConfig) -> None: """Loads config with hydra.""" + print(OmegaConf.to_yaml(config)) run(config) -- cgit v1.2.3-70-g09d2