diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /training | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'training')
-rw-r--r-- | training/experiments/default_config_emnist.yml | 70 | ||||
-rw-r--r-- | training/experiments/embedding_experiment.yml | 64 | ||||
-rw-r--r-- | training/experiments/sample_experiment.yml | 99 | ||||
-rw-r--r-- | training/gpu_manager.py | 62 | ||||
-rw-r--r-- | training/prepare_experiments.py | 34 | ||||
-rw-r--r-- | training/run_experiment.py | 382 | ||||
-rw-r--r-- | training/run_sweep.py | 92 | ||||
-rw-r--r-- | training/sweep_emnist.yml | 26 | ||||
-rw-r--r-- | training/sweep_emnist_resnet.yml | 50 | ||||
-rw-r--r-- | training/trainer/__init__.py | 2 | ||||
-rw-r--r-- | training/trainer/callbacks/__init__.py | 29 | ||||
-rw-r--r-- | training/trainer/callbacks/base.py | 188 | ||||
-rw-r--r-- | training/trainer/callbacks/checkpoint.py | 95 | ||||
-rw-r--r-- | training/trainer/callbacks/early_stopping.py | 108 | ||||
-rw-r--r-- | training/trainer/callbacks/lr_schedulers.py | 77 | ||||
-rw-r--r-- | training/trainer/callbacks/progress_bar.py | 65 | ||||
-rw-r--r-- | training/trainer/callbacks/wandb_callbacks.py | 261 | ||||
-rw-r--r-- | training/trainer/train.py | 325 | ||||
-rw-r--r-- | training/trainer/util.py | 28 |
19 files changed, 2057 insertions, 0 deletions
diff --git a/training/experiments/default_config_emnist.yml b/training/experiments/default_config_emnist.yml new file mode 100644 index 0000000..bf2ed0a --- /dev/null +++ b/training/experiments/default_config_emnist.yml @@ -0,0 +1,70 @@ +dataset: EmnistDataset +dataset_args: + sample_to_balance: true + subsample_fraction: 0.33 + transform: null + target_transform: null + seed: 4711 + +data_loader_args: + splits: [train, val] + shuffle: true + num_workers: 8 + cuda: true + +model: CharacterModel +metrics: [accuracy] + +network_args: + in_channels: 1 + num_classes: 80 + depths: [2] + block_sizes: [256] + +train_args: + batch_size: 256 + epochs: 5 + +criterion: CrossEntropyLoss +criterion_args: + weight: null + ignore_index: -100 + reduction: mean + +optimizer: AdamW +optimizer_args: + lr: 1.e-03 + betas: [0.9, 0.999] + eps: 1.e-08 + # weight_decay: 5.e-4 + amsgrad: false + +lr_scheduler: OneCycleLR +lr_scheduler_args: + max_lr: 1.e-03 + epochs: 5 + anneal_strategy: linear + + +callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] +callback_args: + Checkpoint: + monitor: val_accuracy + ProgressBar: + epochs: 5 + log_batch_frequency: 100 + EarlyStopping: + monitor: val_loss + min_delta: 0.0 + patience: 3 + mode: min + WandbCallback: + log_batch_frequency: 10 + WandbImageLogger: + num_examples: 4 + OneCycleLR: + null +verbosity: 1 # 0, 1, 2 +resume_experiment: null +train: true +validation_metric: val_accuracy diff --git a/training/experiments/embedding_experiment.yml b/training/experiments/embedding_experiment.yml new file mode 100644 index 0000000..1e5f941 --- /dev/null +++ b/training/experiments/embedding_experiment.yml @@ -0,0 +1,64 @@ +experiment_group: Embedding Experiments +experiments: + - train_args: + transformer_model: false + batch_size: &batch_size 256 + max_epochs: &max_epochs 32 + input_shape: [[1, 28, 28]] + dataset: + type: EmnistDataset + args: + sample_to_balance: true + subsample_fraction: null + transform: null + target_transform: null + seed: 4711 + train_args: + num_workers: 8 + train_fraction: 0.85 + batch_size: *batch_size + model: CharacterModel + metrics: [] + network: + type: DenseNet + args: + growth_rate: 4 + block_config: [4, 4] + in_channels: 1 + base_channels: 24 + num_classes: 128 + bn_size: 4 + dropout_rate: 0.1 + classifier: true + activation: elu + criterion: + type: EmbeddingLoss + args: + margin: 0.2 + type_of_triplets: semihard + optimizer: + type: AdamW + args: + lr: 1.e-02 + betas: [0.9, 0.999] + eps: 1.e-08 + weight_decay: 5.e-4 + amsgrad: false + lr_scheduler: + type: CosineAnnealingLR + args: + T_max: *max_epochs + callbacks: [Checkpoint, ProgressBar, WandbCallback] + callback_args: + Checkpoint: + monitor: val_loss + mode: min + ProgressBar: + epochs: *max_epochs + WandbCallback: + log_batch_frequency: 10 + verbosity: 1 # 0, 1, 2 + resume_experiment: null + train: true + test: true + test_metric: mean_average_precision_at_r diff --git a/training/experiments/sample_experiment.yml b/training/experiments/sample_experiment.yml new file mode 100644 index 0000000..8f94475 --- /dev/null +++ b/training/experiments/sample_experiment.yml @@ -0,0 +1,99 @@ +experiment_group: Sample Experiments +experiments: + - train_args: + batch_size: 256 + max_epochs: &max_epochs 32 + dataset: + type: EmnistDataset + args: + sample_to_balance: true + subsample_fraction: null + transform: null + target_transform: null + seed: 4711 + train_args: + num_workers: 6 + train_fraction: 0.8 + + model: CharacterModel + metrics: [accuracy] + # network: MLP + # network_args: + # input_size: 784 + # hidden_size: 512 + # output_size: 80 + # num_layers: 5 + # dropout_rate: 0.2 + # activation_fn: SELU + network: + type: ResidualNetwork + args: + in_channels: 1 + num_classes: 80 + depths: [2, 2] + block_sizes: [64, 64] + activation: leaky_relu + # network: + # type: WideResidualNetwork + # args: + # in_channels: 1 + # num_classes: 80 + # depth: 10 + # num_layers: 3 + # width_factor: 4 + # dropout_rate: 0.2 + # activation: SELU + # network: LeNet + # network_args: + # output_size: 62 + # activation_fn: GELU + criterion: + type: CrossEntropyLoss + args: + weight: null + ignore_index: -100 + reduction: mean + optimizer: + type: AdamW + args: + lr: 1.e-02 + betas: [0.9, 0.999] + eps: 1.e-08 + # weight_decay: 5.e-4 + amsgrad: false + # lr_scheduler: + # type: OneCycleLR + # args: + # max_lr: 1.e-03 + # epochs: *max_epochs + # anneal_strategy: linear + lr_scheduler: + type: CosineAnnealingLR + args: + T_max: *max_epochs + interval: epoch + swa_args: + start: 2 + lr: 5.e-2 + callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping] + callback_args: + Checkpoint: + monitor: val_accuracy + ProgressBar: + epochs: null + log_batch_frequency: 100 + EarlyStopping: + monitor: val_loss + min_delta: 0.0 + patience: 5 + mode: min + WandbCallback: + log_batch_frequency: 10 + WandbImageLogger: + num_examples: 4 + use_transpose: true + verbosity: 0 # 0, 1, 2 + resume_experiment: null + train: true + test: true + test_metric: test_accuracy diff --git a/training/gpu_manager.py b/training/gpu_manager.py new file mode 100644 index 0000000..ce1b3dd --- /dev/null +++ b/training/gpu_manager.py @@ -0,0 +1,62 @@ +"""GPUManager class.""" +import os +import time +from typing import Optional + +import gpustat +from loguru import logger +import numpy as np +from redlock import Redlock + + +GPU_LOCK_TIMEOUT = 5000 # ms + + +class GPUManager: + """Class for allocating GPUs.""" + + def __init__(self, verbose: bool = False) -> None: + """Initializes Redlock manager.""" + self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}]) + self.verbose = verbose + + def get_free_gpu(self) -> int: + """Gets a free GPU. + + If some GPUs are available, try reserving one by checking out an exclusive redis lock. + If none available or can not get lock, sleep and check again. + + Returns: + int: The gpu index. + + """ + while True: + gpu_index = self._get_free_gpu() + if gpu_index is not None: + return gpu_index + + if self.verbose: + logger.debug(f"pid {os.getpid()} sleeping") + time.sleep(GPU_LOCK_TIMEOUT / 1000) + + def _get_free_gpu(self) -> Optional[int]: + """Fetches an available GPU index.""" + try: + available_gpu_indices = [ + gpu.index + for gpu in gpustat.GPUStatCollection.new_query() + if gpu.memory_used < 0.5 * gpu.memory_total + ] + except Exception as e: + logger.debug(f"Got the following exception: {e}") + return None + + if available_gpu_indices: + gpu_index = np.random.choice(available_gpu_indices) + if self.verbose: + logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}") + if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT): + return int(gpu_index) + if self.verbose: + logger.debug(f"pid {os.getpid()} could not get lock.") + return None diff --git a/training/prepare_experiments.py b/training/prepare_experiments.py new file mode 100644 index 0000000..21997af --- /dev/null +++ b/training/prepare_experiments.py @@ -0,0 +1,34 @@ +"""Run a experiment from a config file.""" +import json + +import click +import yaml + + +def run_experiments(experiments_filename: str) -> None: + """Run experiment from file.""" + with open(experiments_filename, "r") as f: + experiments_config = yaml.safe_load(f) + + num_experiments = len(experiments_config["experiments"]) + for index in range(num_experiments): + experiment_config = experiments_config["experiments"][index] + experiment_config["experiment_group"] = experiments_config["experiment_group"] + cmd = f"poetry run run-experiment --gpu=-1 --save '{json.dumps(experiment_config)}'" + print(cmd) + + +@click.command() +@click.option( + "--experiments_filename", + required=True, + type=str, + help="Filename of Yaml file of experiments to run.", +) +def run_cli(experiments_filename: str) -> None: + """Parse command-line arguments and run experiments from provided file.""" + run_experiments(experiments_filename) + + +if __name__ == "__main__": + run_cli() diff --git a/training/run_experiment.py b/training/run_experiment.py new file mode 100644 index 0000000..faafea6 --- /dev/null +++ b/training/run_experiment.py @@ -0,0 +1,382 @@ +"""Script to run experiments.""" +from datetime import datetime +from glob import glob +import importlib +import json +import os +from pathlib import Path +import re +from typing import Callable, Dict, List, Optional, Tuple, Type +import warnings + +import click +from loguru import logger +import numpy as np +import torch +from torchsummary import summary +from tqdm import tqdm +from training.gpu_manager import GPUManager +from training.trainer.callbacks import CallbackList +from training.trainer.train import Trainer +import wandb +import yaml + +import text_recognizer.models +from text_recognizer.models import Model +import text_recognizer.networks +from text_recognizer.networks.loss import loss as custom_loss_module + +EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" + + +def _get_level(verbose: int) -> int: + """Sets the logger level.""" + levels = {0: 40, 1: 20, 2: 10} + verbose = verbose if verbose <= 2 else 2 + return levels[verbose] + + +def _create_experiment_dir( + experiment_config: Dict, checkpoint: Optional[str] = None +) -> Path: + """Create new experiment.""" + EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) + experiment_dir = EXPERIMENTS_DIRNAME / ( + f"{experiment_config['model']}_" + + f"{experiment_config['dataset']['type']}_" + + f"{experiment_config['network']['type']}" + ) + + if checkpoint is None: + experiment = datetime.now().strftime("%m%d_%H%M%S") + logger.debug(f"Creating a new experiment called {experiment}") + else: + available_experiments = glob(str(experiment_dir) + "/*") + available_experiments.sort() + if checkpoint == "last": + experiment = available_experiments[-1] + logger.debug(f"Resuming the latest experiment {experiment}") + else: + experiment = checkpoint + if not str(experiment_dir / experiment) in available_experiments: + raise FileNotFoundError("Experiment does not exist.") + logger.debug(f"Resuming the from experiment {checkpoint}") + + experiment_dir = experiment_dir / experiment + + # Create log and model directories. + log_dir = experiment_dir / "log" + model_dir = experiment_dir / "model" + + return experiment_dir, log_dir, model_dir + + +def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]: + """Loads all modules and arguments.""" + # Load the dataset module. + dataset_args = experiment_config.get("dataset", {}) + dataset_ = dataset_args["type"] + + # Import the model module and model arguments. + model_class_ = getattr(text_recognizer.models, experiment_config["model"]) + + # Import metrics. + metric_fns_ = ( + { + metric: getattr(text_recognizer.networks, metric) + for metric in experiment_config["metrics"] + } + if experiment_config["metrics"] is not None + else None + ) + + # Import network module and arguments. + network_fn_ = experiment_config["network"]["type"] + network_args = experiment_config["network"].get("args", {}) + + # Criterion + if experiment_config["criterion"]["type"] in custom_loss_module.__all__: + criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"]) + else: + criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) + criterion_args = experiment_config["criterion"].get("args", {}) or {} + + # Optimizers + optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) + optimizer_args = experiment_config["optimizer"].get("args", {}) + + # Learning rate scheduler + lr_scheduler_ = None + lr_scheduler_args = None + if "lr_scheduler" in experiment_config: + lr_scheduler_ = getattr( + torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] + ) + lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {} + + # SWA scheduler. + if "swa_args" in experiment_config: + swa_args = experiment_config.get("swa_args", {}) or {} + else: + swa_args = None + + model_args = { + "dataset": dataset_, + "dataset_args": dataset_args, + "metrics": metric_fns_, + "network_fn": network_fn_, + "network_args": network_args, + "criterion": criterion_, + "criterion_args": criterion_args, + "optimizer": optimizer_, + "optimizer_args": optimizer_args, + "lr_scheduler": lr_scheduler_, + "lr_scheduler_args": lr_scheduler_args, + "swa_args": swa_args, + } + + return model_class_, model_args + + +def _configure_callbacks(experiment_config: Dict, model_dir: Path) -> CallbackList: + """Configure a callback list for trainer.""" + if "Checkpoint" in experiment_config["callback_args"]: + experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = str( + model_dir + ) + + # Initializes callbacks. + callback_modules = importlib.import_module("training.trainer.callbacks") + callbacks = [] + for callback in experiment_config["callbacks"]: + args = experiment_config["callback_args"][callback] or {} + callbacks.append(getattr(callback_modules, callback)(**args)) + + return callbacks + + +def _configure_logger(log_dir: Path, verbose: int = 0) -> None: + """Configure the loguru logger for output to terminal and disk.""" + # Have to remove default logger to get tqdm to work properly. + logger.remove() + + # Fetch verbosity level. + level = _get_level(verbose) + + logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level) + logger.add( + str(log_dir / "train.log"), + format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", + ) + + +def _save_config(experiment_dir: Path, experiment_config: Dict) -> None: + """Copy config to experiment directory.""" + config_path = experiment_dir / "config.yml" + with open(str(config_path), "w") as f: + yaml.dump(experiment_config, f) + + +def _load_from_checkpoint( + model: Type[Model], model_dir: Path, pretrained_weights: str = None, +) -> None: + """If checkpoint exists, load model weights and optimizers from checkpoint.""" + # Get checkpoint path. + if pretrained_weights is not None: + logger.info(f"Loading weights from {pretrained_weights}.") + checkpoint_path = ( + EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt" + ) + else: + logger.info(f"Loading weights from {model_dir}.") + checkpoint_path = model_dir / "last.pt" + if checkpoint_path.exists(): + logger.info("Loading and resuming training from checkpoint.") + model.load_from_checkpoint(checkpoint_path) + + +def evaluate_embedding(model: Type[Model]) -> Dict: + """Evaluates the embedding space.""" + from pytorch_metric_learning import testers + from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator + + accuracy_calculator = AccuracyCalculator( + include=("mean_average_precision_at_r",), k=10 + ) + + def get_all_embeddings(model: Type[Model]) -> Tuple: + tester = testers.BaseTester() + return tester.get_all_embeddings(model.test_dataset, model.network) + + embeddings, labels = get_all_embeddings(model) + logger.info("Computing embedding accuracy") + accuracies = accuracy_calculator.get_accuracy( + embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True + ) + logger.info( + f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}" + ) + return accuracies + + +def run_experiment( + experiment_config: Dict, + save_weights: bool, + device: str, + use_wandb: bool, + train: bool, + test: bool, + verbose: int = 0, + checkpoint: Optional[str] = None, + pretrained_weights: Optional[str] = None, +) -> None: + """Runs an experiment.""" + logger.info(f"Experiment config: {json.dumps(experiment_config)}") + + # Create new experiment. + experiment_dir, log_dir, model_dir = _create_experiment_dir( + experiment_config, checkpoint + ) + + # Make sure the log/model directory exists. + log_dir.mkdir(parents=True, exist_ok=True) + model_dir.mkdir(parents=True, exist_ok=True) + + # Load the modules and model arguments. + model_class_, model_args = _load_modules_and_arguments(experiment_config) + + # Initializes the model with experiment config. + model = model_class_(**model_args, device=device) + + callbacks = _configure_callbacks(experiment_config, model_dir) + + # Setup logger. + _configure_logger(log_dir, verbose) + + # Load from checkpoint if resuming an experiment. + resume = False + if checkpoint is not None or pretrained_weights is not None: + # resume = True + _load_from_checkpoint(model, model_dir, pretrained_weights) + + logger.info(f"The class mapping is {model.mapping}") + + # Initializes Weights & Biases + if use_wandb: + wandb.init(project="text-recognizer", config=experiment_config, resume=resume) + + # Lets W&B save the model and track the gradients and optional parameters. + wandb.watch(model.network) + + experiment_config["experiment_group"] = experiment_config.get( + "experiment_group", None + ) + + experiment_config["device"] = device + + # Save the config used in the experiment folder. + _save_config(experiment_dir, experiment_config) + + # Prints a summary of the network in terminal. + model.summary(experiment_config["train_args"]["input_shape"]) + + # Load trainer. + trainer = Trainer( + max_epochs=experiment_config["train_args"]["max_epochs"], + callbacks=callbacks, + transformer_model=experiment_config["train_args"]["transformer_model"], + max_norm=experiment_config["train_args"]["max_norm"], + freeze_backbone=experiment_config["train_args"]["freeze_backbone"], + ) + + # Train the model. + if train: + trainer.fit(model) + + # Run inference over test set. + if test: + logger.info("Loading checkpoint with the best weights.") + if "checkpoint" in experiment_config["train_args"]: + model.load_from_checkpoint( + model_dir / experiment_config["train_args"]["checkpoint"] + ) + else: + model.load_from_checkpoint(model_dir / "best.pt") + + logger.info("Running inference on test set.") + if experiment_config["criterion"]["type"] == "EmbeddingLoss": + logger.info("Evaluating embedding.") + score = evaluate_embedding(model) + else: + score = trainer.test(model) + + logger.info(f"Test set evaluation: {score}") + + if use_wandb: + wandb.log( + { + experiment_config["test_metric"]: score[ + experiment_config["test_metric"] + ] + } + ) + + if save_weights: + model.save_weights(model_dir) + + +@click.command() +@click.argument("experiment_config",) +@click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.") +@click.option( + "--save", + is_flag=True, + help="If set, the final weights will be saved to a canonical, version-controlled location.", +) +@click.option( + "--nowandb", is_flag=False, help="If true, do not use wandb for this run." +) +@click.option("--test", is_flag=True, help="If true, test the model.") +@click.option("-v", "--verbose", count=True) +@click.option("--checkpoint", type=str, help="Path to the experiment.") +@click.option( + "--pretrained_weights", type=str, help="Path to pretrained model weights." +) +@click.option( + "--notrain", is_flag=False, help="Do not train the model.", +) +def run_cli( + experiment_config: str, + gpu: int, + save: bool, + nowandb: bool, + notrain: bool, + test: bool, + verbose: int, + checkpoint: Optional[str] = None, + pretrained_weights: Optional[str] = None, +) -> None: + """Run experiment.""" + if gpu < 0: + gpu_manager = GPUManager(True) + gpu = gpu_manager.get_free_gpu() + device = "cuda:" + str(gpu) + + experiment_config = json.loads(experiment_config) + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}" + + run_experiment( + experiment_config, + save, + device, + use_wandb=not nowandb, + train=not notrain, + test=test, + verbose=verbose, + checkpoint=checkpoint, + pretrained_weights=pretrained_weights, + ) + + +if __name__ == "__main__": + run_cli() diff --git a/training/run_sweep.py b/training/run_sweep.py new file mode 100644 index 0000000..a578592 --- /dev/null +++ b/training/run_sweep.py @@ -0,0 +1,92 @@ +"""W&B Sweep Functionality.""" +from ast import literal_eval +import json +import os +from pathlib import Path +import signal +import subprocess # nosec +import sys +from typing import Dict, List, Tuple + +import click +import yaml + +EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" + + +def load_config() -> Dict: + """Load base hyperparameter config.""" + with open(str(EXPERIMENTS_DIRNAME / "default_config_emnist.yml"), "r") as f: + default_config = yaml.safe_load(f) + return default_config + + +def args_to_json( + default_config: dict, preserve_args: tuple = ("gpu", "save") +) -> Tuple[dict, list]: + """Convert command line arguments to nested config values. + + i.e. run_sweep.py --dataset_args.foo=1.7 + { + "dataset_args": { + "foo": 1.7 + } + } + + Args: + default_config (dict): The base config used for every experiment. + preserve_args (tuple): Arguments preserved for all runs. Defaults to ("gpu", "save"). + + Returns: + Tuple[dict, list]: Tuple of config dictionary and list of arguments. + + """ + + args = [] + config = default_config.copy() + key, val = None, None + for arg in sys.argv[1:]: + if "=" in arg: + key, val = arg.split("=") + elif key: + val = arg + else: + key = arg + if key and val: + parsed_key = key.lstrip("-").split(".") + if parsed_key[0] in preserve_args: + args.append("--{}={}".format(parsed_key[0], val)) + else: + nested = config + for level in parsed_key[:-1]: + nested[level] = config.get(level, {}) + nested = nested[level] + try: + # Convert numerics to floats / ints + val = literal_eval(val) + except ValueError: + pass + nested[parsed_key[-1]] = val + key, val = None, None + return config, args + + +def main() -> None: + """Runs a W&B sweep.""" + default_config = load_config() + config, args = args_to_json(default_config) + env = { + k: v for k, v in os.environ.items() if k not in ("WANDB_PROGRAM", "WANDB_ARGS") + } + # pylint: disable=subprocess-popen-preexec-fn + run = subprocess.Popen( + ["python", "training/run_experiment.py", *args, json.dumps(config)], + env=env, + preexec_fn=os.setsid, + ) # nosec + signal.signal(signal.SIGTERM, lambda *args: run.terminate()) + run.wait() + + +if __name__ == "__main__": + main() diff --git a/training/sweep_emnist.yml b/training/sweep_emnist.yml new file mode 100644 index 0000000..48d7261 --- /dev/null +++ b/training/sweep_emnist.yml @@ -0,0 +1,26 @@ +program: training/run_sweep.py +method: bayes +metric: + name: val_loss + goal: minimize +parameters: + dataset: + value: EmnistDataset + model: + value: CharacterModel + network: + value: MLP + network_args.hidden_size: + values: [128, 256] + network_args.dropout_rate: + values: [0.2, 0.4] + network_args.num_layers: + values: [3, 6] + optimizer_args.lr: + values: [1.e-1, 1.e-5] + lr_scheduler_args.max_lr: + values: [1.0e-1, 1.0e-5] + train_args.batch_size: + values: [64, 128] + train_args.epochs: + value: 5 diff --git a/training/sweep_emnist_resnet.yml b/training/sweep_emnist_resnet.yml new file mode 100644 index 0000000..19a3040 --- /dev/null +++ b/training/sweep_emnist_resnet.yml @@ -0,0 +1,50 @@ +program: training/run_sweep.py +method: bayes +metric: + name: val_accuracy + goal: maximize +parameters: + dataset: + value: EmnistDataset + model: + value: CharacterModel + network: + value: ResidualNetwork + network_args.block_sizes: + distribution: q_uniform + min: 16 + max: 256 + q: 8 + network_args.depths: + distribution: int_uniform + min: 1 + max: 3 + network_args.levels: + distribution: int_uniform + min: 1 + max: 2 + network_args.activation: + distribution: categorical + values: + - gelu + - leaky_relu + - relu + - selu + optimizer_args.lr: + distribution: uniform + min: 1.e-5 + max: 1.e-1 + lr_scheduler_args.max_lr: + distribution: uniform + min: 1.e-5 + max: 1.e-1 + train_args.batch_size: + distribution: q_uniform + min: 32 + max: 256 + q: 8 + train_args.epochs: + value: 5 +early_terminate: + type: hyperband + min_iter: 2 diff --git a/training/trainer/__init__.py b/training/trainer/__init__.py new file mode 100644 index 0000000..de41bfb --- /dev/null +++ b/training/trainer/__init__.py @@ -0,0 +1,2 @@ +"""Trainer modules.""" +from .train import Trainer diff --git a/training/trainer/callbacks/__init__.py b/training/trainer/callbacks/__init__.py new file mode 100644 index 0000000..80c4177 --- /dev/null +++ b/training/trainer/callbacks/__init__.py @@ -0,0 +1,29 @@ +"""The callback modules used in the training script.""" +from .base import Callback, CallbackList +from .checkpoint import Checkpoint +from .early_stopping import EarlyStopping +from .lr_schedulers import ( + LRScheduler, + SWA, +) +from .progress_bar import ProgressBar +from .wandb_callbacks import ( + WandbCallback, + WandbImageLogger, + WandbReconstructionLogger, + WandbSegmentationLogger, +) + +__all__ = [ + "Callback", + "CallbackList", + "Checkpoint", + "EarlyStopping", + "LRScheduler", + "WandbCallback", + "WandbImageLogger", + "WandbReconstructionLogger", + "WandbSegmentationLogger", + "ProgressBar", + "SWA", +] diff --git a/training/trainer/callbacks/base.py b/training/trainer/callbacks/base.py new file mode 100644 index 0000000..500b642 --- /dev/null +++ b/training/trainer/callbacks/base.py @@ -0,0 +1,188 @@ +"""Metaclass for callback functions.""" + +from enum import Enum +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch + +from text_recognizer.models import Model + + +class ModeKeys: + """Mode keys for CallbackList.""" + + TRAIN = "train" + VALIDATION = "validation" + + +class Callback: + """Metaclass for callbacks used in training.""" + + def __init__(self) -> None: + """Initializes the Callback instance.""" + self.model = None + + def set_model(self, model: Type[Model]) -> None: + """Set the model.""" + self.model = model + + def on_fit_begin(self) -> None: + """Called when fit begins.""" + pass + + def on_fit_end(self) -> None: + """Called when fit ends.""" + pass + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch. Only used in training mode.""" + pass + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch. Only used in training mode.""" + pass + + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + pass + + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + pass + + def on_test_begin(self) -> None: + """Called at the beginning of test.""" + pass + + def on_test_end(self) -> None: + """Called at the end of test.""" + pass + + +class CallbackList: + """Container for abstracting away callback calls.""" + + mode_keys = ModeKeys() + + def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: + """Container for `Callback` instances. + + This object wraps a list of `Callback` instances and allows them all to be + called via a single end point. + + Args: + model (Type[Model]): A `Model` instance. + callbacks (List[Callback]): List of `Callback` instances. Defaults to None. + + """ + + self._callbacks = callbacks or [] + if model: + self.set_model(model) + + def set_model(self, model: Type[Model]) -> None: + """Set the model for all callbacks.""" + self.model = model + for callback in self._callbacks: + callback.set_model(model=self.model) + + def append(self, callback: Type[Callback]) -> None: + """Append new callback to callback list.""" + self._callbacks.append(callback) + + def on_fit_begin(self) -> None: + """Called when fit begins.""" + for callback in self._callbacks: + callback.on_fit_begin() + + def on_fit_end(self) -> None: + """Called when fit ends.""" + for callback in self._callbacks: + callback.on_fit_end() + + def on_test_begin(self) -> None: + """Called when test begins.""" + for callback in self._callbacks: + callback.on_test_begin() + + def on_test_end(self) -> None: + """Called when test ends.""" + for callback in self._callbacks: + callback.on_test_end() + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + for callback in self._callbacks: + callback.on_epoch_begin(epoch, logs) + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + for callback in self._callbacks: + callback.on_epoch_end(epoch, logs) + + def _call_batch_hook( + self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all batch_{begin | end} methods.""" + if hook == "begin": + self._call_batch_begin_hook(mode, batch, logs) + elif hook == "end": + self._call_batch_end_hook(mode, batch, logs) + else: + raise ValueError(f"Unrecognized hook {hook}.") + + def _call_batch_begin_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all `on_*_batch_begin` methods.""" + hook_name = f"on_{mode}_batch_begin" + self._call_batch_hook_helper(hook_name, batch, logs) + + def _call_batch_end_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all `on_*_batch_end` methods.""" + hook_name = f"on_{mode}_batch_end" + self._call_batch_hook_helper(hook_name, batch, logs) + + def _call_batch_hook_helper( + self, hook_name: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for `on_*_batch_begin` methods.""" + for callback in self._callbacks: + hook = getattr(callback, hook_name) + hook(batch, logs) + + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) + + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: + """Called at the beginning of an epoch.""" + self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) + + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) + + def __iter__(self) -> iter: + """Iter function for callback list.""" + return iter(self._callbacks) diff --git a/training/trainer/callbacks/checkpoint.py b/training/trainer/callbacks/checkpoint.py new file mode 100644 index 0000000..a54e0a9 --- /dev/null +++ b/training/trainer/callbacks/checkpoint.py @@ -0,0 +1,95 @@ +"""Callback checkpoint for training models.""" +from enum import Enum +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class Checkpoint(Callback): + """Saving model parameters at the end of each epoch.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, + checkpoint_path: Union[str, Path], + monitor: str = "accuracy", + mode: str = "auto", + min_delta: float = 0.0, + ) -> None: + """Monitors a quantity that will allow us to determine the best model weights. + + Args: + checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. + monitor (str): Name of the quantity to monitor. Defaults to "accuracy". + mode (str): Description of parameter `mode`. Defaults to "auto". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + + """ + super().__init__() + self.checkpoint_path = Path(checkpoint_path) + self.monitor = monitor + self.mode = mode + self.min_delta = torch.tensor(min_delta) + + if mode not in ["auto", "min", "max"]: + logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." + ) + + torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Saves a checkpoint for the network parameters. + + Args: + epoch (int): The current epoch. + logs (Dict): The log containing the monitored metrics. + + """ + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + is_best = True + else: + is_best = False + + self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor) + + def get_monitor_value(self, logs: Dict) -> Union[float, None]: + """Extracts the monitored value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" + + f" metrics are: {','.join(list(logs.keys()))}" + ) + return None + return monitor_value diff --git a/training/trainer/callbacks/early_stopping.py b/training/trainer/callbacks/early_stopping.py new file mode 100644 index 0000000..02b431f --- /dev/null +++ b/training/trainer/callbacks/early_stopping.py @@ -0,0 +1,108 @@ +"""Implements Early stopping for PyTorch model.""" +from typing import Dict, Union + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from training.trainer.callbacks import Callback + + +class EarlyStopping(Callback): + """Stops training when a monitored metric stops improving.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, + monitor: str = "val_loss", + min_delta: float = 0.0, + patience: int = 3, + mode: str = "auto", + ) -> None: + """Initializes the EarlyStopping callback. + + Args: + monitor (str): Description of parameter `monitor`. Defaults to "val_loss". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + patience (int): Description of parameter `patience`. Defaults to 3. + mode (str): Description of parameter `mode`. Defaults to "auto". + + """ + super().__init__() + self.monitor = monitor + self.patience = patience + self.min_delta = torch.tensor(min_delta) + self.mode = mode + self.wait_count = 0 + self.stopped_epoch = 0 + + if mode not in ["auto", "min", "max"]: + logger.warning( + f"EarlyStopping mode {mode} is unkown, fallback to auto mode." + ) + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." + ) + + self.torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = ( + self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf + ) + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_fit_begin(self) -> Union[torch.lt, torch.gt]: + """Reset the early stopping variables for reuse.""" + self.wait_count = 0 + self.stopped_epoch = 0 + self.best_score = ( + self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf + ) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Computes the early stop criterion.""" + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + self.wait_count = 0 + else: + self.wait_count += 1 + if self.wait_count >= self.patience: + self.stopped_epoch = epoch + self.model.stop_training = True + + def on_fit_end(self) -> None: + """Logs if early stopping was used.""" + if self.stopped_epoch > 0: + logger.info( + f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." + ) + + def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: + """Extracts the monitor value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" + + f"metrics are: {','.join(list(logs.keys()))}" + ) + return None + return torch.tensor(monitor_value) diff --git a/training/trainer/callbacks/lr_schedulers.py b/training/trainer/callbacks/lr_schedulers.py new file mode 100644 index 0000000..630c434 --- /dev/null +++ b/training/trainer/callbacks/lr_schedulers.py @@ -0,0 +1,77 @@ +"""Callbacks for learning rate schedulers.""" +from typing import Callable, Dict, List, Optional, Type + +from torch.optim.swa_utils import update_bn +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class LRScheduler(Callback): + """Generic learning rate scheduler callback.""" + + def __init__(self) -> None: + super().__init__() + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] + self.interval = self.model.lr_scheduler["interval"] + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + if self.interval == "epoch": + if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__: + self.lr_scheduler.step(logs["val_loss"]) + else: + self.lr_scheduler.step() + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + if self.interval == "step": + self.lr_scheduler.step() + + +class SWA(Callback): + """Stochastic Weight Averaging callback.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + self.interval = None + self.swa_scheduler = None + self.swa_start = None + self.current_epoch = 1 + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] + self.interval = self.model.lr_scheduler["interval"] + self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"] + self.swa_start = self.model.swa_scheduler["swa_start"] + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + if epoch > self.swa_start: + self.model.swa_network.update_parameters(self.model.network) + self.swa_scheduler.step() + elif self.interval == "epoch": + self.lr_scheduler.step() + self.current_epoch = epoch + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + if self.current_epoch < self.swa_start and self.interval == "step": + self.lr_scheduler.step() + + def on_fit_end(self) -> None: + """Update batch norm statistics for the swa model at the end of training.""" + if self.model.swa_network: + update_bn( + self.model.val_dataloader(), + self.model.swa_network, + device=self.model.device, + ) diff --git a/training/trainer/callbacks/progress_bar.py b/training/trainer/callbacks/progress_bar.py new file mode 100644 index 0000000..6c4305a --- /dev/null +++ b/training/trainer/callbacks/progress_bar.py @@ -0,0 +1,65 @@ +"""Progress bar callback for the training loop.""" +from typing import Dict, Optional + +from tqdm import tqdm +from training.trainer.callbacks import Callback + + +class ProgressBar(Callback): + """A TQDM progress bar for the training loop.""" + + def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: + """Initializes the tqdm callback.""" + self.epochs = epochs + print(epochs, type(epochs)) + self.log_batch_frequency = log_batch_frequency + self.progress_bar = None + self.val_metrics = {} + + def _configure_progress_bar(self) -> None: + """Configures the tqdm progress bar with custom bar format.""" + self.progress_bar = tqdm( + total=len(self.model.train_dataloader()), + leave=False, + unit="steps", + mininterval=self.log_batch_frequency, + bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", + ) + + def _key_abbreviations(self, logs: Dict) -> Dict: + """Changes the length of keys, so that the progress bar fits better.""" + + def rename(key: str) -> str: + """Renames accuracy to acc.""" + return key.replace("accuracy", "acc") + + return {rename(key): value for key, value in logs.items()} + + # def on_fit_begin(self) -> None: + # """Creates a tqdm progress bar.""" + # self._configure_progress_bar() + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: + """Updates the description with the current epoch.""" + if epoch == 1: + self._configure_progress_bar() + else: + self.progress_bar.reset() + self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """At the end of each epoch, the validation metrics are updated to the progress bar.""" + self.val_metrics = logs + self.progress_bar.set_postfix(**self._key_abbreviations(logs)) + self.progress_bar.update() + + def on_train_batch_end(self, batch: int, logs: Dict) -> None: + """Updates the progress bar for each training step.""" + if self.val_metrics: + logs.update(self.val_metrics) + self.progress_bar.set_postfix(**self._key_abbreviations(logs)) + self.progress_bar.update() + + def on_fit_end(self) -> None: + """Closes the tqdm progress bar.""" + self.progress_bar.close() diff --git a/training/trainer/callbacks/wandb_callbacks.py b/training/trainer/callbacks/wandb_callbacks.py new file mode 100644 index 0000000..552a4f4 --- /dev/null +++ b/training/trainer/callbacks/wandb_callbacks.py @@ -0,0 +1,261 @@ +"""Callback for W&B.""" +from typing import Callable, Dict, List, Optional, Type + +import numpy as np +from training.trainer.callbacks import Callback +import wandb + +import text_recognizer.datasets.transforms as transforms +from text_recognizer.models.base import Model + + +class WandbCallback(Callback): + """A custom W&B metric logger for the trainer.""" + + def __init__(self, log_batch_frequency: int = None) -> None: + """Short summary. + + Args: + log_batch_frequency (int): If None, metrics will be logged every epoch. + If set to an integer, callback will log every metrics every log_batch_frequency. + + """ + super().__init__() + self.log_batch_frequency = log_batch_frequency + + def _on_batch_end(self, batch: int, logs: Dict) -> None: + if self.log_batch_frequency and batch % self.log_batch_frequency == 0: + wandb.log(logs, commit=True) + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Logs training metrics.""" + if logs is not None: + logs["lr"] = self.model.optimizer.param_groups[0]["lr"] + self._on_batch_end(batch, logs) + + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Logs validation metrics.""" + if logs is not None: + self._on_batch_end(batch, logs) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Logs at epoch end.""" + wandb.log(logs, commit=True) + + +class WandbImageLogger(Callback): + """Custom W&B callback for image logging.""" + + def __init__( + self, + example_indices: Optional[List] = None, + num_examples: int = 4, + transform: Optional[bool] = None, + ) -> None: + """Initializes the WandbImageLogger with the model to train. + + Args: + example_indices (Optional[List]): Indices for validation images. Defaults to None. + num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. + transform (Optional[Dict]): Use transform on image or not. Defaults to None. + + """ + + super().__init__() + self.caption = None + self.example_indices = example_indices + self.test_sample_indices = None + self.num_examples = num_examples + self.transform = ( + self._configure_transform(transform) if transform is not None else None + ) + + def _configure_transform(self, transform: Dict) -> Callable: + args = transform["args"] or {} + return getattr(transforms, transform["type"])(**args) + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and extracts validation images from the dataset.""" + self.model = model + self.caption = "Validation Examples" + if self.example_indices is None: + self.example_indices = np.random.randint( + 0, len(self.model.val_dataset), self.num_examples + ) + self.images = self.model.val_dataset.dataset.data[self.example_indices] + self.targets = self.model.val_dataset.dataset.targets[self.example_indices] + self.targets = self.targets.tolist() + + def on_test_begin(self) -> None: + """Get samples from test dataset.""" + self.caption = "Test Examples" + if self.test_sample_indices is None: + self.test_sample_indices = np.random.randint( + 0, len(self.model.test_dataset), self.num_examples + ) + self.images = self.model.test_dataset.data[self.test_sample_indices] + self.targets = self.model.test_dataset.targets[self.test_sample_indices] + self.targets = self.targets.tolist() + + def on_test_end(self) -> None: + """Log test images.""" + self.on_epoch_end(0, {}) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Get network predictions on validation images.""" + images = [] + for i, image in enumerate(self.images): + image = self.transform(image) if self.transform is not None else image + pred, conf = self.model.predict_on_image(image) + if isinstance(self.targets[i], list): + ground_truth = "".join( + [ + self.model.mapper(int(target_index) - 26) + if target_index > 35 + else self.model.mapper(int(target_index)) + for target_index in self.targets[i] + ] + ).rstrip("_") + else: + ground_truth = self.model.mapper(int(self.targets[i])) + caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" + images.append(wandb.Image(image, caption=caption)) + + wandb.log({f"{self.caption}": images}, commit=False) + + +class WandbSegmentationLogger(Callback): + """Custom W&B callback for image logging.""" + + def __init__( + self, + class_labels: Dict, + example_indices: Optional[List] = None, + num_examples: int = 4, + ) -> None: + """Initializes the WandbImageLogger with the model to train. + + Args: + class_labels (Dict): A dict with int as key and class string as value. + example_indices (Optional[List]): Indices for validation images. Defaults to None. + num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. + + """ + + super().__init__() + self.caption = None + self.class_labels = {int(k): v for k, v in class_labels.items()} + self.example_indices = example_indices + self.test_sample_indices = None + self.num_examples = num_examples + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and extracts validation images from the dataset.""" + self.model = model + self.caption = "Validation Segmentation Examples" + if self.example_indices is None: + self.example_indices = np.random.randint( + 0, len(self.model.val_dataset), self.num_examples + ) + self.images = self.model.val_dataset.dataset.data[self.example_indices] + self.targets = self.model.val_dataset.dataset.targets[self.example_indices] + self.targets = self.targets.tolist() + + def on_test_begin(self) -> None: + """Get samples from test dataset.""" + self.caption = "Test Segmentation Examples" + if self.test_sample_indices is None: + self.test_sample_indices = np.random.randint( + 0, len(self.model.test_dataset), self.num_examples + ) + self.images = self.model.test_dataset.data[self.test_sample_indices] + self.targets = self.model.test_dataset.targets[self.test_sample_indices] + self.targets = self.targets.tolist() + + def on_test_end(self) -> None: + """Log test images.""" + self.on_epoch_end(0, {}) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Get network predictions on validation images.""" + images = [] + for i, image in enumerate(self.images): + pred_mask = ( + self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() + ) + gt_mask = np.array(self.targets[i]) + images.append( + wandb.Image( + image, + masks={ + "predictions": { + "mask_data": pred_mask, + "class_labels": self.class_labels, + }, + "ground_truth": { + "mask_data": gt_mask, + "class_labels": self.class_labels, + }, + }, + ) + ) + + wandb.log({f"{self.caption}": images}, commit=False) + + +class WandbReconstructionLogger(Callback): + """Custom W&B callback for image reconstructions logging.""" + + def __init__( + self, example_indices: Optional[List] = None, num_examples: int = 4, + ) -> None: + """Initializes the WandbImageLogger with the model to train. + + Args: + example_indices (Optional[List]): Indices for validation images. Defaults to None. + num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. + + """ + + super().__init__() + self.caption = None + self.example_indices = example_indices + self.test_sample_indices = None + self.num_examples = num_examples + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and extracts validation images from the dataset.""" + self.model = model + self.caption = "Validation Reconstructions Examples" + if self.example_indices is None: + self.example_indices = np.random.randint( + 0, len(self.model.val_dataset), self.num_examples + ) + self.images = self.model.val_dataset.dataset.data[self.example_indices] + + def on_test_begin(self) -> None: + """Get samples from test dataset.""" + self.caption = "Test Reconstructions Examples" + if self.test_sample_indices is None: + self.test_sample_indices = np.random.randint( + 0, len(self.model.test_dataset), self.num_examples + ) + self.images = self.model.test_dataset.data[self.test_sample_indices] + + def on_test_end(self) -> None: + """Log test images.""" + self.on_epoch_end(0, {}) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Get network predictions on validation images.""" + images = [] + for image in self.images: + reconstructed_image = ( + self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() + ) + images.append(image) + images.append(reconstructed_image) + + wandb.log( + {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False, + ) diff --git a/training/trainer/train.py b/training/trainer/train.py new file mode 100644 index 0000000..b770c94 --- /dev/null +++ b/training/trainer/train.py @@ -0,0 +1,325 @@ +"""Training script for PyTorch models.""" + +from pathlib import Path +import time +from typing import Dict, List, Optional, Tuple, Type +import warnings + +from einops import rearrange +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torch.optim.swa_utils import update_bn +from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA +from training.trainer.util import log_val_metric +import wandb + +from text_recognizer.models import Model + + +torch.backends.cudnn.benchmark = True +np.random.seed(4711) +torch.manual_seed(4711) +torch.cuda.manual_seed(4711) + + +warnings.filterwarnings("ignore") + + +class Trainer: + """Trainer for training PyTorch models.""" + + def __init__( + self, + max_epochs: int, + callbacks: List[Type[Callback]], + transformer_model: bool = False, + max_norm: float = 0.0, + freeze_backbone: Optional[int] = None, + ) -> None: + """Initialization of the Trainer. + + Args: + max_epochs (int): The maximum number of epochs in the training loop. + callbacks (CallbackList): List of callbacks to be called. + transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. + max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0. + freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training + Transformers. Default is None. + + """ + # Training arguments. + self.start_epoch = 1 + self.max_epochs = max_epochs + self.callbacks = callbacks + self.freeze_backbone = freeze_backbone + + # Flag for setting callbacks. + self.callbacks_configured = False + + self.transformer_model = transformer_model + + self.max_norm = max_norm + + # Model placeholders + self.model = None + + def _configure_callbacks(self) -> None: + """Instantiate the CallbackList.""" + if not self.callbacks_configured: + # If learning rate schedulers are present, they need to be added to the callbacks. + if self.model.swa_scheduler is not None: + self.callbacks.append(SWA()) + elif self.model.lr_scheduler is not None: + self.callbacks.append(LRScheduler()) + + self.callbacks = CallbackList(self.model, self.callbacks) + + def compute_metrics( + self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int + ) -> Dict: + """Computes metrics for output and target pairs.""" + # Compute metrics. + loss = loss.detach().float().item() + output = output.detach() + targets = targets.detach() + if self.model.metrics is not None: + metrics = {} + for metric in self.model.metrics: + if metric == "cer" or metric == "wer": + metrics[metric] = self.model.metrics[metric]( + output, + targets, + batch_size, + self.model.mapper(self.model.pad_token), + ) + else: + metrics[metric] = self.model.metrics[metric](output, targets) + else: + metrics = {} + metrics["loss"] = loss + + return metrics + + def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: + """Performs the training step.""" + # Pass the tensor to the device for computation. + data, targets = samples + data, targets = ( + data.to(self.model.device), + targets.to(self.model.device), + ) + + batch_size = data.shape[0] + + # Placeholder for uxiliary loss. + aux_loss = None + + # Forward pass. + # Get the network prediction. + if self.transformer_model: + if self.freeze_backbone is not None and batch < self.freeze_backbone: + with torch.no_grad(): + image_features = self.model.network.extract_image_features(data) + + if isinstance(image_features, Tuple): + image_features, _ = image_features + + output = self.model.network.decode_image_features( + image_features, targets[:, :-1] + ) + else: + output = self.model.network.forward(data, targets[:, :-1]) + if isinstance(output, Tuple): + output, aux_loss = output + output = rearrange(output, "b t v -> (b t) v") + targets = rearrange(targets[:, 1:], "b t -> (b t)").long() + else: + output = self.model.forward(data) + + if isinstance(output, Tuple): + output, aux_loss = output + targets = data + + # Compute the loss. + loss = self.model.criterion(output, targets) + + if aux_loss is not None: + loss += aux_loss + + # Backward pass. + # Clear the previous gradients. + for p in self.model.network.parameters(): + p.grad = None + + # Compute the gradients. + loss.backward() + + if self.max_norm > 0: + torch.nn.utils.clip_grad_norm_( + self.model.network.parameters(), self.max_norm + ) + + # Perform updates using calculated gradients. + self.model.optimizer.step() + + metrics = self.compute_metrics(output, targets, loss, batch_size) + + return metrics + + def train(self) -> None: + """Runs the training loop for one epoch.""" + # Set model to traning mode. + self.model.train() + + for batch, samples in enumerate(self.model.train_dataloader()): + self.callbacks.on_train_batch_begin(batch) + metrics = self.training_step(batch, samples) + self.callbacks.on_train_batch_end(batch, logs=metrics) + + @torch.no_grad() + def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: + """Performs the validation step.""" + + # Pass the tensor to the device for computation. + data, targets = samples + data, targets = ( + data.to(self.model.device), + targets.to(self.model.device), + ) + + batch_size = data.shape[0] + + # Placeholder for uxiliary loss. + aux_loss = None + + # Forward pass. + # Get the network prediction. + # Use SWA if available and using test dataset. + if self.transformer_model: + output = self.model.network.forward(data, targets[:, :-1]) + if isinstance(output, Tuple): + output, aux_loss = output + output = rearrange(output, "b t v -> (b t) v") + targets = rearrange(targets[:, 1:], "b t -> (b t)").long() + else: + output = self.model.forward(data) + + if isinstance(output, Tuple): + output, aux_loss = output + targets = data + + # Compute the loss. + loss = self.model.criterion(output, targets) + + if aux_loss is not None: + loss += aux_loss + + # Compute metrics. + metrics = self.compute_metrics(output, targets, loss, batch_size) + + return metrics + + def validate(self) -> Dict: + """Runs the validation loop for one epoch.""" + # Set model to eval mode. + self.model.eval() + + # Summary for the current eval loop. + summary = [] + + for batch, samples in enumerate(self.model.val_dataloader()): + self.callbacks.on_validation_batch_begin(batch) + metrics = self.validation_step(batch, samples) + self.callbacks.on_validation_batch_end(batch, logs=metrics) + summary.append(metrics) + + # Compute mean of all metrics. + metrics_mean = { + "val_" + metric: np.mean([x[metric] for x in summary]) + for metric in summary[0] + } + + return metrics_mean + + def fit(self, model: Type[Model]) -> None: + """Runs the training and evaluation loop.""" + + # Sets model, loads the data, criterion, and optimizers. + self.model = model + self.model.prepare_data() + self.model.configure_model() + + # Configure callbacks. + self._configure_callbacks() + + # Set start time. + t_start = time.time() + + self.callbacks.on_fit_begin() + + # Run the training loop. + for epoch in range(self.start_epoch, self.max_epochs + 1): + self.callbacks.on_epoch_begin(epoch) + + # Perform one training pass over the training set. + self.train() + + # Evaluate the model on the validation set. + val_metrics = self.validate() + log_val_metric(val_metrics, epoch) + + self.callbacks.on_epoch_end(epoch, logs=val_metrics) + + if self.model.stop_training: + break + + # Calculate the total training time. + t_end = time.time() + t_training = t_end - t_start + + self.callbacks.on_fit_end() + + logger.info(f"Training took {t_training:.2f} s.") + + # "Teardown". + self.model = None + + def test(self, model: Type[Model]) -> Dict: + """Run inference on test data.""" + + # Sets model, loads the data, criterion, and optimizers. + self.model = model + self.model.prepare_data() + self.model.configure_model() + + # Configure callbacks. + self._configure_callbacks() + + self.callbacks.on_test_begin() + + self.model.eval() + + # Check if SWA network is available. + self.model.use_swa_model() + + # Summary for the current test loop. + summary = [] + + for batch, samples in enumerate(self.model.test_dataloader()): + metrics = self.validation_step(batch, samples) + summary.append(metrics) + + self.callbacks.on_test_end() + + # Compute mean of all test metrics. + metrics_mean = { + "test_" + metric: np.mean([x[metric] for x in summary]) + for metric in summary[0] + } + + # "Teardown". + self.model = None + + return metrics_mean diff --git a/training/trainer/util.py b/training/trainer/util.py new file mode 100644 index 0000000..7cf1b45 --- /dev/null +++ b/training/trainer/util.py @@ -0,0 +1,28 @@ +"""Utility functions for training neural networks.""" +from typing import Dict, Optional + +from loguru import logger + + +def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None: + """Logging of val metrics to file/terminal.""" + log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") + logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())) + + +class RunningAverage: + """Maintains a running average.""" + + def __init__(self) -> None: + """Initializes the parameters.""" + self.steps = 0 + self.total = 0 + + def update(self, val: float) -> None: + """Updates the parameters.""" + self.total += val + self.steps += 1 + + def __call__(self) -> float: + """Computes the running average.""" + return self.total / float(self.steps) |