summaryrefslogtreecommitdiff
path: root/src/training/run_experiment.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r--src/training/run_experiment.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 8033f47..8296e59 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -1 +1,75 @@
"""Script to run experiments."""
+import importlib
+import os
+from typing import Dict
+
+import click
+import torch
+from training.train import Trainer
+
+
+def run_experiment(
+ experiment_config: Dict, save_weights: bool, gpu_index: int, use_wandb: bool = False
+) -> None:
+ """Short summary."""
+ # Import the data loader module and arguments.
+ datasets_module = importlib.import_module("text_recognizer.datasets")
+ data_loader_ = getattr(datasets_module, experiment_config["dataloader"])
+ data_loader_args = experiment_config.get("data_loader_args", {})
+
+ # Import the model module and model arguments.
+ models_module = importlib.import_module("text_recognizer.models")
+ model_class_ = getattr(models_module, experiment_config["model"])
+
+ # Import metric.
+ metric_fn_ = getattr(models_module, experiment_config["metric"])
+
+ # Import network module and arguments.
+ network_module = importlib.import_module("text_recognizer.networks")
+ network_fn_ = getattr(network_module, experiment_config["network"])
+ network_args = experiment_config.get("network_args", {})
+
+ # Criterion
+ criterion_ = getattr(torch.nn, experiment_config["criterion"])
+ criterion_args = experiment_config.get("criterion_args", {})
+
+ # Optimizer
+ optimizer_ = getattr(torch.optim, experiment_config["optimizer"])
+ optimizer_args = experiment_config.get("optimizer_args", {})
+
+ # Learning rate scheduler
+ lr_scheduler_ = None
+ lr_scheduler_args = None
+ if experiment_config["lr_scheduler"] is not None:
+ lr_scheduler_ = getattr(
+ torch.optim.lr_scheduler, experiment_config["lr_scheduler"]
+ )
+ lr_scheduler_args = experiment_config.get("lr_scheduler_args", {})
+
+ # Device
+ # TODO fix gpu manager
+ device = None
+
+ model = model_class_(
+ network_fn=network_fn_,
+ network_args=network_args,
+ data_loader=data_loader_,
+ data_loader_args=data_loader_args,
+ metrics=metric_fn_,
+ criterion=criterion_,
+ criterion_args=criterion_args,
+ optimizer=optimizer_,
+ optimizer_args=optimizer_args,
+ lr_scheduler=lr_scheduler_,
+ lr_scheduler_args=lr_scheduler_args,
+ device=device,
+ )
+
+ # TODO: Fix checkpoint path and wandb
+ trainer = Trainer(
+ model=model,
+ epochs=experiment_config["epochs"],
+ val_metric=experiment_config["metric"],
+ )
+
+ trainer.fit()