From 43888e1b3eaa5902496ef1e191b58d94c224c220 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 2 Oct 2022 02:56:07 +0200 Subject: Fix train script --- training/run.py | 6 +++++- training/utils.py | 23 ++++++++--------------- 2 files changed, 13 insertions(+), 16 deletions(-) (limited to 'training') diff --git a/training/run.py b/training/run.py index 99059d6..429b1a2 100644 --- a/training/run.py +++ b/training/run.py @@ -79,7 +79,11 @@ def run(config: DictConfig) -> Optional[float]: if config.test: log.info("Testing network...") - trainer.test(model, datamodule=datamodule) + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path is None: + log.error("No best checkpoint path for model found") + return + trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path) log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}") utils.finish(logger) diff --git a/training/utils.py b/training/utils.py index f94a436..c8ea1be 100644 --- a/training/utils.py +++ b/training/utils.py @@ -1,5 +1,5 @@ """Util functions for training with hydra and pytorch lightning.""" -from typing import Any, List, Type +from typing import List, Type import warnings import hydra @@ -105,10 +105,6 @@ def extras(config: DictConfig) -> None: OmegaConf.set_struct(config, True) -def empty(*args: Any, **kwargs: Any) -> None: - pass - - @rank_zero_only def log_hyperparameters( config: DictConfig, @@ -122,26 +118,23 @@ def log_hyperparameters( hparams["trainer"] = config["trainer"] hparams["model"] = config["model"] hparams["datamodule"] = config["datamodule"] - if "callbacks" in config: - hparams["callbacks"] = config["callbacks"] # save number of model parameters - hparams["model/params_total"] = sum(p.numel() for p in model.parameters()) - hparams["model/params_trainable"] = sum( + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( p.numel() for p in model.parameters() if p.requires_grad ) - hparams["model/params_not_trainable"] = sum( + hparams["model/params/non_trainable"] = sum( p.numel() for p in model.parameters() if not p.requires_grad ) + hparams["callbacks"] = config.get("callbacks") + hparams["tags"] = config.get("tags") + hparams["ckpt_path"] = config.get("ckpt_path") + hparams["seed"] = config.get("seed") # send hparams to all loggers trainer.logger.log_hyperparams(hparams) - # disable logging any more hyperparameters for all loggers - # this is just a trick to prevent trainer from logging hparams of model, - # since we already did that above - trainer.logger.log_hyperparams = empty - def finish( logger: List[Type[LightningLoggerBase]], -- cgit v1.2.3-70-g09d2