summaryrefslogtreecommitdiff
path: root/training/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/utils.py')
-rw-r--r--training/utils.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/training/utils.py b/training/utils.py
index 1996f0a..f94a436 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -36,7 +36,9 @@ def configure_logging(config: DictConfig) -> None:
log.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=config.logging)
-def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]:
+def configure_callbacks(
+ config: DictConfig,
+) -> List[Type[Callback]]:
"""Configures Lightning callbacks."""
def load_callback(callback_config: DictConfig) -> Type[Callback]:
@@ -109,7 +111,9 @@ def empty(*args: Any, **kwargs: Any) -> None:
@rank_zero_only
def log_hyperparameters(
- config: DictConfig, model: LightningModule, trainer: Trainer,
+ config: DictConfig,
+ model: LightningModule,
+ trainer: Trainer,
) -> None:
"""This method saves hyperparameters with the logger."""
hparams = {}
@@ -139,7 +143,9 @@ def log_hyperparameters(
trainer.logger.log_hyperparams = empty
-def finish(logger: List[Type[LightningLoggerBase]],) -> None:
+def finish(
+ logger: List[Type[LightningLoggerBase]],
+) -> None:
"""Makes sure everything closed properly."""
for lg in logger:
if isinstance(lg, WandbLogger):