diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
commit | 1f459ba19422593de325983040e176f97cf4ffc0 (patch) | |
tree | 89fef442d5dbe0c83253e9566d1762f0704f64e2 /src/training/run_experiment.py | |
parent | 95cbdf5bc1cc9639febda23c28d8f464c998b214 (diff) |
A lot of stuff working :D. ResNet implemented!
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r-- | src/training/run_experiment.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index d278dc2..8c063ff 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -6,18 +6,20 @@ import json import os from pathlib import Path import re -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, Tuple, Type import click from loguru import logger import torch from tqdm import tqdm -from training.callbacks import CallbackList from training.gpu_manager import GPUManager -from training.train import Trainer +from training.trainer.callbacks import CallbackList +from training.trainer.train import Trainer import wandb import yaml +from text_recognizer.models import Model + EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" @@ -35,7 +37,7 @@ def get_level(experiment_config: Dict) -> int: return 10 -def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path: +def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path: """Create new experiment.""" EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) experiment_dir = EXPERIMENTS_DIRNAME / model.__name__ @@ -67,6 +69,8 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] """Loads all modules and arguments.""" # Import the data loader arguments. data_loader_args = experiment_config.get("data_loader_args", {}) + train_args = experiment_config.get("train_args", {}) + data_loader_args["batch_size"] = train_args["batch_size"] data_loader_args["dataset"] = experiment_config["dataset"] data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {}) @@ -94,7 +98,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] optimizer_args = experiment_config.get("optimizer_args", {}) # Callbacks - callback_modules = importlib.import_module("training.callbacks") + callback_modules = importlib.import_module("training.trainer.callbacks") callbacks = [ getattr(callback_modules, callback)( **check_args(experiment_config["callback_args"][callback]) @@ -208,6 +212,7 @@ def run_experiment( with open(str(config_path), "w") as f: yaml.dump(experiment_config, f) + # Train the model. trainer = Trainer( model=model, model_dir=model_dir, @@ -247,7 +252,7 @@ def run_experiment( @click.option( "--nowandb", is_flag=False, help="If true, do not use wandb for this run." ) -def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: +def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: """Run experiment.""" if gpu < 0: gpu_manager = GPUManager(True) @@ -260,4 +265,4 @@ def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: if __name__ == "__main__": - main() + run_cli() |