diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
commit | 53677be4ec14854ea4881b0d78730e0414c8dedd (patch) | |
tree | 56eaace5e9906c7d408b6a251ca100b5c8b4e991 /src/training/run_experiment.py | |
parent | 125d5da5fb845d03bda91426e172bca7f537584a (diff) |
Working bash scripts etc.
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r-- | src/training/run_experiment.py | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index c133ce5..d278dc2 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -58,12 +58,17 @@ def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path: return experiment_dir +def check_args(args: Dict) -> Dict: + """Checks that the arguments are not None.""" + return args or {} + + def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]: """Loads all modules and arguments.""" - # Import the data loader module and arguments. - datasets_module = importlib.import_module("text_recognizer.datasets") - data_loader_ = getattr(datasets_module, experiment_config["dataloader"]) + # Import the data loader arguments. data_loader_args = experiment_config.get("data_loader_args", {}) + data_loader_args["dataset"] = experiment_config["dataset"] + data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {}) # Import the model module and model arguments. models_module = importlib.import_module("text_recognizer.models") @@ -90,10 +95,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] # Callbacks callback_modules = importlib.import_module("training.callbacks") - callbacks = [] - for callback in experiment_config["callbacks"]: - args = experiment_config["callback_args"][callback] or {} - callbacks.append(getattr(callback_modules, callback)(**args)) + callbacks = [ + getattr(callback_modules, callback)( + **check_args(experiment_config["callback_args"][callback]) + ) + for callback in experiment_config["callbacks"] + ] # Learning rate scheduler if experiment_config["lr_scheduler"] is not None: @@ -106,7 +113,6 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] lr_scheduler_args = None model_args = { - "data_loader": data_loader_, "data_loader_args": data_loader_args, "metrics": metric_fns_, "network_fn": network_fn_, |