diff options
Diffstat (limited to 'training/run.py')
-rw-r--r-- | training/run.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/training/run.py b/training/run.py index 99059d6..429b1a2 100644 --- a/training/run.py +++ b/training/run.py @@ -79,7 +79,11 @@ def run(config: DictConfig) -> Optional[float]: if config.test: log.info("Testing network...") - trainer.test(model, datamodule=datamodule) + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path is None: + log.error("No best checkpoint path for model found") + return + trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path) log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}") utils.finish(logger) |