summaryrefslogtreecommitdiff
path: root/src/training/run_experiment.py
blob: 8296e5902a08f55b30b00ea34ee2f24f2acee902 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Script to run experiments."""
import importlib
import os
from typing import Dict

import click
import torch
from training.train import Trainer


def run_experiment(
    experiment_config: Dict, save_weights: bool, gpu_index: int, use_wandb: bool = False
) -> None:
    """Short summary."""
    # Import the data loader module and arguments.
    datasets_module = importlib.import_module("text_recognizer.datasets")
    data_loader_ = getattr(datasets_module, experiment_config["dataloader"])
    data_loader_args = experiment_config.get("data_loader_args", {})

    # Import the model module and model arguments.
    models_module = importlib.import_module("text_recognizer.models")
    model_class_ = getattr(models_module, experiment_config["model"])

    # Import metric.
    metric_fn_ = getattr(models_module, experiment_config["metric"])

    # Import network module and arguments.
    network_module = importlib.import_module("text_recognizer.networks")
    network_fn_ = getattr(network_module, experiment_config["network"])
    network_args = experiment_config.get("network_args", {})

    # Criterion
    criterion_ = getattr(torch.nn, experiment_config["criterion"])
    criterion_args = experiment_config.get("criterion_args", {})

    # Optimizer
    optimizer_ = getattr(torch.optim, experiment_config["optimizer"])
    optimizer_args = experiment_config.get("optimizer_args", {})

    # Learning rate scheduler
    lr_scheduler_ = None
    lr_scheduler_args = None
    if experiment_config["lr_scheduler"] is not None:
        lr_scheduler_ = getattr(
            torch.optim.lr_scheduler, experiment_config["lr_scheduler"]
        )
        lr_scheduler_args = experiment_config.get("lr_scheduler_args", {})

    # Device
    # TODO fix gpu manager
    device = None

    model = model_class_(
        network_fn=network_fn_,
        network_args=network_args,
        data_loader=data_loader_,
        data_loader_args=data_loader_args,
        metrics=metric_fn_,
        criterion=criterion_,
        criterion_args=criterion_args,
        optimizer=optimizer_,
        optimizer_args=optimizer_args,
        lr_scheduler=lr_scheduler_,
        lr_scheduler_args=lr_scheduler_args,
        device=device,
    )

    # TODO: Fix checkpoint path and wandb
    trainer = Trainer(
        model=model,
        epochs=experiment_config["epochs"],
        val_metric=experiment_config["metric"],
    )

    trainer.fit()