diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
commit | e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch) | |
tree | 70b482f890c9ad2be104f0bff8f2172e8411a2be /src/training/run_experiment.py | |
parent | fe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff) |
IAM datasets implemented.
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r-- | src/training/run_experiment.py | 238 |
1 files changed, 143 insertions, 95 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 8c063ff..4317d66 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -6,18 +6,19 @@ import json import os from pathlib import Path import re -from typing import Callable, Dict, Tuple, Type +from typing import Callable, Dict, List, Tuple, Type import click from loguru import logger import torch from tqdm import tqdm from training.gpu_manager import GPUManager -from training.trainer.callbacks import CallbackList +from training.trainer.callbacks import Callback, CallbackList from training.trainer.train import Trainer import wandb import yaml + from text_recognizer.models import Model @@ -37,10 +38,14 @@ def get_level(experiment_config: Dict) -> int: return 10 -def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path: +def create_experiment_dir(experiment_config: Dict) -> Path: """Create new experiment.""" EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) - experiment_dir = EXPERIMENTS_DIRNAME / model.__name__ + experiment_dir = EXPERIMENTS_DIRNAME / ( + f"{experiment_config['model']}_" + + f"{experiment_config['dataset']['type']}_" + + f"{experiment_config['network']['type']}" + ) if experiment_config["resume_experiment"] is None: experiment = datetime.now().strftime("%m%d_%H%M%S") logger.debug(f"Creating a new experiment called {experiment}") @@ -54,70 +59,89 @@ def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path: experiment = experiment_config["resume_experiment"] if not str(experiment_dir / experiment) in available_experiments: raise FileNotFoundError("Experiment does not exist.") - logger.debug(f"Resuming the experiment {experiment}") experiment_dir = experiment_dir / experiment - return experiment_dir + # Create log and model directories. + log_dir = experiment_dir / "log" + model_dir = experiment_dir / "model" + + return experiment_dir, log_dir, model_dir -def check_args(args: Dict) -> Dict: + +def check_args(args: Dict, train_args: Dict) -> Dict: """Checks that the arguments are not None.""" + args = args or {} + + # I just want to set total epochs in train args, instead of changing all parameter. + if "epochs" in args and args["epochs"] is None: + args["epochs"] = train_args["max_epochs"] + + # For CosineAnnealingLR. + if "T_max" in args and args["T_max"] is None: + args["T_max"] = train_args["max_epochs"] + return args or {} def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]: """Loads all modules and arguments.""" # Import the data loader arguments. - data_loader_args = experiment_config.get("data_loader_args", {}) train_args = experiment_config.get("train_args", {}) - data_loader_args["batch_size"] = train_args["batch_size"] - data_loader_args["dataset"] = experiment_config["dataset"] - data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {}) + + # Load the dataset module. + dataset_args = experiment_config.get("dataset", {}) + dataset_args["train_args"]["batch_size"] = train_args["batch_size"] + datasets_module = importlib.import_module("text_recognizer.datasets") + dataset_ = getattr(datasets_module, dataset_args["type"]) # Import the model module and model arguments. models_module = importlib.import_module("text_recognizer.models") model_class_ = getattr(models_module, experiment_config["model"]) # Import metrics. - metric_fns_ = { - metric: getattr(models_module, metric) - for metric in experiment_config["metrics"] - } + metric_fns_ = ( + { + metric: getattr(models_module, metric) + for metric in experiment_config["metrics"] + } + if experiment_config["metrics"] is not None + else None + ) # Import network module and arguments. network_module = importlib.import_module("text_recognizer.networks") - network_fn_ = getattr(network_module, experiment_config["network"]) - network_args = experiment_config.get("network_args", {}) + network_fn_ = getattr(network_module, experiment_config["network"]["type"]) + network_args = experiment_config["network"].get("args", {}) # Criterion - criterion_ = getattr(torch.nn, experiment_config["criterion"]) - criterion_args = experiment_config.get("criterion_args", {}) + criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) + criterion_args = experiment_config["criterion"].get("args", {}) - # Optimizer - optimizer_ = getattr(torch.optim, experiment_config["optimizer"]) - optimizer_args = experiment_config.get("optimizer_args", {}) - - # Callbacks - callback_modules = importlib.import_module("training.trainer.callbacks") - callbacks = [ - getattr(callback_modules, callback)( - **check_args(experiment_config["callback_args"][callback]) - ) - for callback in experiment_config["callbacks"] - ] + # Optimizers + optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) + optimizer_args = experiment_config["optimizer"].get("args", {}) # Learning rate scheduler + lr_scheduler_ = None + lr_scheduler_args = None if experiment_config["lr_scheduler"] is not None: lr_scheduler_ = getattr( - torch.optim.lr_scheduler, experiment_config["lr_scheduler"] + torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] + ) + lr_scheduler_args = check_args( + experiment_config["lr_scheduler"].get("args", {}), train_args ) - lr_scheduler_args = experiment_config.get("lr_scheduler_args", {}) + + # SWA scheduler. + if "swa_args" in experiment_config: + swa_args = check_args(experiment_config.get("swa_args", {}), train_args) else: - lr_scheduler_ = None - lr_scheduler_args = None + swa_args = None model_args = { - "data_loader_args": data_loader_args, + "dataset": dataset_, + "dataset_args": dataset_args, "metrics": metric_fns_, "network_fn": network_fn_, "network_args": network_args, @@ -127,43 +151,33 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] "optimizer_args": optimizer_args, "lr_scheduler": lr_scheduler_, "lr_scheduler_args": lr_scheduler_args, + "swa_args": swa_args, } - return model_class_, model_args, callbacks - + return model_class_, model_args -def run_experiment( - experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False -) -> None: - """Runs an experiment.""" - - # Load the modules and model arguments. - model_class_, model_args, callbacks = load_modules_and_arguments(experiment_config) - - # Initializes the model with experiment config. - model = model_class_(**model_args, device=device) - # Instantiate a CallbackList. - callbacks = CallbackList(model, callbacks) - - # Create new experiment. - experiment_dir = create_experiment_dir(model, experiment_config) +def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList: + """Configure a callback list for trainer.""" + train_args = experiment_config.get("train_args", {}) - # Create log and model directories. - log_dir = experiment_dir / "log" - model_dir = experiment_dir / "model" + if "Checkpoint" in experiment_config["callback_args"]: + experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir - # Set the model dir to be able to save checkpoints. - model.model_dir = model_dir + # Callbacks + callback_modules = importlib.import_module("training.trainer.callbacks") + callbacks = [ + getattr(callback_modules, callback)( + **check_args(experiment_config["callback_args"][callback], train_args) + ) + for callback in experiment_config["callbacks"] + ] - # Get checkpoint path. - checkpoint_path = model_dir / "last.pt" - if not checkpoint_path.exists(): - checkpoint_path = None + return callbacks - # Make sure the log directory exists. - log_dir.mkdir(parents=True, exist_ok=True) +def configure_logger(experiment_config: Dict, log_dir: Path) -> None: + """Configure the loguru logger for output to terminal and disk.""" # Have to remove default logger to get tqdm to work properly. logger.remove() @@ -176,13 +190,50 @@ def run_experiment( format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", ) - if "cuda" in device: - gpu_index = re.sub("[^0-9]+", "", device) - logger.info( - f"Running experiment with config {experiment_config} on GPU {gpu_index}" - ) - else: - logger.info(f"Running experiment with config {experiment_config} on CPU") + +def save_config(experiment_dir: Path, experiment_config: Dict) -> None: + """Copy config to experiment directory.""" + config_path = experiment_dir / "config.yml" + with open(str(config_path), "w") as f: + yaml.dump(experiment_config, f) + + +def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) -> None: + """If checkpoint exists, load model weights and optimizers from checkpoint.""" + # Get checkpoint path. + checkpoint_path = model_dir / "last.pt" + if checkpoint_path.exists(): + logger.info("Loading and resuming training from last checkpoint.") + model.load_checkpoint(checkpoint_path) + + +def run_experiment( + experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False +) -> None: + """Runs an experiment.""" + logger.info(f"Experiment config: {json.dumps(experiment_config)}") + + # Create new experiment. + experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config) + + # Make sure the log/model directory exists. + log_dir.mkdir(parents=True, exist_ok=True) + model_dir.mkdir(parents=True, exist_ok=True) + + # Load the modules and model arguments. + model_class_, model_args = load_modules_and_arguments(experiment_config) + + # Initializes the model with experiment config. + model = model_class_(**model_args, device=device) + + callbacks = configure_callbacks(experiment_config, model_dir) + + # Setup logger. + configure_logger(experiment_config, log_dir) + + # Load from checkpoint if resuming an experiment. + if experiment_config["resume_experiment"] is not None: + load_from_checkpoint(model, log_dir, model_dir) logger.info(f"The class mapping is {model.mapping}") @@ -193,9 +244,6 @@ def run_experiment( # Lets W&B save the model and track the gradients and optional parameters. wandb.watch(model.network) - # PÅ•ints a summary of the network in terminal. - model.summary() - experiment_config["train_args"] = { **DEFAULT_TRAIN_ARGS, **experiment_config.get("train_args", {}), @@ -208,41 +256,41 @@ def run_experiment( experiment_config["device"] = device # Save the config used in the experiment folder. - config_path = experiment_dir / "config.yml" - with open(str(config_path), "w") as f: - yaml.dump(experiment_config, f) + save_config(experiment_dir, experiment_config) - # Train the model. + # Load trainer. trainer = Trainer( - model=model, - model_dir=model_dir, - train_args=experiment_config["train_args"], - callbacks=callbacks, - checkpoint_path=checkpoint_path, + max_epochs=experiment_config["train_args"]["max_epochs"], callbacks=callbacks, ) - trainer.fit() + # Train the model. + trainer.fit(model) - logger.info("Loading checkpoint with the best weights.") - model.load_checkpoint(model_dir / "best.pt") + # Run inference over test set. + if experiment_config["test"]: + logger.info("Loading checkpoint with the best weights.") + model.load_from_checkpoint(model_dir / "best.pt") - score = trainer.validate() + logger.info("Running inference on test set.") + score = trainer.test(model) - logger.info(f"Validation set evaluation: {score}") + logger.info(f"Test set evaluation: {score}") - if use_wandb: - wandb.log({"validation_metric": score["val_accuracy"]}) + if use_wandb: + wandb.log( + { + experiment_config["test_metric"]: score[ + experiment_config["test_metric"] + ] + } + ) if save_weights: model.save_weights(model_dir) @click.command() -@click.option( - "--experiment_config", - type=str, - help='Experiment JSON, e.g. \'{"dataloader": "EmnistDataLoader", "model": "CharacterModel", "network": "mlp"}\'', -) +@click.argument("experiment_config",) @click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.") @click.option( "--save", |