diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 20:47:55 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 20:47:55 +0200 |
commit | 9ae5fa1a88899180f88ddb14d4cef457ceb847e5 (patch) | |
tree | 4fe2bcd82553c8062eb0908ae6442c123addf55d /training/run_experiment.py | |
parent | 9e54591b7e342edc93b0bb04809a0f54045c6a15 (diff) |
Add new training loop with PyTorch Lightning, remove stale files
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r-- | training/run_experiment.py | 419 |
1 files changed, 105 insertions, 314 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py index faafea6..ff8b886 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -1,162 +1,34 @@ """Script to run experiments.""" from datetime import datetime -from glob import glob import importlib -import json -import os from pathlib import Path -import re -from typing import Callable, Dict, List, Optional, Tuple, Type -import warnings +from typing import Dict, List, Optional, Type import click from loguru import logger import numpy as np +from omegaconf import OmegaConf +import pytorch_lightning as pl import torch +from torch import nn from torchsummary import summary from tqdm import tqdm -from training.gpu_manager import GPUManager -from training.trainer.callbacks import CallbackList -from training.trainer.train import Trainer import wandb -import yaml -import text_recognizer.models -from text_recognizer.models import Model -import text_recognizer.networks -from text_recognizer.networks.loss import loss as custom_loss_module +SEED = 4711 EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" -def _get_level(verbose: int) -> int: - """Sets the logger level.""" - levels = {0: 40, 1: 20, 2: 10} - verbose = verbose if verbose <= 2 else 2 - return levels[verbose] - - -def _create_experiment_dir( - experiment_config: Dict, checkpoint: Optional[str] = None -) -> Path: - """Create new experiment.""" - EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) - experiment_dir = EXPERIMENTS_DIRNAME / ( - f"{experiment_config['model']}_" - + f"{experiment_config['dataset']['type']}_" - + f"{experiment_config['network']['type']}" - ) - - if checkpoint is None: - experiment = datetime.now().strftime("%m%d_%H%M%S") - logger.debug(f"Creating a new experiment called {experiment}") - else: - available_experiments = glob(str(experiment_dir) + "/*") - available_experiments.sort() - if checkpoint == "last": - experiment = available_experiments[-1] - logger.debug(f"Resuming the latest experiment {experiment}") - else: - experiment = checkpoint - if not str(experiment_dir / experiment) in available_experiments: - raise FileNotFoundError("Experiment does not exist.") - logger.debug(f"Resuming the from experiment {checkpoint}") - - experiment_dir = experiment_dir / experiment - - # Create log and model directories. - log_dir = experiment_dir / "log" - model_dir = experiment_dir / "model" - - return experiment_dir, log_dir, model_dir - - -def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]: - """Loads all modules and arguments.""" - # Load the dataset module. - dataset_args = experiment_config.get("dataset", {}) - dataset_ = dataset_args["type"] - - # Import the model module and model arguments. - model_class_ = getattr(text_recognizer.models, experiment_config["model"]) - - # Import metrics. - metric_fns_ = ( - { - metric: getattr(text_recognizer.networks, metric) - for metric in experiment_config["metrics"] - } - if experiment_config["metrics"] is not None - else None - ) - - # Import network module and arguments. - network_fn_ = experiment_config["network"]["type"] - network_args = experiment_config["network"].get("args", {}) - - # Criterion - if experiment_config["criterion"]["type"] in custom_loss_module.__all__: - criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"]) - else: - criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) - criterion_args = experiment_config["criterion"].get("args", {}) or {} - - # Optimizers - optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) - optimizer_args = experiment_config["optimizer"].get("args", {}) - - # Learning rate scheduler - lr_scheduler_ = None - lr_scheduler_args = None - if "lr_scheduler" in experiment_config: - lr_scheduler_ = getattr( - torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] - ) - lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {} - - # SWA scheduler. - if "swa_args" in experiment_config: - swa_args = experiment_config.get("swa_args", {}) or {} - else: - swa_args = None - - model_args = { - "dataset": dataset_, - "dataset_args": dataset_args, - "metrics": metric_fns_, - "network_fn": network_fn_, - "network_args": network_args, - "criterion": criterion_, - "criterion_args": criterion_args, - "optimizer": optimizer_, - "optimizer_args": optimizer_args, - "lr_scheduler": lr_scheduler_, - "lr_scheduler_args": lr_scheduler_args, - "swa_args": swa_args, - } - - return model_class_, model_args - - -def _configure_callbacks(experiment_config: Dict, model_dir: Path) -> CallbackList: - """Configure a callback list for trainer.""" - if "Checkpoint" in experiment_config["callback_args"]: - experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = str( - model_dir - ) - - # Initializes callbacks. - callback_modules = importlib.import_module("training.trainer.callbacks") - callbacks = [] - for callback in experiment_config["callbacks"]: - args = experiment_config["callback_args"][callback] or {} - callbacks.append(getattr(callback_modules, callback)(**args)) - - return callbacks +def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: + """Configure the loguru logger for output to terminal and disk.""" + def _get_level(verbose: int) -> int: + """Sets the logger level.""" + levels = {0: 40, 1: 20, 2: 10} + verbose = verbose if verbose <= 2 else 2 + return levels[verbose] -def _configure_logger(log_dir: Path, verbose: int = 0) -> None: - """Configure the loguru logger for output to terminal and disk.""" # Have to remove default logger to get tqdm to work properly. logger.remove() @@ -164,219 +36,138 @@ def _configure_logger(log_dir: Path, verbose: int = 0) -> None: level = _get_level(verbose) logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level) - logger.add( - str(log_dir / "train.log"), - format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", - ) - - -def _save_config(experiment_dir: Path, experiment_config: Dict) -> None: - """Copy config to experiment directory.""" - config_path = experiment_dir / "config.yml" - with open(str(config_path), "w") as f: - yaml.dump(experiment_config, f) - - -def _load_from_checkpoint( - model: Type[Model], model_dir: Path, pretrained_weights: str = None, -) -> None: - """If checkpoint exists, load model weights and optimizers from checkpoint.""" - # Get checkpoint path. - if pretrained_weights is not None: - logger.info(f"Loading weights from {pretrained_weights}.") - checkpoint_path = ( - EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt" + if log_dir is not None: + logger.add( + str(log_dir / "train.log"), + format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", ) - else: - logger.info(f"Loading weights from {model_dir}.") - checkpoint_path = model_dir / "last.pt" - if checkpoint_path.exists(): - logger.info("Loading and resuming training from checkpoint.") - model.load_from_checkpoint(checkpoint_path) -def evaluate_embedding(model: Type[Model]) -> Dict: - """Evaluates the embedding space.""" - from pytorch_metric_learning import testers - from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator +def _import_class(module_and_class_name: str) -> type: + """Import class from module.""" + module_name, class_name = module_and_class_name.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) - accuracy_calculator = AccuracyCalculator( - include=("mean_average_precision_at_r",), k=10 - ) - def get_all_embeddings(model: Type[Model]) -> Tuple: - tester = testers.BaseTester() - return tester.get_all_embeddings(model.test_dataset, model.network) +def _configure_pl_callbacks(args: List[Dict]) -> List[Type[pl.callbacks.Callback]]: + """Configures PyTorch Lightning callbacks.""" + pl_callbacks = [ + getattr(pl.callbacks, callback["type"])(**callback["args"]) for callback in args + ] + return pl_callbacks - embeddings, labels = get_all_embeddings(model) - logger.info("Computing embedding accuracy") - accuracies = accuracy_calculator.get_accuracy( - embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True - ) - logger.info( - f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}" - ) - return accuracies +def _configure_wandb_callback( + network: Type[nn.Module], args: Dict +) -> pl.loggers.WandbLogger: + """Configures wandb logger.""" + pl_logger = pl.loggers.WandbLogger() + pl_logger.watch(network) + pl_logger.log_hyperparams(vars(args)) + return pl_logger -def run_experiment( - experiment_config: Dict, - save_weights: bool, - device: str, - use_wandb: bool, - train: bool, - test: bool, - verbose: int = 0, - checkpoint: Optional[str] = None, - pretrained_weights: Optional[str] = None, -) -> None: - """Runs an experiment.""" - logger.info(f"Experiment config: {json.dumps(experiment_config)}") - # Create new experiment. - experiment_dir, log_dir, model_dir = _create_experiment_dir( - experiment_config, checkpoint +def _save_best_weights( + callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool +) -> None: + """Saves the best model.""" + model_checkpoint_callback = next( + callback + for callback in callbacks + if isinstance(callback, pl.callbacks.ModelCheckpoint) ) + best_model_path = model_checkpoint_callback.best_model_path + if best_model_path: + logger.info(f"Best model saved at: {best_model_path}") + if use_wandb: + logger.info("Uploading model to W&B...") + wandb.save(best_model_path) - # Make sure the log/model directory exists. - log_dir.mkdir(parents=True, exist_ok=True) - model_dir.mkdir(parents=True, exist_ok=True) - - # Load the modules and model arguments. - model_class_, model_args = _load_modules_and_arguments(experiment_config) - - # Initializes the model with experiment config. - model = model_class_(**model_args, device=device) - - callbacks = _configure_callbacks(experiment_config, model_dir) - # Setup logger. - _configure_logger(log_dir, verbose) +def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None: + """Runs experiment.""" + logger.info("Starting experiment...") - # Load from checkpoint if resuming an experiment. - resume = False - if checkpoint is not None or pretrained_weights is not None: - # resume = True - _load_from_checkpoint(model, model_dir, pretrained_weights) + # Seed everything in the experiment + logger.info(f"Seeding everthing with seed={SEED}") + pl.utilities.seed.seed_everything(SEED) - logger.info(f"The class mapping is {model.mapping}") + # Load config. + logger.info(f"Loading config from: {path}") + config = OmegaConf.load(path) - # Initializes Weights & Biases - if use_wandb: - wandb.init(project="text-recognizer", config=experiment_config, resume=resume) + # Load classes + data_module_class = _import_class(f"text_recognizer.data.{config.data.type}") + network_class = _import_class(f"text_recognizer.networks.{config.network.type}") + lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}") - # Lets W&B save the model and track the gradients and optional parameters. - wandb.watch(model.network) + # Initialize data object and network. + data_module = data_module_class(**config.data.args) + network = network_class(**config.network.args) - experiment_config["experiment_group"] = experiment_config.get( - "experiment_group", None + # Load callback and logger + callbacks = _configure_pl_callbacks(config.callbacks) + pl_logger = ( + _configure_wandb_callback(network, config.network.args) + if use_wandb + else pl.logger.TensorBoardLogger("training/logs") ) - experiment_config["device"] = device - - # Save the config used in the experiment folder. - _save_config(experiment_dir, experiment_config) - - # Prints a summary of the network in terminal. - model.summary(experiment_config["train_args"]["input_shape"]) + # Checkpoint + if config.load_checkpoint is not None: + logger.info( + f"Loading network weights from checkpoint: {config.load_checkpoint}" + ) + lit_model = lit_model_class.load_from_checkpoint( + config.load_checkpoint, network=network, **config.model.args + ) + else: + lit_model = lit_model_class(**config.model.args) - # Load trainer. - trainer = Trainer( - max_epochs=experiment_config["train_args"]["max_epochs"], + trainer = pl.Trainer( + **config.trainer, callbacks=callbacks, - transformer_model=experiment_config["train_args"]["transformer_model"], - max_norm=experiment_config["train_args"]["max_norm"], - freeze_backbone=experiment_config["train_args"]["freeze_backbone"], + logger=pl_logger, + weigths_save_path="training/logs", ) - # Train the model. + if tune: + logger.info(f"Tuning learning rate and batch size...") + trainer.tune(lit_model, datamodule=data_module) + if train: - trainer.fit(model) + logger.info(f"Training network...") + trainer.fit(lit_model, datamodule=data_module) - # Run inference over test set. if test: - logger.info("Loading checkpoint with the best weights.") - if "checkpoint" in experiment_config["train_args"]: - model.load_from_checkpoint( - model_dir / experiment_config["train_args"]["checkpoint"] - ) - else: - model.load_from_checkpoint(model_dir / "best.pt") - - logger.info("Running inference on test set.") - if experiment_config["criterion"]["type"] == "EmbeddingLoss": - logger.info("Evaluating embedding.") - score = evaluate_embedding(model) - else: - score = trainer.test(model) - - logger.info(f"Test set evaluation: {score}") - - if use_wandb: - wandb.log( - { - experiment_config["test_metric"]: score[ - experiment_config["test_metric"] - ] - } - ) + logger.info(f"Testing network...") + trainer.test(lit_model, datamodule=data_module) - if save_weights: - model.save_weights(model_dir) + _save_best_weights(callbacks, use_wandb) @click.command() -@click.argument("experiment_config",) -@click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.") +@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.") +@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.") @click.option( - "--save", - is_flag=True, - help="If set, the final weights will be saved to a canonical, version-controlled location.", -) -@click.option( - "--nowandb", is_flag=False, help="If true, do not use wandb for this run." + "--tune", is_flag=True, help="If true, tune hyperparameters for training." ) +@click.option("--train", is_flag=True, help="If true, train the model.") @click.option("--test", is_flag=True, help="If true, test the model.") @click.option("-v", "--verbose", count=True) -@click.option("--checkpoint", type=str, help="Path to the experiment.") -@click.option( - "--pretrained_weights", type=str, help="Path to pretrained model weights." -) -@click.option( - "--notrain", is_flag=False, help="Do not train the model.", -) -def run_cli( +def cli( experiment_config: str, - gpu: int, - save: bool, - nowandb: bool, - notrain: bool, + use_wandb: bool, + tune: bool, + train: bool, test: bool, verbose: int, - checkpoint: Optional[str] = None, - pretrained_weights: Optional[str] = None, ) -> None: """Run experiment.""" - if gpu < 0: - gpu_manager = GPUManager(True) - gpu = gpu_manager.get_free_gpu() - device = "cuda:" + str(gpu) - - experiment_config = json.loads(experiment_config) - os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}" - - run_experiment( - experiment_config, - save, - device, - use_wandb=not nowandb, - train=not notrain, - test=test, - verbose=verbose, - checkpoint=checkpoint, - pretrained_weights=pretrained_weights, - ) + _configure_logging(None, verbose=verbose) + run(path=experiment_config, train=train, test=test, tune=tune, use_wandb=use_wandb) if __name__ == "__main__": - run_cli() + cli() |