summaryrefslogtreecommitdiff
path: root/training/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/utils.py')
-rw-r--r--training/utils.py12
1 files changed, 3 insertions, 9 deletions
diff --git a/training/utils.py b/training/utils.py
index 88b72b7..ef74f61 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -25,9 +25,7 @@ def configure_logging(config: DictConfig) -> None:
log.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=config.logging)
-def configure_callbacks(
- config: DictConfig,
-) -> List[Type[Callback]]:
+def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]:
"""Configures Lightning callbacks."""
callbacks = []
if config.get("callbacks"):
@@ -95,9 +93,7 @@ def empty(*args: Any, **kwargs: Any) -> None:
@rank_zero_only
def log_hyperparameters(
- config: DictConfig,
- model: LightningModule,
- trainer: Trainer,
+ config: DictConfig, model: LightningModule, trainer: Trainer,
) -> None:
"""This method saves hyperparameters with the logger."""
hparams = {}
@@ -127,9 +123,7 @@ def log_hyperparameters(
trainer.logger.log_hyperparams = empty
-def finish(
- logger: List[Type[LightningLoggerBase]],
-) -> None:
+def finish(logger: List[Type[LightningLoggerBase]],) -> None:
"""Makes sure everything closed properly."""
for lg in logger:
if isinstance(lg, WandbLogger):