summaryrefslogtreecommitdiff
path: root/training/run.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 02:56:07 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 02:56:07 +0200
commit43888e1b3eaa5902496ef1e191b58d94c224c220 (patch)
tree2722ddd475ea5ce8d01dbd6adcfe7c2d7ea47532 /training/run.py
parent2f19e8b863c54d16c1eb855bc89391063def15ce (diff)
Fix train script
Diffstat (limited to 'training/run.py')
-rw-r--r--training/run.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/training/run.py b/training/run.py
index 99059d6..429b1a2 100644
--- a/training/run.py
+++ b/training/run.py
@@ -79,7 +79,11 @@ def run(config: DictConfig) -> Optional[float]:
if config.test:
log.info("Testing network...")
- trainer.test(model, datamodule=datamodule)
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+ if ckpt_path is None:
+ log.error("No best checkpoint path for model found")
+ return
+ trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
utils.finish(logger)