diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /training/run_experiment.py | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r-- | training/run_experiment.py | 382 |
1 files changed, 382 insertions, 0 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py new file mode 100644 index 0000000..faafea6 --- /dev/null +++ b/training/run_experiment.py @@ -0,0 +1,382 @@ +"""Script to run experiments.""" +from datetime import datetime +from glob import glob +import importlib +import json +import os +from pathlib import Path +import re +from typing import Callable, Dict, List, Optional, Tuple, Type +import warnings + +import click +from loguru import logger +import numpy as np +import torch +from torchsummary import summary +from tqdm import tqdm +from training.gpu_manager import GPUManager +from training.trainer.callbacks import CallbackList +from training.trainer.train import Trainer +import wandb +import yaml + +import text_recognizer.models +from text_recognizer.models import Model +import text_recognizer.networks +from text_recognizer.networks.loss import loss as custom_loss_module + +EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" + + +def _get_level(verbose: int) -> int: + """Sets the logger level.""" + levels = {0: 40, 1: 20, 2: 10} + verbose = verbose if verbose <= 2 else 2 + return levels[verbose] + + +def _create_experiment_dir( + experiment_config: Dict, checkpoint: Optional[str] = None +) -> Path: + """Create new experiment.""" + EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) + experiment_dir = EXPERIMENTS_DIRNAME / ( + f"{experiment_config['model']}_" + + f"{experiment_config['dataset']['type']}_" + + f"{experiment_config['network']['type']}" + ) + + if checkpoint is None: + experiment = datetime.now().strftime("%m%d_%H%M%S") + logger.debug(f"Creating a new experiment called {experiment}") + else: + available_experiments = glob(str(experiment_dir) + "/*") + available_experiments.sort() + if checkpoint == "last": + experiment = available_experiments[-1] + logger.debug(f"Resuming the latest experiment {experiment}") + else: + experiment = checkpoint + if not str(experiment_dir / experiment) in available_experiments: + raise FileNotFoundError("Experiment does not exist.") + logger.debug(f"Resuming the from experiment {checkpoint}") + + experiment_dir = experiment_dir / experiment + + # Create log and model directories. + log_dir = experiment_dir / "log" + model_dir = experiment_dir / "model" + + return experiment_dir, log_dir, model_dir + + +def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]: + """Loads all modules and arguments.""" + # Load the dataset module. + dataset_args = experiment_config.get("dataset", {}) + dataset_ = dataset_args["type"] + + # Import the model module and model arguments. + model_class_ = getattr(text_recognizer.models, experiment_config["model"]) + + # Import metrics. + metric_fns_ = ( + { + metric: getattr(text_recognizer.networks, metric) + for metric in experiment_config["metrics"] + } + if experiment_config["metrics"] is not None + else None + ) + + # Import network module and arguments. + network_fn_ = experiment_config["network"]["type"] + network_args = experiment_config["network"].get("args", {}) + + # Criterion + if experiment_config["criterion"]["type"] in custom_loss_module.__all__: + criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"]) + else: + criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) + criterion_args = experiment_config["criterion"].get("args", {}) or {} + + # Optimizers + optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) + optimizer_args = experiment_config["optimizer"].get("args", {}) + + # Learning rate scheduler + lr_scheduler_ = None + lr_scheduler_args = None + if "lr_scheduler" in experiment_config: + lr_scheduler_ = getattr( + torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] + ) + lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {} + + # SWA scheduler. + if "swa_args" in experiment_config: + swa_args = experiment_config.get("swa_args", {}) or {} + else: + swa_args = None + + model_args = { + "dataset": dataset_, + "dataset_args": dataset_args, + "metrics": metric_fns_, + "network_fn": network_fn_, + "network_args": network_args, + "criterion": criterion_, + "criterion_args": criterion_args, + "optimizer": optimizer_, + "optimizer_args": optimizer_args, + "lr_scheduler": lr_scheduler_, + "lr_scheduler_args": lr_scheduler_args, + "swa_args": swa_args, + } + + return model_class_, model_args + + +def _configure_callbacks(experiment_config: Dict, model_dir: Path) -> CallbackList: + """Configure a callback list for trainer.""" + if "Checkpoint" in experiment_config["callback_args"]: + experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = str( + model_dir + ) + + # Initializes callbacks. + callback_modules = importlib.import_module("training.trainer.callbacks") + callbacks = [] + for callback in experiment_config["callbacks"]: + args = experiment_config["callback_args"][callback] or {} + callbacks.append(getattr(callback_modules, callback)(**args)) + + return callbacks + + +def _configure_logger(log_dir: Path, verbose: int = 0) -> None: + """Configure the loguru logger for output to terminal and disk.""" + # Have to remove default logger to get tqdm to work properly. + logger.remove() + + # Fetch verbosity level. + level = _get_level(verbose) + + logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level) + logger.add( + str(log_dir / "train.log"), + format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", + ) + + +def _save_config(experiment_dir: Path, experiment_config: Dict) -> None: + """Copy config to experiment directory.""" + config_path = experiment_dir / "config.yml" + with open(str(config_path), "w") as f: + yaml.dump(experiment_config, f) + + +def _load_from_checkpoint( + model: Type[Model], model_dir: Path, pretrained_weights: str = None, +) -> None: + """If checkpoint exists, load model weights and optimizers from checkpoint.""" + # Get checkpoint path. + if pretrained_weights is not None: + logger.info(f"Loading weights from {pretrained_weights}.") + checkpoint_path = ( + EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt" + ) + else: + logger.info(f"Loading weights from {model_dir}.") + checkpoint_path = model_dir / "last.pt" + if checkpoint_path.exists(): + logger.info("Loading and resuming training from checkpoint.") + model.load_from_checkpoint(checkpoint_path) + + +def evaluate_embedding(model: Type[Model]) -> Dict: + """Evaluates the embedding space.""" + from pytorch_metric_learning import testers + from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator + + accuracy_calculator = AccuracyCalculator( + include=("mean_average_precision_at_r",), k=10 + ) + + def get_all_embeddings(model: Type[Model]) -> Tuple: + tester = testers.BaseTester() + return tester.get_all_embeddings(model.test_dataset, model.network) + + embeddings, labels = get_all_embeddings(model) + logger.info("Computing embedding accuracy") + accuracies = accuracy_calculator.get_accuracy( + embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True + ) + logger.info( + f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}" + ) + return accuracies + + +def run_experiment( + experiment_config: Dict, + save_weights: bool, + device: str, + use_wandb: bool, + train: bool, + test: bool, + verbose: int = 0, + checkpoint: Optional[str] = None, + pretrained_weights: Optional[str] = None, +) -> None: + """Runs an experiment.""" + logger.info(f"Experiment config: {json.dumps(experiment_config)}") + + # Create new experiment. + experiment_dir, log_dir, model_dir = _create_experiment_dir( + experiment_config, checkpoint + ) + + # Make sure the log/model directory exists. + log_dir.mkdir(parents=True, exist_ok=True) + model_dir.mkdir(parents=True, exist_ok=True) + + # Load the modules and model arguments. + model_class_, model_args = _load_modules_and_arguments(experiment_config) + + # Initializes the model with experiment config. + model = model_class_(**model_args, device=device) + + callbacks = _configure_callbacks(experiment_config, model_dir) + + # Setup logger. + _configure_logger(log_dir, verbose) + + # Load from checkpoint if resuming an experiment. + resume = False + if checkpoint is not None or pretrained_weights is not None: + # resume = True + _load_from_checkpoint(model, model_dir, pretrained_weights) + + logger.info(f"The class mapping is {model.mapping}") + + # Initializes Weights & Biases + if use_wandb: + wandb.init(project="text-recognizer", config=experiment_config, resume=resume) + + # Lets W&B save the model and track the gradients and optional parameters. + wandb.watch(model.network) + + experiment_config["experiment_group"] = experiment_config.get( + "experiment_group", None + ) + + experiment_config["device"] = device + + # Save the config used in the experiment folder. + _save_config(experiment_dir, experiment_config) + + # Prints a summary of the network in terminal. + model.summary(experiment_config["train_args"]["input_shape"]) + + # Load trainer. + trainer = Trainer( + max_epochs=experiment_config["train_args"]["max_epochs"], + callbacks=callbacks, + transformer_model=experiment_config["train_args"]["transformer_model"], + max_norm=experiment_config["train_args"]["max_norm"], + freeze_backbone=experiment_config["train_args"]["freeze_backbone"], + ) + + # Train the model. + if train: + trainer.fit(model) + + # Run inference over test set. + if test: + logger.info("Loading checkpoint with the best weights.") + if "checkpoint" in experiment_config["train_args"]: + model.load_from_checkpoint( + model_dir / experiment_config["train_args"]["checkpoint"] + ) + else: + model.load_from_checkpoint(model_dir / "best.pt") + + logger.info("Running inference on test set.") + if experiment_config["criterion"]["type"] == "EmbeddingLoss": + logger.info("Evaluating embedding.") + score = evaluate_embedding(model) + else: + score = trainer.test(model) + + logger.info(f"Test set evaluation: {score}") + + if use_wandb: + wandb.log( + { + experiment_config["test_metric"]: score[ + experiment_config["test_metric"] + ] + } + ) + + if save_weights: + model.save_weights(model_dir) + + +@click.command() +@click.argument("experiment_config",) +@click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.") +@click.option( + "--save", + is_flag=True, + help="If set, the final weights will be saved to a canonical, version-controlled location.", +) +@click.option( + "--nowandb", is_flag=False, help="If true, do not use wandb for this run." +) +@click.option("--test", is_flag=True, help="If true, test the model.") +@click.option("-v", "--verbose", count=True) +@click.option("--checkpoint", type=str, help="Path to the experiment.") +@click.option( + "--pretrained_weights", type=str, help="Path to pretrained model weights." +) +@click.option( + "--notrain", is_flag=False, help="Do not train the model.", +) +def run_cli( + experiment_config: str, + gpu: int, + save: bool, + nowandb: bool, + notrain: bool, + test: bool, + verbose: int, + checkpoint: Optional[str] = None, + pretrained_weights: Optional[str] = None, +) -> None: + """Run experiment.""" + if gpu < 0: + gpu_manager = GPUManager(True) + gpu = gpu_manager.get_free_gpu() + device = "cuda:" + str(gpu) + + experiment_config = json.loads(experiment_config) + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}" + + run_experiment( + experiment_config, + save, + device, + use_wandb=not nowandb, + train=not notrain, + test=test, + verbose=verbose, + checkpoint=checkpoint, + pretrained_weights=pretrained_weights, + ) + + +if __name__ == "__main__": + run_cli() |