summaryrefslogtreecommitdiff
path: root/training/utils.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-28 15:14:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-28 15:14:55 +0200
commitc032ffb05a7ed86f8fe5d596f94e8997c558cae8 (patch)
treebf890ffd4c815db7d510cfb281d253b5728f70c6 /training/utils.py
parent524bf4351ac295bd4ff9914bb1f32eda7f7ff855 (diff)
Reformatting with attrs, config for encoder and decoder
Diffstat (limited to 'training/utils.py')
-rw-r--r--training/utils.py12
1 files changed, 3 insertions, 9 deletions
diff --git a/training/utils.py b/training/utils.py
index 88b72b7..ef74f61 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -25,9 +25,7 @@ def configure_logging(config: DictConfig) -> None:
log.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=config.logging)
-def configure_callbacks(
- config: DictConfig,
-) -> List[Type[Callback]]:
+def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]:
"""Configures Lightning callbacks."""
callbacks = []
if config.get("callbacks"):
@@ -95,9 +93,7 @@ def empty(*args: Any, **kwargs: Any) -> None:
@rank_zero_only
def log_hyperparameters(
- config: DictConfig,
- model: LightningModule,
- trainer: Trainer,
+ config: DictConfig, model: LightningModule, trainer: Trainer,
) -> None:
"""This method saves hyperparameters with the logger."""
hparams = {}
@@ -127,9 +123,7 @@ def log_hyperparameters(
trainer.logger.log_hyperparams = empty
-def finish(
- logger: List[Type[LightningLoggerBase]],
-) -> None:
+def finish(logger: List[Type[LightningLoggerBase]],) -> None:
"""Makes sure everything closed properly."""
for lg in logger:
if isinstance(lg, WandbLogger):