summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 02:56:07 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 02:56:07 +0200
commit43888e1b3eaa5902496ef1e191b58d94c224c220 (patch)
tree2722ddd475ea5ce8d01dbd6adcfe7c2d7ea47532
parent2f19e8b863c54d16c1eb855bc89391063def15ce (diff)
Fix train script
-rw-r--r--training/run.py6
-rw-r--r--training/utils.py23
2 files changed, 13 insertions, 16 deletions
diff --git a/training/run.py b/training/run.py
index 99059d6..429b1a2 100644
--- a/training/run.py
+++ b/training/run.py
@@ -79,7 +79,11 @@ def run(config: DictConfig) -> Optional[float]:
if config.test:
log.info("Testing network...")
- trainer.test(model, datamodule=datamodule)
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+ if ckpt_path is None:
+ log.error("No best checkpoint path for model found")
+ return
+ trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
utils.finish(logger)
diff --git a/training/utils.py b/training/utils.py
index f94a436..c8ea1be 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -1,5 +1,5 @@
"""Util functions for training with hydra and pytorch lightning."""
-from typing import Any, List, Type
+from typing import List, Type
import warnings
import hydra
@@ -105,10 +105,6 @@ def extras(config: DictConfig) -> None:
OmegaConf.set_struct(config, True)
-def empty(*args: Any, **kwargs: Any) -> None:
- pass
-
-
@rank_zero_only
def log_hyperparameters(
config: DictConfig,
@@ -122,26 +118,23 @@ def log_hyperparameters(
hparams["trainer"] = config["trainer"]
hparams["model"] = config["model"]
hparams["datamodule"] = config["datamodule"]
- if "callbacks" in config:
- hparams["callbacks"] = config["callbacks"]
# save number of model parameters
- hparams["model/params_total"] = sum(p.numel() for p in model.parameters())
- hparams["model/params_trainable"] = sum(
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
- hparams["model/params_not_trainable"] = sum(
+ hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
+ hparams["callbacks"] = config.get("callbacks")
+ hparams["tags"] = config.get("tags")
+ hparams["ckpt_path"] = config.get("ckpt_path")
+ hparams["seed"] = config.get("seed")
# send hparams to all loggers
trainer.logger.log_hyperparams(hparams)
- # disable logging any more hyperparameters for all loggers
- # this is just a trick to prevent trainer from logging hparams of model,
- # since we already did that above
- trainer.logger.log_hyperparams = empty
-
def finish(
logger: List[Type[LightningLoggerBase]],