summaryrefslogtreecommitdiff
path: root/training/utils.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 02:56:07 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 02:56:07 +0200
commit43888e1b3eaa5902496ef1e191b58d94c224c220 (patch)
tree2722ddd475ea5ce8d01dbd6adcfe7c2d7ea47532 /training/utils.py
parent2f19e8b863c54d16c1eb855bc89391063def15ce (diff)
Fix train script
Diffstat (limited to 'training/utils.py')
-rw-r--r--training/utils.py23
1 files changed, 8 insertions, 15 deletions
diff --git a/training/utils.py b/training/utils.py
index f94a436..c8ea1be 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -1,5 +1,5 @@
"""Util functions for training with hydra and pytorch lightning."""
-from typing import Any, List, Type
+from typing import List, Type
import warnings
import hydra
@@ -105,10 +105,6 @@ def extras(config: DictConfig) -> None:
OmegaConf.set_struct(config, True)
-def empty(*args: Any, **kwargs: Any) -> None:
- pass
-
-
@rank_zero_only
def log_hyperparameters(
config: DictConfig,
@@ -122,26 +118,23 @@ def log_hyperparameters(
hparams["trainer"] = config["trainer"]
hparams["model"] = config["model"]
hparams["datamodule"] = config["datamodule"]
- if "callbacks" in config:
- hparams["callbacks"] = config["callbacks"]
# save number of model parameters
- hparams["model/params_total"] = sum(p.numel() for p in model.parameters())
- hparams["model/params_trainable"] = sum(
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
- hparams["model/params_not_trainable"] = sum(
+ hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
+ hparams["callbacks"] = config.get("callbacks")
+ hparams["tags"] = config.get("tags")
+ hparams["ckpt_path"] = config.get("ckpt_path")
+ hparams["seed"] = config.get("seed")
# send hparams to all loggers
trainer.logger.log_hyperparams(hparams)
- # disable logging any more hyperparameters for all loggers
- # this is just a trick to prevent trainer from logging hparams of model,
- # since we already did that above
- trainer.logger.log_hyperparams = empty
-
def finish(
logger: List[Type[LightningLoggerBase]],