diff options
-rw-r--r-- | training/conf/logger/wandb.yaml | 15 | ||||
-rw-r--r-- | training/main.py | 21 | ||||
-rw-r--r-- | training/run.py | 95 | ||||
-rw-r--r-- | training/utils.py | 56 |
4 files changed, 187 insertions, 0 deletions
diff --git a/training/conf/logger/wandb.yaml b/training/conf/logger/wandb.yaml new file mode 100644 index 0000000..552cf00 --- /dev/null +++ b/training/conf/logger/wandb.yaml @@ -0,0 +1,15 @@ +# https://wandb.ai + +wandb: + _target_: pytorch_lightning.loggers.wandb.WandbLogger + project: "text-recognizer" + name: null + save_dir: "." + offline: False # set True to store all logs only locally + id: null # pass correct id to resume experiment! + # entity: "" # set to name of your wandb team or just remove it + log_model: False + prefix: "" + job_type: "train" + group: "" + tags: [] diff --git a/training/main.py b/training/main.py new file mode 100644 index 0000000..26d5aeb --- /dev/null +++ b/training/main.py @@ -0,0 +1,21 @@ +"""Loads config with hydra and runs experiment.""" +import hydra +from omegaconf import DictConfig + + +@hydra.main(config_path="conf", config_name="config") +def main(config: DictConfig) -> None: + """Loads config with hydra and runs the experiment.""" + import utils + from run import run + + utils.extras(config) + + if config.get("print_config"): + utils.print_config(config) + + return run(config) + + +if __name__ == "__main__": + main() diff --git a/training/run.py b/training/run.py new file mode 100644 index 0000000..ed1b372 --- /dev/null +++ b/training/run.py @@ -0,0 +1,95 @@ +"""Script to run experiments.""" +from typing import List, Optional, Type + +import hydra +import loguru.logger as log +from omegaconf import DictConfig +from pytorch_lightning import ( + Callback, + LightningDataModule, + LightningModule, + Trainer, + seed_everything, +) +from pytorch_lightning.loggers import LightningLoggerBase +from torch import nn + +from utils import configure_logging + + +def configure_callbacks( + config: DictConfig, +) -> List[Type[Callback]]: + """Configures lightning callbacks.""" + callbacks = [] + if config.get("callbacks"): + for callback_config in config.callbacks.values(): + if config.get("_target_"): + log.info(f"Instantiating callback <{callback_config._target_}>") + callbacks.append(hydra.utils.instantiate(callback_config)) + return callbacks + + +def configure_logger(config: DictConfig) -> List[Type[LightningLoggerBase]]: + logger = [] + if config.get("logger"): + for logger_config in config.logger.values(): + if config.get("_target_"): + log.info(f"Instantiating callback <{logger_config._target_}>") + logger.append(hydra.utils.instantiate(logger_config)) + return logger + + +def run(config: DictConfig) -> Optional[float]: + """Runs experiment.""" + configure_logging(config.logging) + log.info("Starting experiment...") + + if config.get("seed"): + seed_everything(config.seed) + + log.info(f"Instantiating datamodule <{config.datamodule._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) + + log.info(f"Instantiating network <{config.network._target_}>") + network: nn.Module = hydra.utils.instantiate(config.network, **datamodule.config()) + + log.info(f"Instantiating model <{config.model._target_}>") + model: LightningModule = hydra.utils.instantiate( + config.model, + network=network, + criterion=config.criterion, + optimizer=config.optimizer, + lr_scheduler=config.lr_scheduler, + _recursive_=False, + ) + + # Load callback and logger. + callbacks = configure_callbacks(config) + logger = configure_logger(config) + + log.info(f"Instantiating trainer <{config.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" + ) + + # Log hyperparameters + + if config.debug: + log.info("Fast development run...") + trainer.fit(model, datamodule=datamodule) + return None + + if config.tune: + log.info("Tuning learning rate and batch size...") + trainer.tune(model, datamodule=datamodule) + + if config.train: + log.info("Training network...") + trainer.fit(model, datamodule=datamodule) + + if config.test: + log.info("Testing network...") + trainer.test(model, datamodule=datamodule) + + # Make sure everything closes properly diff --git a/training/utils.py b/training/utils.py new file mode 100644 index 0000000..7717fc5 --- /dev/null +++ b/training/utils.py @@ -0,0 +1,56 @@ +"""Util functions for training hydra configs and pytorch lightning.""" +import warnings + +from omegaconf import DictConfig, OmegaConf +import loguru.logger as log +from pytorch_lightning.loggers.wandb import WandbLogger +from pytorch_lightning.utilities import rank_zero_only +from tqdm import tqdm + + +@rank_zero_only +def configure_logging(level: str) -> None: + """Configure the loguru logger for output to terminal and disk.""" + # Remove default logger to get tqdm to work properly. + log.remove() + log.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level) + + +def extras(config: DictConfig) -> None: + """Sets optional utilities.""" + # Enable adding new keys. + OmegaConf.set_struct(config, False) + + if config.get("ignore_warnings"): + log.info("Disabling python warnings! <config.ignore_warnings=True>") + warnings.filterwarnings("ignore") + + if config.get("debug"): + log.info("Running in debug mode! <config.debug=True>") + config.trainer.fast_dev_run = True + + if config.trainer.get("fast_dev_run"): + log.info( + "Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>" + ) + # Debuggers do not like GPUs and multiprocessing. + if config.trainer.get("gpus"): + config.trainer.gpus = 0 + if config.datamodule.get("pin_memory"): + config.datamodule.pin_memory = False + if config.datamodule.get("num_workers"): + config.datamodule.num_workers = 0 + + # Force multi-gpu friendly config. + accelerator = config.trainer.get("accelerator") + if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]: + log.info( + f"Forcing ddp friendly configuration! <config.trainer.accelerator={accelerator}>" + ) + if config.datamodule.get("pin_memory"): + config.datamodule.pin_memory = False + if config.datamodule.get("num_workers"): + config.datamodule.num_workers = 0 + + # Disable adding new keys to config + OmegaConf.set_struct(config, True) |