From 437ba4e22b2dad2ca309085a2e97f33dd03eb642 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Jul 2021 00:21:43 +0200 Subject: Working on new hydra implementation --- training/run.py | 95 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 training/run.py (limited to 'training/run.py') diff --git a/training/run.py b/training/run.py new file mode 100644 index 0000000..ed1b372 --- /dev/null +++ b/training/run.py @@ -0,0 +1,95 @@ +"""Script to run experiments.""" +from typing import List, Optional, Type + +import hydra +import loguru.logger as log +from omegaconf import DictConfig +from pytorch_lightning import ( + Callback, + LightningDataModule, + LightningModule, + Trainer, + seed_everything, +) +from pytorch_lightning.loggers import LightningLoggerBase +from torch import nn + +from utils import configure_logging + + +def configure_callbacks( + config: DictConfig, +) -> List[Type[Callback]]: + """Configures lightning callbacks.""" + callbacks = [] + if config.get("callbacks"): + for callback_config in config.callbacks.values(): + if config.get("_target_"): + log.info(f"Instantiating callback <{callback_config._target_}>") + callbacks.append(hydra.utils.instantiate(callback_config)) + return callbacks + + +def configure_logger(config: DictConfig) -> List[Type[LightningLoggerBase]]: + logger = [] + if config.get("logger"): + for logger_config in config.logger.values(): + if config.get("_target_"): + log.info(f"Instantiating callback <{logger_config._target_}>") + logger.append(hydra.utils.instantiate(logger_config)) + return logger + + +def run(config: DictConfig) -> Optional[float]: + """Runs experiment.""" + configure_logging(config.logging) + log.info("Starting experiment...") + + if config.get("seed"): + seed_everything(config.seed) + + log.info(f"Instantiating datamodule <{config.datamodule._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) + + log.info(f"Instantiating network <{config.network._target_}>") + network: nn.Module = hydra.utils.instantiate(config.network, **datamodule.config()) + + log.info(f"Instantiating model <{config.model._target_}>") + model: LightningModule = hydra.utils.instantiate( + config.model, + network=network, + criterion=config.criterion, + optimizer=config.optimizer, + lr_scheduler=config.lr_scheduler, + _recursive_=False, + ) + + # Load callback and logger. + callbacks = configure_callbacks(config) + logger = configure_logger(config) + + log.info(f"Instantiating trainer <{config.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" + ) + + # Log hyperparameters + + if config.debug: + log.info("Fast development run...") + trainer.fit(model, datamodule=datamodule) + return None + + if config.tune: + log.info("Tuning learning rate and batch size...") + trainer.tune(model, datamodule=datamodule) + + if config.train: + log.info("Training network...") + trainer.fit(model, datamodule=datamodule) + + if config.test: + log.info("Testing network...") + trainer.test(model, datamodule=datamodule) + + # Make sure everything closes properly -- cgit v1.2.3-70-g09d2