diff options
Diffstat (limited to 'training/utils.py')
-rw-r--r-- | training/utils.py | 23 |
1 files changed, 8 insertions, 15 deletions
diff --git a/training/utils.py b/training/utils.py index f94a436..c8ea1be 100644 --- a/training/utils.py +++ b/training/utils.py @@ -1,5 +1,5 @@ """Util functions for training with hydra and pytorch lightning.""" -from typing import Any, List, Type +from typing import List, Type import warnings import hydra @@ -105,10 +105,6 @@ def extras(config: DictConfig) -> None: OmegaConf.set_struct(config, True) -def empty(*args: Any, **kwargs: Any) -> None: - pass - - @rank_zero_only def log_hyperparameters( config: DictConfig, @@ -122,26 +118,23 @@ def log_hyperparameters( hparams["trainer"] = config["trainer"] hparams["model"] = config["model"] hparams["datamodule"] = config["datamodule"] - if "callbacks" in config: - hparams["callbacks"] = config["callbacks"] # save number of model parameters - hparams["model/params_total"] = sum(p.numel() for p in model.parameters()) - hparams["model/params_trainable"] = sum( + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( p.numel() for p in model.parameters() if p.requires_grad ) - hparams["model/params_not_trainable"] = sum( + hparams["model/params/non_trainable"] = sum( p.numel() for p in model.parameters() if not p.requires_grad ) + hparams["callbacks"] = config.get("callbacks") + hparams["tags"] = config.get("tags") + hparams["ckpt_path"] = config.get("ckpt_path") + hparams["seed"] = config.get("seed") # send hparams to all loggers trainer.logger.log_hyperparams(hparams) - # disable logging any more hyperparameters for all loggers - # this is just a trick to prevent trainer from logging hparams of model, - # since we already did that above - trainer.logger.log_hyperparams = empty - def finish( logger: List[Type[LightningLoggerBase]], |