diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-05 22:27:08 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-05 22:27:08 +0200 |
commit | 5a78fc2e33c28968a69d033cb10d638f4f63fed1 (patch) | |
tree | e11cea8366c848e5500f85968ee5369ff8d96b00 /src/training/run_experiment.py | |
parent | 7c4de6d88664d2ea1b084f316a11896dde3e1150 (diff) |
Working on getting experiment loop.
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r-- | src/training/run_experiment.py | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 8033f47..8296e59 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -1 +1,75 @@ """Script to run experiments.""" +import importlib +import os +from typing import Dict + +import click +import torch +from training.train import Trainer + + +def run_experiment( + experiment_config: Dict, save_weights: bool, gpu_index: int, use_wandb: bool = False +) -> None: + """Short summary.""" + # Import the data loader module and arguments. + datasets_module = importlib.import_module("text_recognizer.datasets") + data_loader_ = getattr(datasets_module, experiment_config["dataloader"]) + data_loader_args = experiment_config.get("data_loader_args", {}) + + # Import the model module and model arguments. + models_module = importlib.import_module("text_recognizer.models") + model_class_ = getattr(models_module, experiment_config["model"]) + + # Import metric. + metric_fn_ = getattr(models_module, experiment_config["metric"]) + + # Import network module and arguments. + network_module = importlib.import_module("text_recognizer.networks") + network_fn_ = getattr(network_module, experiment_config["network"]) + network_args = experiment_config.get("network_args", {}) + + # Criterion + criterion_ = getattr(torch.nn, experiment_config["criterion"]) + criterion_args = experiment_config.get("criterion_args", {}) + + # Optimizer + optimizer_ = getattr(torch.optim, experiment_config["optimizer"]) + optimizer_args = experiment_config.get("optimizer_args", {}) + + # Learning rate scheduler + lr_scheduler_ = None + lr_scheduler_args = None + if experiment_config["lr_scheduler"] is not None: + lr_scheduler_ = getattr( + torch.optim.lr_scheduler, experiment_config["lr_scheduler"] + ) + lr_scheduler_args = experiment_config.get("lr_scheduler_args", {}) + + # Device + # TODO fix gpu manager + device = None + + model = model_class_( + network_fn=network_fn_, + network_args=network_args, + data_loader=data_loader_, + data_loader_args=data_loader_args, + metrics=metric_fn_, + criterion=criterion_, + criterion_args=criterion_args, + optimizer=optimizer_, + optimizer_args=optimizer_args, + lr_scheduler=lr_scheduler_, + lr_scheduler_args=lr_scheduler_args, + device=device, + ) + + # TODO: Fix checkpoint path and wandb + trainer = Trainer( + model=model, + epochs=experiment_config["epochs"], + val_metric=experiment_config["metric"], + ) + + trainer.fit() |