summaryrefslogtreecommitdiff
path: root/src/training/run_experiment.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r--src/training/run_experiment.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index c133ce5..d278dc2 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -58,12 +58,17 @@ def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path:
return experiment_dir
+def check_args(args: Dict) -> Dict:
+ """Checks that the arguments are not None."""
+ return args or {}
+
+
def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]:
"""Loads all modules and arguments."""
- # Import the data loader module and arguments.
- datasets_module = importlib.import_module("text_recognizer.datasets")
- data_loader_ = getattr(datasets_module, experiment_config["dataloader"])
+ # Import the data loader arguments.
data_loader_args = experiment_config.get("data_loader_args", {})
+ data_loader_args["dataset"] = experiment_config["dataset"]
+ data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {})
# Import the model module and model arguments.
models_module = importlib.import_module("text_recognizer.models")
@@ -90,10 +95,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
# Callbacks
callback_modules = importlib.import_module("training.callbacks")
- callbacks = []
- for callback in experiment_config["callbacks"]:
- args = experiment_config["callback_args"][callback] or {}
- callbacks.append(getattr(callback_modules, callback)(**args))
+ callbacks = [
+ getattr(callback_modules, callback)(
+ **check_args(experiment_config["callback_args"][callback])
+ )
+ for callback in experiment_config["callbacks"]
+ ]
# Learning rate scheduler
if experiment_config["lr_scheduler"] is not None:
@@ -106,7 +113,6 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
lr_scheduler_args = None
model_args = {
- "data_loader": data_loader_,
"data_loader_args": data_loader_args,
"metrics": metric_fns_,
"network_fn": network_fn_,