summaryrefslogtreecommitdiff
path: root/training/utils.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-14 00:54:28 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-14 00:54:28 +0200
commit2a1580869d4b520291d660ca662c374e5046329a (patch)
tree61a922b0dd0f997121691665687182711bb2e279 /training/utils.py
parent026514d6565cbb3f96afd7f308cc4f22d3f7e88a (diff)
Format
Diffstat (limited to 'training/utils.py')
-rw-r--r--training/utils.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/training/utils.py b/training/utils.py
index 1996f0a..f94a436 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -36,7 +36,9 @@ def configure_logging(config: DictConfig) -> None:
log.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=config.logging)
-def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]:
+def configure_callbacks(
+ config: DictConfig,
+) -> List[Type[Callback]]:
"""Configures Lightning callbacks."""
def load_callback(callback_config: DictConfig) -> Type[Callback]:
@@ -109,7 +111,9 @@ def empty(*args: Any, **kwargs: Any) -> None:
@rank_zero_only
def log_hyperparameters(
- config: DictConfig, model: LightningModule, trainer: Trainer,
+ config: DictConfig,
+ model: LightningModule,
+ trainer: Trainer,
) -> None:
"""This method saves hyperparameters with the logger."""
hparams = {}
@@ -139,7 +143,9 @@ def log_hyperparameters(
trainer.logger.log_hyperparams = empty
-def finish(logger: List[Type[LightningLoggerBase]],) -> None:
+def finish(
+ logger: List[Type[LightningLoggerBase]],
+) -> None:
"""Makes sure everything closed properly."""
for lg in logger:
if isinstance(lg, WandbLogger):