summaryrefslogtreecommitdiff
path: root/src/training/run_experiment.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r--src/training/run_experiment.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index d278dc2..8c063ff 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -6,18 +6,20 @@ import json
import os
from pathlib import Path
import re
-from typing import Callable, Dict, Tuple
+from typing import Callable, Dict, Tuple, Type
import click
from loguru import logger
import torch
from tqdm import tqdm
-from training.callbacks import CallbackList
from training.gpu_manager import GPUManager
-from training.train import Trainer
+from training.trainer.callbacks import CallbackList
+from training.trainer.train import Trainer
import wandb
import yaml
+from text_recognizer.models import Model
+
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
@@ -35,7 +37,7 @@ def get_level(experiment_config: Dict) -> int:
return 10
-def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path:
+def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path:
"""Create new experiment."""
EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
experiment_dir = EXPERIMENTS_DIRNAME / model.__name__
@@ -67,6 +69,8 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
"""Loads all modules and arguments."""
# Import the data loader arguments.
data_loader_args = experiment_config.get("data_loader_args", {})
+ train_args = experiment_config.get("train_args", {})
+ data_loader_args["batch_size"] = train_args["batch_size"]
data_loader_args["dataset"] = experiment_config["dataset"]
data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {})
@@ -94,7 +98,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
optimizer_args = experiment_config.get("optimizer_args", {})
# Callbacks
- callback_modules = importlib.import_module("training.callbacks")
+ callback_modules = importlib.import_module("training.trainer.callbacks")
callbacks = [
getattr(callback_modules, callback)(
**check_args(experiment_config["callback_args"][callback])
@@ -208,6 +212,7 @@ def run_experiment(
with open(str(config_path), "w") as f:
yaml.dump(experiment_config, f)
+ # Train the model.
trainer = Trainer(
model=model,
model_dir=model_dir,
@@ -247,7 +252,7 @@ def run_experiment(
@click.option(
"--nowandb", is_flag=False, help="If true, do not use wandb for this run."
)
-def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
+def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
"""Run experiment."""
if gpu < 0:
gpu_manager = GPUManager(True)
@@ -260,4 +265,4 @@ def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
if __name__ == "__main__":
- main()
+ run_cli()