From e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Tue, 8 Sep 2020 23:14:23 +0200 Subject: IAM datasets implemented. --- src/training/experiments/sample_experiment.yml | 127 +++++++----- src/training/prepare_experiments.py | 4 +- src/training/run_experiment.py | 238 +++++++++++++--------- src/training/run_sweep.py | 86 +++++++- src/training/sweep_emnist.yml | 26 +++ src/training/sweep_emnist_resnet.yml | 50 +++++ src/training/trainer/callbacks/__init__.py | 15 +- src/training/trainer/callbacks/base.py | 78 ------- src/training/trainer/callbacks/checkpoint.py | 95 +++++++++ src/training/trainer/callbacks/lr_schedulers.py | 52 +++++ src/training/trainer/callbacks/progress_bar.py | 19 +- src/training/trainer/callbacks/wandb_callbacks.py | 32 +-- src/training/trainer/train.py | 170 ++++++++++------ src/training/trainer/util.py | 9 + 14 files changed, 686 insertions(+), 315 deletions(-) create mode 100644 src/training/sweep_emnist.yml create mode 100644 src/training/sweep_emnist_resnet.yml create mode 100644 src/training/trainer/callbacks/checkpoint.py (limited to 'src/training') diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml index b00bd5a..17e220e 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/src/training/experiments/sample_experiment.yml @@ -1,17 +1,20 @@ experiment_group: Sample Experiments experiments: - - dataset: EmnistDataset - dataset_args: - sample_to_balance: true - subsample_fraction: null - transform: null - target_transform: null - seed: 4711 - data_loader_args: - splits: [train, val] - shuffle: true - num_workers: 8 - cuda: true + - train_args: + batch_size: 256 + max_epochs: 32 + dataset: + type: EmnistDataset + args: + sample_to_balance: true + subsample_fraction: null + transform: null + target_transform: null + seed: 4711 + train_args: + num_workers: 6 + train_fraction: 0.8 + model: CharacterModel metrics: [accuracy] # network: MLP @@ -19,65 +22,81 @@ experiments: # input_size: 784 # hidden_size: 512 # output_size: 80 - # num_layers: 3 - # dropout_rate: 0 + # num_layers: 5 + # dropout_rate: 0.2 # activation_fn: SELU - network: ResidualNetwork - network_args: - in_channels: 1 - num_classes: 80 - depths: [2, 1] - block_sizes: [96, 32] + network: + type: ResidualNetwork + args: + in_channels: 1 + num_classes: 80 + depths: [2, 2] + block_sizes: [64, 64] + activation: leaky_relu + stn: true + # network: + # type: WideResidualNetwork + # args: + # in_channels: 1 + # num_classes: 80 + # depth: 10 + # num_layers: 3 + # width_factor: 4 + # dropout_rate: 0.2 + # activation: SELU # network: LeNet # network_args: # output_size: 62 # activation_fn: GELU - train_args: - batch_size: 256 - epochs: 32 - criterion: CrossEntropyLoss - criterion_args: - weight: null - ignore_index: -100 - reduction: mean - # optimizer: RMSprop - # optimizer_args: - # lr: 1.e-3 - # alpha: 0.9 - # eps: 1.e-7 - # momentum: 0 - # weight_decay: 0 - # centered: false - optimizer: AdamW - optimizer_args: - lr: 1.e-03 - betas: [0.9, 0.999] - eps: 1.e-08 - # weight_decay: 5.e-4 - amsgrad: false - # lr_scheduler: null - lr_scheduler: OneCycleLR - lr_scheduler_args: - max_lr: 1.e-03 - epochs: 32 - anneal_strategy: linear - callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] + criterion: + type: CrossEntropyLoss + args: + weight: null + ignore_index: -100 + reduction: mean + optimizer: + type: AdamW + args: + lr: 1.e-02 + betas: [0.9, 0.999] + eps: 1.e-08 + # weight_decay: 5.e-4 + amsgrad: false + # lr_scheduler: + # type: OneCycleLR + # args: + # max_lr: 1.e-03 + # epochs: null + # anneal_strategy: linear + lr_scheduler: + type: CosineAnnealingLR + args: + T_max: null + swa_args: + start: 2 + lr: 5.e-2 + callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping, SWA] # OneCycleLR] callback_args: Checkpoint: monitor: val_accuracy ProgressBar: - epochs: 32 + epochs: null log_batch_frequency: 100 EarlyStopping: monitor: val_loss min_delta: 0.0 - patience: 3 + patience: 5 mode: min WandbCallback: log_batch_frequency: 10 WandbImageLogger: num_examples: 4 - OneCycleLR: + use_transpose: true + # OneCycleLR: + # null + SWA: null - verbosity: 1 # 0, 1, 2 + verbosity: 0 # 0, 1, 2 resume_experiment: null + test: true + test_metric: test_accuracy diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py index 4c3f9ba..e00540c 100644 --- a/src/training/prepare_experiments.py +++ b/src/training/prepare_experiments.py @@ -9,14 +9,14 @@ import yaml def run_experiments(experiments_filename: str) -> None: """Run experiment from file.""" - with open(experiments_filename) as f: + with open(experiments_filename, "r") as f: experiments_config = yaml.safe_load(f) num_experiments = len(experiments_config["experiments"]) for index in range(num_experiments): experiment_config = experiments_config["experiments"][index] experiment_config["experiment_group"] = experiments_config["experiment_group"] - cmd = f"python training/run_experiment.py --gpu=-1 --save --experiment_config='{json.dumps(experiment_config)}'" + cmd = f"python training/run_experiment.py --gpu=-1 --save '{json.dumps(experiment_config)}'" print(cmd) diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 8c063ff..4317d66 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -6,18 +6,19 @@ import json import os from pathlib import Path import re -from typing import Callable, Dict, Tuple, Type +from typing import Callable, Dict, List, Tuple, Type import click from loguru import logger import torch from tqdm import tqdm from training.gpu_manager import GPUManager -from training.trainer.callbacks import CallbackList +from training.trainer.callbacks import Callback, CallbackList from training.trainer.train import Trainer import wandb import yaml + from text_recognizer.models import Model @@ -37,10 +38,14 @@ def get_level(experiment_config: Dict) -> int: return 10 -def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path: +def create_experiment_dir(experiment_config: Dict) -> Path: """Create new experiment.""" EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) - experiment_dir = EXPERIMENTS_DIRNAME / model.__name__ + experiment_dir = EXPERIMENTS_DIRNAME / ( + f"{experiment_config['model']}_" + + f"{experiment_config['dataset']['type']}_" + + f"{experiment_config['network']['type']}" + ) if experiment_config["resume_experiment"] is None: experiment = datetime.now().strftime("%m%d_%H%M%S") logger.debug(f"Creating a new experiment called {experiment}") @@ -54,70 +59,89 @@ def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path: experiment = experiment_config["resume_experiment"] if not str(experiment_dir / experiment) in available_experiments: raise FileNotFoundError("Experiment does not exist.") - logger.debug(f"Resuming the experiment {experiment}") experiment_dir = experiment_dir / experiment - return experiment_dir + # Create log and model directories. + log_dir = experiment_dir / "log" + model_dir = experiment_dir / "model" + + return experiment_dir, log_dir, model_dir -def check_args(args: Dict) -> Dict: + +def check_args(args: Dict, train_args: Dict) -> Dict: """Checks that the arguments are not None.""" + args = args or {} + + # I just want to set total epochs in train args, instead of changing all parameter. + if "epochs" in args and args["epochs"] is None: + args["epochs"] = train_args["max_epochs"] + + # For CosineAnnealingLR. + if "T_max" in args and args["T_max"] is None: + args["T_max"] = train_args["max_epochs"] + return args or {} def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]: """Loads all modules and arguments.""" # Import the data loader arguments. - data_loader_args = experiment_config.get("data_loader_args", {}) train_args = experiment_config.get("train_args", {}) - data_loader_args["batch_size"] = train_args["batch_size"] - data_loader_args["dataset"] = experiment_config["dataset"] - data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {}) + + # Load the dataset module. + dataset_args = experiment_config.get("dataset", {}) + dataset_args["train_args"]["batch_size"] = train_args["batch_size"] + datasets_module = importlib.import_module("text_recognizer.datasets") + dataset_ = getattr(datasets_module, dataset_args["type"]) # Import the model module and model arguments. models_module = importlib.import_module("text_recognizer.models") model_class_ = getattr(models_module, experiment_config["model"]) # Import metrics. - metric_fns_ = { - metric: getattr(models_module, metric) - for metric in experiment_config["metrics"] - } + metric_fns_ = ( + { + metric: getattr(models_module, metric) + for metric in experiment_config["metrics"] + } + if experiment_config["metrics"] is not None + else None + ) # Import network module and arguments. network_module = importlib.import_module("text_recognizer.networks") - network_fn_ = getattr(network_module, experiment_config["network"]) - network_args = experiment_config.get("network_args", {}) + network_fn_ = getattr(network_module, experiment_config["network"]["type"]) + network_args = experiment_config["network"].get("args", {}) # Criterion - criterion_ = getattr(torch.nn, experiment_config["criterion"]) - criterion_args = experiment_config.get("criterion_args", {}) + criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) + criterion_args = experiment_config["criterion"].get("args", {}) - # Optimizer - optimizer_ = getattr(torch.optim, experiment_config["optimizer"]) - optimizer_args = experiment_config.get("optimizer_args", {}) - - # Callbacks - callback_modules = importlib.import_module("training.trainer.callbacks") - callbacks = [ - getattr(callback_modules, callback)( - **check_args(experiment_config["callback_args"][callback]) - ) - for callback in experiment_config["callbacks"] - ] + # Optimizers + optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) + optimizer_args = experiment_config["optimizer"].get("args", {}) # Learning rate scheduler + lr_scheduler_ = None + lr_scheduler_args = None if experiment_config["lr_scheduler"] is not None: lr_scheduler_ = getattr( - torch.optim.lr_scheduler, experiment_config["lr_scheduler"] + torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] + ) + lr_scheduler_args = check_args( + experiment_config["lr_scheduler"].get("args", {}), train_args ) - lr_scheduler_args = experiment_config.get("lr_scheduler_args", {}) + + # SWA scheduler. + if "swa_args" in experiment_config: + swa_args = check_args(experiment_config.get("swa_args", {}), train_args) else: - lr_scheduler_ = None - lr_scheduler_args = None + swa_args = None model_args = { - "data_loader_args": data_loader_args, + "dataset": dataset_, + "dataset_args": dataset_args, "metrics": metric_fns_, "network_fn": network_fn_, "network_args": network_args, @@ -127,43 +151,33 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] "optimizer_args": optimizer_args, "lr_scheduler": lr_scheduler_, "lr_scheduler_args": lr_scheduler_args, + "swa_args": swa_args, } - return model_class_, model_args, callbacks - + return model_class_, model_args -def run_experiment( - experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False -) -> None: - """Runs an experiment.""" - - # Load the modules and model arguments. - model_class_, model_args, callbacks = load_modules_and_arguments(experiment_config) - - # Initializes the model with experiment config. - model = model_class_(**model_args, device=device) - # Instantiate a CallbackList. - callbacks = CallbackList(model, callbacks) - - # Create new experiment. - experiment_dir = create_experiment_dir(model, experiment_config) +def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList: + """Configure a callback list for trainer.""" + train_args = experiment_config.get("train_args", {}) - # Create log and model directories. - log_dir = experiment_dir / "log" - model_dir = experiment_dir / "model" + if "Checkpoint" in experiment_config["callback_args"]: + experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir - # Set the model dir to be able to save checkpoints. - model.model_dir = model_dir + # Callbacks + callback_modules = importlib.import_module("training.trainer.callbacks") + callbacks = [ + getattr(callback_modules, callback)( + **check_args(experiment_config["callback_args"][callback], train_args) + ) + for callback in experiment_config["callbacks"] + ] - # Get checkpoint path. - checkpoint_path = model_dir / "last.pt" - if not checkpoint_path.exists(): - checkpoint_path = None + return callbacks - # Make sure the log directory exists. - log_dir.mkdir(parents=True, exist_ok=True) +def configure_logger(experiment_config: Dict, log_dir: Path) -> None: + """Configure the loguru logger for output to terminal and disk.""" # Have to remove default logger to get tqdm to work properly. logger.remove() @@ -176,13 +190,50 @@ def run_experiment( format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", ) - if "cuda" in device: - gpu_index = re.sub("[^0-9]+", "", device) - logger.info( - f"Running experiment with config {experiment_config} on GPU {gpu_index}" - ) - else: - logger.info(f"Running experiment with config {experiment_config} on CPU") + +def save_config(experiment_dir: Path, experiment_config: Dict) -> None: + """Copy config to experiment directory.""" + config_path = experiment_dir / "config.yml" + with open(str(config_path), "w") as f: + yaml.dump(experiment_config, f) + + +def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) -> None: + """If checkpoint exists, load model weights and optimizers from checkpoint.""" + # Get checkpoint path. + checkpoint_path = model_dir / "last.pt" + if checkpoint_path.exists(): + logger.info("Loading and resuming training from last checkpoint.") + model.load_checkpoint(checkpoint_path) + + +def run_experiment( + experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False +) -> None: + """Runs an experiment.""" + logger.info(f"Experiment config: {json.dumps(experiment_config)}") + + # Create new experiment. + experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config) + + # Make sure the log/model directory exists. + log_dir.mkdir(parents=True, exist_ok=True) + model_dir.mkdir(parents=True, exist_ok=True) + + # Load the modules and model arguments. + model_class_, model_args = load_modules_and_arguments(experiment_config) + + # Initializes the model with experiment config. + model = model_class_(**model_args, device=device) + + callbacks = configure_callbacks(experiment_config, model_dir) + + # Setup logger. + configure_logger(experiment_config, log_dir) + + # Load from checkpoint if resuming an experiment. + if experiment_config["resume_experiment"] is not None: + load_from_checkpoint(model, log_dir, model_dir) logger.info(f"The class mapping is {model.mapping}") @@ -193,9 +244,6 @@ def run_experiment( # Lets W&B save the model and track the gradients and optional parameters. wandb.watch(model.network) - # PÅ•ints a summary of the network in terminal. - model.summary() - experiment_config["train_args"] = { **DEFAULT_TRAIN_ARGS, **experiment_config.get("train_args", {}), @@ -208,41 +256,41 @@ def run_experiment( experiment_config["device"] = device # Save the config used in the experiment folder. - config_path = experiment_dir / "config.yml" - with open(str(config_path), "w") as f: - yaml.dump(experiment_config, f) + save_config(experiment_dir, experiment_config) - # Train the model. + # Load trainer. trainer = Trainer( - model=model, - model_dir=model_dir, - train_args=experiment_config["train_args"], - callbacks=callbacks, - checkpoint_path=checkpoint_path, + max_epochs=experiment_config["train_args"]["max_epochs"], callbacks=callbacks, ) - trainer.fit() + # Train the model. + trainer.fit(model) - logger.info("Loading checkpoint with the best weights.") - model.load_checkpoint(model_dir / "best.pt") + # Run inference over test set. + if experiment_config["test"]: + logger.info("Loading checkpoint with the best weights.") + model.load_from_checkpoint(model_dir / "best.pt") - score = trainer.validate() + logger.info("Running inference on test set.") + score = trainer.test(model) - logger.info(f"Validation set evaluation: {score}") + logger.info(f"Test set evaluation: {score}") - if use_wandb: - wandb.log({"validation_metric": score["val_accuracy"]}) + if use_wandb: + wandb.log( + { + experiment_config["test_metric"]: score[ + experiment_config["test_metric"] + ] + } + ) if save_weights: model.save_weights(model_dir) @click.command() -@click.option( - "--experiment_config", - type=str, - help='Experiment JSON, e.g. \'{"dataloader": "EmnistDataLoader", "model": "CharacterModel", "network": "mlp"}\'', -) +@click.argument("experiment_config",) @click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.") @click.option( "--save", diff --git a/src/training/run_sweep.py b/src/training/run_sweep.py index 5c5322a..a578592 100644 --- a/src/training/run_sweep.py +++ b/src/training/run_sweep.py @@ -2,7 +2,91 @@ from ast import literal_eval import json import os +from pathlib import Path import signal import subprocess # nosec import sys -from typing import Tuple +from typing import Dict, List, Tuple + +import click +import yaml + +EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" + + +def load_config() -> Dict: + """Load base hyperparameter config.""" + with open(str(EXPERIMENTS_DIRNAME / "default_config_emnist.yml"), "r") as f: + default_config = yaml.safe_load(f) + return default_config + + +def args_to_json( + default_config: dict, preserve_args: tuple = ("gpu", "save") +) -> Tuple[dict, list]: + """Convert command line arguments to nested config values. + + i.e. run_sweep.py --dataset_args.foo=1.7 + { + "dataset_args": { + "foo": 1.7 + } + } + + Args: + default_config (dict): The base config used for every experiment. + preserve_args (tuple): Arguments preserved for all runs. Defaults to ("gpu", "save"). + + Returns: + Tuple[dict, list]: Tuple of config dictionary and list of arguments. + + """ + + args = [] + config = default_config.copy() + key, val = None, None + for arg in sys.argv[1:]: + if "=" in arg: + key, val = arg.split("=") + elif key: + val = arg + else: + key = arg + if key and val: + parsed_key = key.lstrip("-").split(".") + if parsed_key[0] in preserve_args: + args.append("--{}={}".format(parsed_key[0], val)) + else: + nested = config + for level in parsed_key[:-1]: + nested[level] = config.get(level, {}) + nested = nested[level] + try: + # Convert numerics to floats / ints + val = literal_eval(val) + except ValueError: + pass + nested[parsed_key[-1]] = val + key, val = None, None + return config, args + + +def main() -> None: + """Runs a W&B sweep.""" + default_config = load_config() + config, args = args_to_json(default_config) + env = { + k: v for k, v in os.environ.items() if k not in ("WANDB_PROGRAM", "WANDB_ARGS") + } + # pylint: disable=subprocess-popen-preexec-fn + run = subprocess.Popen( + ["python", "training/run_experiment.py", *args, json.dumps(config)], + env=env, + preexec_fn=os.setsid, + ) # nosec + signal.signal(signal.SIGTERM, lambda *args: run.terminate()) + run.wait() + + +if __name__ == "__main__": + main() diff --git a/src/training/sweep_emnist.yml b/src/training/sweep_emnist.yml new file mode 100644 index 0000000..48d7261 --- /dev/null +++ b/src/training/sweep_emnist.yml @@ -0,0 +1,26 @@ +program: training/run_sweep.py +method: bayes +metric: + name: val_loss + goal: minimize +parameters: + dataset: + value: EmnistDataset + model: + value: CharacterModel + network: + value: MLP + network_args.hidden_size: + values: [128, 256] + network_args.dropout_rate: + values: [0.2, 0.4] + network_args.num_layers: + values: [3, 6] + optimizer_args.lr: + values: [1.e-1, 1.e-5] + lr_scheduler_args.max_lr: + values: [1.0e-1, 1.0e-5] + train_args.batch_size: + values: [64, 128] + train_args.epochs: + value: 5 diff --git a/src/training/sweep_emnist_resnet.yml b/src/training/sweep_emnist_resnet.yml new file mode 100644 index 0000000..19a3040 --- /dev/null +++ b/src/training/sweep_emnist_resnet.yml @@ -0,0 +1,50 @@ +program: training/run_sweep.py +method: bayes +metric: + name: val_accuracy + goal: maximize +parameters: + dataset: + value: EmnistDataset + model: + value: CharacterModel + network: + value: ResidualNetwork + network_args.block_sizes: + distribution: q_uniform + min: 16 + max: 256 + q: 8 + network_args.depths: + distribution: int_uniform + min: 1 + max: 3 + network_args.levels: + distribution: int_uniform + min: 1 + max: 2 + network_args.activation: + distribution: categorical + values: + - gelu + - leaky_relu + - relu + - selu + optimizer_args.lr: + distribution: uniform + min: 1.e-5 + max: 1.e-1 + lr_scheduler_args.max_lr: + distribution: uniform + min: 1.e-5 + max: 1.e-1 + train_args.batch_size: + distribution: q_uniform + min: 32 + max: 256 + q: 8 + train_args.epochs: + value: 5 +early_terminate: + type: hyperband + min_iter: 2 diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index 5942276..c81e4bf 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -1,7 +1,16 @@ """The callback modules used in the training script.""" -from .base import Callback, CallbackList, Checkpoint +from .base import Callback, CallbackList +from .checkpoint import Checkpoint from .early_stopping import EarlyStopping -from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR +from .lr_schedulers import ( + CosineAnnealingLR, + CyclicLR, + MultiStepLR, + OneCycleLR, + ReduceLROnPlateau, + StepLR, + SWA, +) from .progress_bar import ProgressBar from .wandb_callbacks import WandbCallback, WandbImageLogger @@ -9,6 +18,7 @@ __all__ = [ "Callback", "CallbackList", "Checkpoint", + "CosineAnnealingLR", "EarlyStopping", "WandbCallback", "WandbImageLogger", @@ -18,4 +28,5 @@ __all__ = [ "ProgressBar", "ReduceLROnPlateau", "StepLR", + "SWA", ] diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py index 8df94f3..8c7b085 100644 --- a/src/training/trainer/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -168,81 +168,3 @@ class CallbackList: def __iter__(self) -> iter: """Iter function for callback list.""" return iter(self._callbacks) - - -class Checkpoint(Callback): - """Saving model parameters at the end of each epoch.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 - ) -> None: - """Monitors a quantity that will allow us to determine the best model weights. - - Args: - monitor (str): Name of the quantity to monitor. Defaults to "accuracy". - mode (str): Description of parameter `mode`. Defaults to "auto". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - - """ - super().__init__() - self.monitor = monitor - self.mode = mode - self.min_delta = torch.tensor(min_delta) - - if mode not in ["auto", "min", "max"]: - logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." - ) - - torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Saves a checkpoint for the network parameters. - - Args: - epoch (int): The current epoch. - logs (Dict): The log containing the monitored metrics. - - """ - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - is_best = True - else: - is_best = False - - self.model.save_checkpoint(is_best, epoch, self.monitor) - - def get_monitor_value(self, logs: Dict) -> Union[float, None]: - """Extracts the monitored value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" - + f"metrics are: {','.join(list(logs.keys()))}" - ) - return None - return monitor_value diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py new file mode 100644 index 0000000..6fe06d3 --- /dev/null +++ b/src/training/trainer/callbacks/checkpoint.py @@ -0,0 +1,95 @@ +"""Callback checkpoint for training models.""" +from enum import Enum +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class Checkpoint(Callback): + """Saving model parameters at the end of each epoch.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, + checkpoint_path: Path, + monitor: str = "accuracy", + mode: str = "auto", + min_delta: float = 0.0, + ) -> None: + """Monitors a quantity that will allow us to determine the best model weights. + + Args: + checkpoint_path (Path): Path to the experiment with the checkpoint. + monitor (str): Name of the quantity to monitor. Defaults to "accuracy". + mode (str): Description of parameter `mode`. Defaults to "auto". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + + """ + super().__init__() + self.checkpoint_path = checkpoint_path + self.monitor = monitor + self.mode = mode + self.min_delta = torch.tensor(min_delta) + + if mode not in ["auto", "min", "max"]: + logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." + ) + + torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Saves a checkpoint for the network parameters. + + Args: + epoch (int): The current epoch. + logs (Dict): The log containing the monitored metrics. + + """ + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + is_best = True + else: + is_best = False + + self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor) + + def get_monitor_value(self, logs: Dict) -> Union[float, None]: + """Extracts the monitored value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" + + f" metrics are: {','.join(list(logs.keys()))}" + ) + return None + return monitor_value diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index ba2226a..bb41d2d 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -1,6 +1,7 @@ """Callbacks for learning rate schedulers.""" from typing import Callable, Dict, List, Optional, Type +from torch.optim.swa_utils import update_bn from training.trainer.callbacks import Callback from text_recognizer.models import Model @@ -95,3 +96,54 @@ class OneCycleLR(Callback): def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" self.lr_scheduler.step() + + +class CosineAnnealingLR(Callback): + """Callback for Cosine Annealing.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class SWA(Callback): + """Stochastic Weight Averaging callback.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.swa_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.swa_start = self.model.swa_start + self.swa_scheduler = self.model.lr_scheduler + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + if epoch > self.swa_start: + self.model.swa_network.update_parameters(self.model.network) + self.swa_scheduler.step() + else: + self.lr_scheduler.step() + + def on_fit_end(self) -> None: + """Update batch norm statistics for the swa model at the end of training.""" + if self.model.swa_network: + update_bn( + self.model.val_dataloader(), + self.model.swa_network, + device=self.model.device, + ) diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py index 1970747..7829fa0 100644 --- a/src/training/trainer/callbacks/progress_bar.py +++ b/src/training/trainer/callbacks/progress_bar.py @@ -18,11 +18,11 @@ class ProgressBar(Callback): def _configure_progress_bar(self) -> None: """Configures the tqdm progress bar with custom bar format.""" self.progress_bar = tqdm( - total=len(self.model.data_loaders["train"]), - leave=True, - unit="step", + total=len(self.model.train_dataloader()), + leave=False, + unit="steps", mininterval=self.log_batch_frequency, - bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", + bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", ) def _key_abbreviations(self, logs: Dict) -> Dict: @@ -34,13 +34,16 @@ class ProgressBar(Callback): return {rename(key): value for key, value in logs.items()} - def on_fit_begin(self) -> None: - """Creates a tqdm progress bar.""" - self._configure_progress_bar() + # def on_fit_begin(self) -> None: + # """Creates a tqdm progress bar.""" + # self._configure_progress_bar() def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: """Updates the description with the current epoch.""" - self.progress_bar.reset() + if epoch == 1: + self._configure_progress_bar() + else: + self.progress_bar.reset() self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") def on_epoch_end(self, epoch: int, logs: Dict) -> None: diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index e44c745..6643a44 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -2,7 +2,8 @@ from typing import Callable, Dict, List, Optional, Type import numpy as np -from torchvision.transforms import Compose, ToTensor +import torch +from torchvision.transforms import ToTensor from training.trainer.callbacks import Callback import wandb @@ -50,43 +51,48 @@ class WandbImageLogger(Callback): self, example_indices: Optional[List] = None, num_examples: int = 4, - transfroms: Optional[Callable] = None, + use_transpose: Optional[bool] = False, ) -> None: """Initializes the WandbImageLogger with the model to train. Args: example_indices (Optional[List]): Indices for validation images. Defaults to None. num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to - None. + use_transpose (Optional[bool]): Use transpose on image or not. Defaults to False. """ super().__init__() self.example_indices = example_indices self.num_examples = num_examples - self.transfroms = transfroms - if self.transfroms is None: - self.transforms = Compose([Transpose()]) + self.transpose = Transpose() if use_transpose else None def set_model(self, model: Type[Model]) -> None: """Sets the model and extracts validation images from the dataset.""" self.model = model - data_loader = self.model.data_loaders["val"] if self.example_indices is None: self.example_indices = np.random.randint( - 0, len(data_loader.dataset.data), self.num_examples + 0, len(self.model.val_dataset), self.num_examples ) - self.val_images = data_loader.dataset.data[self.example_indices] - self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() + self.val_images = self.model.val_dataset.dataset.data[self.example_indices] + self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices] + self.val_targets = self.val_targets.tolist() def on_epoch_end(self, epoch: int, logs: Dict) -> None: """Get network predictions on validation images.""" images = [] for i, image in enumerate(self.val_images): - image = self.transforms(image) + image = self.transpose(image) if self.transpose is not None else image pred, conf = self.model.predict_on_image(image) - ground_truth = self.model.mapper(int(self.val_targets[i])) + if isinstance(self.val_targets[i], list): + ground_truth = "".join( + [ + self.model.mapper(int(target_index)) + for target_index in self.val_targets[i] + ] + ).rstrip("_") + else: + ground_truth = self.val_targets[i] caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" images.append(wandb.Image(image, caption=caption)) diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index a75ae8f..b240157 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -8,8 +8,9 @@ from loguru import logger import numpy as np import torch from torch import Tensor +from torch.optim.swa_utils import update_bn from training.trainer.callbacks import Callback, CallbackList -from training.trainer.util import RunningAverage +from training.trainer.util import log_val_metric, RunningAverage import wandb from text_recognizer.models import Model @@ -24,37 +25,55 @@ torch.cuda.manual_seed(4711) class Trainer: """Trainer for training PyTorch models.""" - def __init__( - self, - model: Type[Model], - model_dir: Path, - train_args: Dict, - callbacks: CallbackList, - checkpoint_path: Optional[Path] = None, - ) -> None: + # TODO: proper add teardown? + + def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None: """Initialization of the Trainer. Args: - model (Type[Model]): A model object. - model_dir (Path): Path to the model directory. - train_args (Dict): The training arguments. + max_epochs (int): The maximum number of epochs in the training loop. callbacks (CallbackList): List of callbacks to be called. - checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. """ - self.model = model - self.model_dir = model_dir - self.checkpoint_path = checkpoint_path + # Training arguments. self.start_epoch = 1 - self.epochs = train_args["epochs"] + self.max_epochs = max_epochs self.callbacks = callbacks - if self.checkpoint_path is not None: - self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + # Flag for setting callbacks. + self.callbacks_configured = False + + # Model placeholders + self.model = None + + def _configure_callbacks(self) -> None: + if not self.callbacks_configured: + # Instantiate a CallbackList. + self.callbacks = CallbackList(self.model, self.callbacks) + + def compute_metrics( + self, + output: Tensor, + targets: Tensor, + loss: Tensor, + loss_avg: Type[RunningAverage], + ) -> Dict: + """Computes metrics for output and target pairs.""" + # Compute metrics. + loss = loss.detach().float().item() + loss_avg.update(loss) + output = output.detach() + targets = targets.detach() + if self.model.metrics is not None: + metrics = { + metric: self.model.metrics[metric](output, targets) + for metric in self.model.metrics + } + else: + metrics = {} + metrics["loss"] = loss - # Parse the name of the experiment. - experiment_dir = str(self.model_dir.parents[1]).split("/") - self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1] + return metrics def training_step( self, @@ -75,11 +94,12 @@ class Trainer: output = self.model.network(data) # Compute the loss. - loss = self.model.criterion(output, targets) + loss = self.model.loss_fn(output, targets) # Backward pass. # Clear the previous gradients. - self.model.optimizer.zero_grad() + for p in self.model.network.parameters(): + p.grad = None # Compute the gradients. loss.backward() @@ -87,15 +107,8 @@ class Trainer: # Perform updates using calculated gradients. self.model.optimizer.step() - # Compute metrics. - loss_avg.update(loss.item()) - output = output.data.cpu() - targets = targets.data.cpu() - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } - metrics["loss"] = loss_avg() + metrics = self.compute_metrics(output, targets, loss, loss_avg) + return metrics def train(self) -> None: @@ -106,9 +119,7 @@ class Trainer: # Running average for the loss. loss_avg = RunningAverage() - data_loader = self.model.data_loaders["train"] - - for batch, samples in enumerate(data_loader): + for batch, samples in enumerate(self.model.train_dataloader()): self.callbacks.on_train_batch_begin(batch) metrics = self.training_step(batch, samples, loss_avg) self.callbacks.on_train_batch_end(batch, logs=metrics) @@ -119,6 +130,7 @@ class Trainer: batch: int, samples: Tuple[Tensor, Tensor], loss_avg: Type[RunningAverage], + use_swa: bool = False, ) -> Dict: """Performs the validation step.""" # Pass the tensor to the device for computation. @@ -130,44 +142,32 @@ class Trainer: # Forward pass. # Get the network prediction. - output = self.model.network(data) + # Use SWA if available and using test dataset. + if use_swa and self.model.swa_network is None: + output = self.model.swa_network(data) + else: + output = self.model.network(data) # Compute the loss. - loss = self.model.criterion(output, targets) + loss = self.model.loss_fn(output, targets) # Compute metrics. - loss_avg.update(loss.item()) - output = output.data.cpu() - targets = targets.data.cpu() - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } - metrics["loss"] = loss.item() + metrics = self.compute_metrics(output, targets, loss, loss_avg) return metrics - def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None: - log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") - logger.debug( - log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) - ) - - def validate(self, epoch: Optional[int] = None) -> Dict: + def validate(self) -> Dict: """Runs the validation loop for one epoch.""" # Set model to eval mode. self.model.eval() - # Running average for the loss. - data_loader = self.model.data_loaders["val"] - # Running average for the loss. loss_avg = RunningAverage() # Summary for the current eval loop. summary = [] - for batch, samples in enumerate(data_loader): + for batch, samples in enumerate(self.model.val_dataloader()): self.callbacks.on_validation_batch_begin(batch) metrics = self.validation_step(batch, samples, loss_avg) self.callbacks.on_validation_batch_end(batch, logs=metrics) @@ -178,14 +178,19 @@ class Trainer: "val_" + metric: np.mean([x[metric] for x in summary]) for metric in summary[0] } - self._log_val_metric(metrics_mean, epoch) return metrics_mean - def fit(self) -> None: + def fit(self, model: Type[Model]) -> None: """Runs the training and evaluation loop.""" - logger.debug(f"Running an experiment called {self.experiment_name}.") + # Sets model, loads the data, criterion, and optimizers. + self.model = model + self.model.prepare_data() + self.model.configure_model() + + # Configure callbacks. + self._configure_callbacks() # Set start time. t_start = time.time() @@ -193,14 +198,15 @@ class Trainer: self.callbacks.on_fit_begin() # Run the training loop. - for epoch in range(self.start_epoch, self.epochs + 1): + for epoch in range(self.start_epoch, self.max_epochs + 1): self.callbacks.on_epoch_begin(epoch) # Perform one training pass over the training set. self.train() # Evaluate the model on the validation set. - val_metrics = self.validate(epoch) + val_metrics = self.validate() + log_val_metric(val_metrics, epoch) self.callbacks.on_epoch_end(epoch, logs=val_metrics) @@ -214,3 +220,43 @@ class Trainer: self.callbacks.on_fit_end() logger.info(f"Training took {t_training:.2f} s.") + + # "Teardown". + self.model = None + + def test(self, model: Type[Model]) -> Dict: + """Run inference on test data.""" + + # Sets model, loads the data, criterion, and optimizers. + self.model = model + self.model.prepare_data() + self.model.configure_model() + + # Configure callbacks. + self._configure_callbacks() + + self.model.eval() + + # Check if SWA network is available. + use_swa = True if self.model.swa_network is not None else False + + # Running average for the loss. + loss_avg = RunningAverage() + + # Summary for the current test loop. + summary = [] + + for batch, samples in enumerate(self.model.test_dataloader()): + metrics = self.validation_step(batch, samples, loss_avg, use_swa) + summary.append(metrics) + + # Compute mean of all test metrics. + metrics_mean = { + "test_" + metric: np.mean([x[metric] for x in summary]) + for metric in summary[0] + } + + # "Teardown". + self.model = None + + return metrics_mean diff --git a/src/training/trainer/util.py b/src/training/trainer/util.py index 132b2dc..7cf1b45 100644 --- a/src/training/trainer/util.py +++ b/src/training/trainer/util.py @@ -1,4 +1,13 @@ """Utility functions for training neural networks.""" +from typing import Dict, Optional + +from loguru import logger + + +def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None: + """Logging of val metrics to file/terminal.""" + log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") + logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())) class RunningAverage: -- cgit v1.2.3-70-g09d2