diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 23:14:24 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 23:14:24 +0200 |
commit | 532286b516b17d279c321358bf03dddc8adc8029 (patch) | |
tree | 9409789d2a3c093b499253f2fd101a74be1da452 /training/run_experiment.py | |
parent | 370c3eb35f47e64eab7926a5f995947c63c6b208 (diff) |
Completed first draft for training loop with PyTorch Lightning
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r-- | training/run_experiment.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py index ff8b886..8a29555 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -2,7 +2,7 @@ from datetime import datetime import importlib from pathlib import Path -from typing import Dict, List, Optional, Type +from typing import Dict, List, NamedTuple, Optional, Union, Type import click from loguru import logger @@ -50,10 +50,10 @@ def _import_class(module_and_class_name: str) -> type: return getattr(module, class_name) -def _configure_pl_callbacks(args: List[Dict]) -> List[Type[pl.callbacks.Callback]]: +def _configure_pl_callbacks(args: List[Union[OmegaConf, NamedTuple]]) -> List[Type[pl.callbacks.Callback]]: """Configures PyTorch Lightning callbacks.""" pl_callbacks = [ - getattr(pl.callbacks, callback["type"])(**callback["args"]) for callback in args + getattr(pl.callbacks, callback.type)(**callback.args) for callback in args ] return pl_callbacks |