summaryrefslogtreecommitdiff
path: root/training/run_experiment.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 23:14:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 23:14:24 +0200
commit532286b516b17d279c321358bf03dddc8adc8029 (patch)
tree9409789d2a3c093b499253f2fd101a74be1da452 /training/run_experiment.py
parent370c3eb35f47e64eab7926a5f995947c63c6b208 (diff)
Completed first draft for training loop with PyTorch Lightning
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r--training/run_experiment.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py
index ff8b886..8a29555 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -2,7 +2,7 @@
from datetime import datetime
import importlib
from pathlib import Path
-from typing import Dict, List, Optional, Type
+from typing import Dict, List, NamedTuple, Optional, Union, Type
import click
from loguru import logger
@@ -50,10 +50,10 @@ def _import_class(module_and_class_name: str) -> type:
return getattr(module, class_name)
-def _configure_pl_callbacks(args: List[Dict]) -> List[Type[pl.callbacks.Callback]]:
+def _configure_pl_callbacks(args: List[Union[OmegaConf, NamedTuple]]) -> List[Type[pl.callbacks.Callback]]:
"""Configures PyTorch Lightning callbacks."""
pl_callbacks = [
- getattr(pl.callbacks, callback["type"])(**callback["args"]) for callback in args
+ getattr(pl.callbacks, callback.type)(**callback.args) for callback in args
]
return pl_callbacks