From 1f459ba19422593de325983040e176f97cf4ffc0 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 20 Aug 2020 22:18:35 +0200 Subject: A lot of stuff working :D. ResNet implemented! --- src/training/callbacks/__init__.py | 19 -- src/training/callbacks/base.py | 240 -------------------- src/training/callbacks/early_stopping.py | 107 --------- src/training/callbacks/lr_schedulers.py | 97 -------- src/training/callbacks/wandb_callbacks.py | 93 -------- src/training/experiments/sample_experiment.yml | 37 +-- src/training/population_based_training/__init__.py | 1 - .../population_based_training.py | 1 - src/training/prepare_experiments.py | 6 +- src/training/run_experiment.py | 19 +- src/training/train.py | 249 --------------------- src/training/trainer/__init__.py | 2 + src/training/trainer/callbacks/__init__.py | 21 ++ src/training/trainer/callbacks/base.py | 248 ++++++++++++++++++++ src/training/trainer/callbacks/early_stopping.py | 108 +++++++++ src/training/trainer/callbacks/lr_schedulers.py | 97 ++++++++ src/training/trainer/callbacks/progress_bar.py | 61 +++++ src/training/trainer/callbacks/wandb_callbacks.py | 93 ++++++++ .../trainer/population_based_training/__init__.py | 1 + .../population_based_training.py | 1 + src/training/trainer/train.py | 216 ++++++++++++++++++ src/training/trainer/util.py | 19 ++ src/training/util.py | 19 -- 23 files changed, 906 insertions(+), 849 deletions(-) delete mode 100644 src/training/callbacks/__init__.py delete mode 100644 src/training/callbacks/base.py delete mode 100644 src/training/callbacks/early_stopping.py delete mode 100644 src/training/callbacks/lr_schedulers.py delete mode 100644 src/training/callbacks/wandb_callbacks.py delete mode 100644 src/training/population_based_training/__init__.py delete mode 100644 src/training/population_based_training/population_based_training.py delete mode 100644 src/training/train.py create mode 100644 src/training/trainer/__init__.py create mode 100644 src/training/trainer/callbacks/__init__.py create mode 100644 src/training/trainer/callbacks/base.py create mode 100644 src/training/trainer/callbacks/early_stopping.py create mode 100644 src/training/trainer/callbacks/lr_schedulers.py create mode 100644 src/training/trainer/callbacks/progress_bar.py create mode 100644 src/training/trainer/callbacks/wandb_callbacks.py create mode 100644 src/training/trainer/population_based_training/__init__.py create mode 100644 src/training/trainer/population_based_training/population_based_training.py create mode 100644 src/training/trainer/train.py create mode 100644 src/training/trainer/util.py delete mode 100644 src/training/util.py (limited to 'src/training') diff --git a/src/training/callbacks/__init__.py b/src/training/callbacks/__init__.py deleted file mode 100644 index fbcc285..0000000 --- a/src/training/callbacks/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""The callback modules used in the training script.""" -from .base import Callback, CallbackList, Checkpoint -from .early_stopping import EarlyStopping -from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR -from .wandb_callbacks import WandbCallback, WandbImageLogger - -__all__ = [ - "Callback", - "CallbackList", - "Checkpoint", - "EarlyStopping", - "WandbCallback", - "WandbImageLogger", - "CyclicLR", - "MultiStepLR", - "OneCycleLR", - "ReduceLROnPlateau", - "StepLR", -] diff --git a/src/training/callbacks/base.py b/src/training/callbacks/base.py deleted file mode 100644 index e0d91e6..0000000 --- a/src/training/callbacks/base.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Metaclass for callback functions.""" - -from enum import Enum -from typing import Callable, Dict, List, Type, Union - -from loguru import logger -import numpy as np -import torch - -from text_recognizer.models import Model - - -class ModeKeys: - """Mode keys for CallbackList.""" - - TRAIN = "train" - VALIDATION = "validation" - - -class Callback: - """Metaclass for callbacks used in training.""" - - def __init__(self) -> None: - """Initializes the Callback instance.""" - self.model = None - - def set_model(self, model: Type[Model]) -> None: - """Set the model.""" - self.model = model - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - pass - - def on_fit_end(self) -> None: - """Called when fit ends.""" - pass - - def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch. Only used in training mode.""" - pass - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch. Only used in training mode.""" - pass - - def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - pass - - def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - pass - - -class CallbackList: - """Container for abstracting away callback calls.""" - - mode_keys = ModeKeys() - - def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: - """Container for `Callback` instances. - - This object wraps a list of `Callback` instances and allows them all to be - called via a single end point. - - Args: - model (Type[Model]): A `Model` instance. - callbacks (List[Callback]): List of `Callback` instances. Defaults to None. - - """ - - self._callbacks = callbacks or [] - if model: - self.set_model(model) - - def set_model(self, model: Type[Model]) -> None: - """Set the model for all callbacks.""" - self.model = model - for callback in self._callbacks: - callback.set_model(model=self.model) - - def append(self, callback: Type[Callback]) -> None: - """Append new callback to callback list.""" - self.callbacks.append(callback) - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - for callback in self._callbacks: - callback.on_fit_begin() - - def on_fit_end(self) -> None: - """Called when fit ends.""" - for callback in self._callbacks: - callback.on_fit_end() - - def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_begin(epoch, logs) - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_end(epoch, logs) - - def _call_batch_hook( - self, mode: str, hook: str, batch: int, logs: Dict = {} - ) -> None: - """Helper function for all batch_{begin | end} methods.""" - if hook == "begin": - self._call_batch_begin_hook(mode, batch, logs) - elif hook == "end": - self._call_batch_end_hook(mode, batch, logs) - else: - raise ValueError(f"Unrecognized hook {hook}.") - - def _call_batch_begin_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: - """Helper function for all `on_*_batch_begin` methods.""" - hook_name = f"on_{mode}_batch_begin" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_end_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: - """Helper function for all `on_*_batch_end` methods.""" - hook_name = f"on_{mode}_batch_end" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_hook_helper( - self, hook_name: str, batch: int, logs: Dict = {} - ) -> None: - """Helper function for `on_*_batch_begin` methods.""" - for callback in self._callbacks: - hook = getattr(callback, hook_name) - hook(batch, logs) - - def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch) - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "end", batch) - - def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch) - - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch) - - def __iter__(self) -> iter: - """Iter function for callback list.""" - return iter(self._callbacks) - - -class Checkpoint(Callback): - """Saving model parameters at the end of each epoch.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 - ) -> None: - """Monitors a quantity that will allow us to determine the best model weights. - - Args: - monitor (str): Name of the quantity to monitor. Defaults to "accuracy". - mode (str): Description of parameter `mode`. Defaults to "auto". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - - """ - super().__init__() - self.monitor = monitor - self.mode = mode - self.min_delta = torch.tensor(min_delta) - - if mode not in ["auto", "min", "max"]: - logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." - ) - - torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Saves a checkpoint for the network parameters. - - Args: - epoch (int): The current epoch. - logs (Dict): The log containing the monitored metrics. - - """ - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - is_best = True - else: - is_best = False - - self.model.save_checkpoint(is_best, epoch, self.monitor) - - def get_monitor_value(self, logs: Dict) -> Union[float, None]: - """Extracts the monitored value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" - + f"metrics are: {','.join(list(logs.keys()))}" - ) - return None - return monitor_value diff --git a/src/training/callbacks/early_stopping.py b/src/training/callbacks/early_stopping.py deleted file mode 100644 index c9b7907..0000000 --- a/src/training/callbacks/early_stopping.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Implements Early stopping for PyTorch model.""" -from typing import Dict, Union - -from loguru import logger -import numpy as np -import torch -from training.callbacks import Callback - - -class EarlyStopping(Callback): - """Stops training when a monitored metric stops improving.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, - monitor: str = "val_loss", - min_delta: float = 0.0, - patience: int = 3, - mode: str = "auto", - ) -> None: - """Initializes the EarlyStopping callback. - - Args: - monitor (str): Description of parameter `monitor`. Defaults to "val_loss". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - patience (int): Description of parameter `patience`. Defaults to 3. - mode (str): Description of parameter `mode`. Defaults to "auto". - - """ - super().__init__() - self.monitor = monitor - self.patience = patience - self.min_delta = torch.tensor(min_delta) - self.mode = mode - self.wait_count = 0 - self.stopped_epoch = 0 - - if mode not in ["auto", "min", "max"]: - logger.warning( - f"EarlyStopping mode {mode} is unkown, fallback to auto mode." - ) - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." - ) - - self.torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = ( - self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf - ) - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_fit_begin(self) -> Union[torch.lt, torch.gt]: - """Reset the early stopping variables for reuse.""" - self.wait_count = 0 - self.stopped_epoch = 0 - self.best_score = ( - self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf - ) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Computes the early stop criterion.""" - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - self.wait_count = 0 - else: - self.wait_count += 1 - if self.wait_count >= self.patience: - self.stopped_epoch = epoch - self.model.stop_training = True - - def on_fit_end(self) -> None: - """Logs if early stopping was used.""" - if self.stopped_epoch > 0: - logger.info( - f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." - ) - - def get_monitor_value(self, logs: Dict) -> Union[torch.Tensor, None]: - """Extracts the monitor value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" - + f"metrics are: {','.join(list(logs.keys()))}" - ) - return None - return torch.tensor(monitor_value) diff --git a/src/training/callbacks/lr_schedulers.py b/src/training/callbacks/lr_schedulers.py deleted file mode 100644 index 00c7e9b..0000000 --- a/src/training/callbacks/lr_schedulers.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Callbacks for learning rate schedulers.""" -from typing import Callable, Dict, List, Optional, Type - -from training.callbacks import Callback - -from text_recognizer.models import Model - - -class StepLR(Callback): - """Callback for StepLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class MultiStepLR(Callback): - """Callback for MultiStepLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class ReduceLROnPlateau(Callback): - """Callback for ReduceLROnPlateau.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every epoch.""" - val_loss = logs["val_loss"] - self.lr_scheduler.step(val_loss) - - -class CyclicLR(Callback): - """Callback for CyclicLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() - - -class OneCycleLR(Callback): - """Callback for OneCycleLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/callbacks/wandb_callbacks.py deleted file mode 100644 index 6ada6df..0000000 --- a/src/training/callbacks/wandb_callbacks.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Callbacks using wandb.""" -from typing import Callable, Dict, List, Optional, Type - -import numpy as np -from torchvision.transforms import Compose, ToTensor -from training.callbacks import Callback -import wandb - -from text_recognizer.datasets import Transpose -from text_recognizer.models.base import Model - - -class WandbCallback(Callback): - """A custom W&B metric logger for the trainer.""" - - def __init__(self, log_batch_frequency: int = None) -> None: - """Short summary. - - Args: - log_batch_frequency (int): If None, metrics will be logged every epoch. - If set to an integer, callback will log every metrics every log_batch_frequency. - - """ - super().__init__() - self.log_batch_frequency = log_batch_frequency - - def _on_batch_end(self, batch: int, logs: Dict) -> None: - if self.log_batch_frequency and batch % self.log_batch_frequency == 0: - wandb.log(logs, commit=True) - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Logs training metrics.""" - if logs is not None: - self._on_batch_end(batch, logs) - - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Logs validation metrics.""" - if logs is not None: - self._on_batch_end(batch, logs) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Logs at epoch end.""" - wandb.log(logs, commit=True) - - -class WandbImageLogger(Callback): - """Custom W&B callback for image logging.""" - - def __init__( - self, - example_indices: Optional[List] = None, - num_examples: int = 4, - transfroms: Optional[Callable] = None, - ) -> None: - """Initializes the WandbImageLogger with the model to train. - - Args: - example_indices (Optional[List]): Indices for validation images. Defaults to None. - num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to - None. - - """ - - super().__init__() - self.example_indices = example_indices - self.num_examples = num_examples - self.transfroms = transfroms - if self.transfroms is None: - self.transforms = Compose([Transpose()]) - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and extracts validation images from the dataset.""" - self.model = model - data_loader = self.model.data_loaders["val"] - if self.example_indices is None: - self.example_indices = np.random.randint( - 0, len(data_loader.dataset.data), self.num_examples - ) - self.val_images = data_loader.dataset.data[self.example_indices] - self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Get network predictions on validation images.""" - images = [] - for i, image in enumerate(self.val_images): - image = self.transforms(image) - pred, conf = self.model.predict_on_image(image) - ground_truth = self.model.mapper(int(self.val_targets[i])) - caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" - images.append(wandb.Image(image, caption=caption)) - - wandb.log({"examples": images}, commit=False) diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml index 355305c..bae02ac 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/src/training/experiments/sample_experiment.yml @@ -9,25 +9,32 @@ experiments: seed: 4711 data_loader_args: splits: [train, val] - batch_size: 256 shuffle: true num_workers: 8 cuda: true model: CharacterModel metrics: [accuracy] - network: MLP + # network: MLP + # network_args: + # input_size: 784 + # hidden_size: 512 + # output_size: 80 + # num_layers: 3 + # dropout_rate: 0 + # activation_fn: SELU + network: ResidualNetwork network_args: - input_size: 784 - output_size: 62 - num_layers: 3 - activation_fn: GELU + in_channels: 1 + num_classes: 80 + depths: [1, 1] + block_sizes: [128, 256] # network: LeNet # network_args: # output_size: 62 # activation_fn: GELU train_args: batch_size: 256 - epochs: 16 + epochs: 32 criterion: CrossEntropyLoss criterion_args: weight: null @@ -43,20 +50,24 @@ experiments: # centered: false optimizer: AdamW optimizer_args: - lr: 1.e-2 + lr: 1.e-03 betas: [0.9, 0.999] eps: 1.e-08 - weight_decay: 0 + # weight_decay: 5.e-4 amsgrad: false # lr_scheduler: null lr_scheduler: OneCycleLR lr_scheduler_args: - max_lr: 1.e-3 - epochs: 16 - callbacks: [Checkpoint, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] + max_lr: 1.e-03 + epochs: 32 + anneal_strategy: linear + callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] callback_args: Checkpoint: monitor: val_accuracy + ProgressBar: + epochs: 32 + log_batch_frequency: 100 EarlyStopping: monitor: val_loss min_delta: 0.0 @@ -68,5 +79,5 @@ experiments: num_examples: 4 OneCycleLR: null - verbosity: 2 # 0, 1, 2 + verbosity: 1 # 0, 1, 2 resume_experiment: null diff --git a/src/training/population_based_training/__init__.py b/src/training/population_based_training/__init__.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/population_based_training/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/population_based_training/population_based_training.py b/src/training/population_based_training/population_based_training.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/population_based_training/population_based_training.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py index 97c0304..4c3f9ba 100644 --- a/src/training/prepare_experiments.py +++ b/src/training/prepare_experiments.py @@ -7,11 +7,11 @@ from loguru import logger import yaml -# flake8: noqa: S404,S607,S603 def run_experiments(experiments_filename: str) -> None: """Run experiment from file.""" with open(experiments_filename) as f: experiments_config = yaml.safe_load(f) + num_experiments = len(experiments_config["experiments"]) for index in range(num_experiments): experiment_config = experiments_config["experiments"][index] @@ -27,10 +27,10 @@ def run_experiments(experiments_filename: str) -> None: type=str, help="Filename of Yaml file of experiments to run.", ) -def main(experiments_filename: str) -> None: +def run_cli(experiments_filename: str) -> None: """Parse command-line arguments and run experiments from provided file.""" run_experiments(experiments_filename) if __name__ == "__main__": - main() + run_cli() diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index d278dc2..8c063ff 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -6,18 +6,20 @@ import json import os from pathlib import Path import re -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, Tuple, Type import click from loguru import logger import torch from tqdm import tqdm -from training.callbacks import CallbackList from training.gpu_manager import GPUManager -from training.train import Trainer +from training.trainer.callbacks import CallbackList +from training.trainer.train import Trainer import wandb import yaml +from text_recognizer.models import Model + EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" @@ -35,7 +37,7 @@ def get_level(experiment_config: Dict) -> int: return 10 -def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path: +def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path: """Create new experiment.""" EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) experiment_dir = EXPERIMENTS_DIRNAME / model.__name__ @@ -67,6 +69,8 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] """Loads all modules and arguments.""" # Import the data loader arguments. data_loader_args = experiment_config.get("data_loader_args", {}) + train_args = experiment_config.get("train_args", {}) + data_loader_args["batch_size"] = train_args["batch_size"] data_loader_args["dataset"] = experiment_config["dataset"] data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {}) @@ -94,7 +98,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] optimizer_args = experiment_config.get("optimizer_args", {}) # Callbacks - callback_modules = importlib.import_module("training.callbacks") + callback_modules = importlib.import_module("training.trainer.callbacks") callbacks = [ getattr(callback_modules, callback)( **check_args(experiment_config["callback_args"][callback]) @@ -208,6 +212,7 @@ def run_experiment( with open(str(config_path), "w") as f: yaml.dump(experiment_config, f) + # Train the model. trainer = Trainer( model=model, model_dir=model_dir, @@ -247,7 +252,7 @@ def run_experiment( @click.option( "--nowandb", is_flag=False, help="If true, do not use wandb for this run." ) -def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: +def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: """Run experiment.""" if gpu < 0: gpu_manager = GPUManager(True) @@ -260,4 +265,4 @@ def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: if __name__ == "__main__": - main() + run_cli() diff --git a/src/training/train.py b/src/training/train.py deleted file mode 100644 index aaa0430..0000000 --- a/src/training/train.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Training script for PyTorch models.""" - -from pathlib import Path -import time -from typing import Dict, List, Optional, Tuple, Type - -from loguru import logger -import numpy as np -import torch -from tqdm import tqdm, trange -from training.callbacks import Callback, CallbackList -from training.util import RunningAverage -import wandb - -from text_recognizer.models import Model - - -torch.backends.cudnn.benchmark = True -np.random.seed(4711) -torch.manual_seed(4711) -torch.cuda.manual_seed(4711) - - -class Trainer: - """Trainer for training PyTorch models.""" - - def __init__( - self, - model: Type[Model], - model_dir: Path, - train_args: Dict, - callbacks: CallbackList, - checkpoint_path: Optional[Path] = None, - ) -> None: - """Initialization of the Trainer. - - Args: - model (Type[Model]): A model object. - model_dir (Path): Path to the model directory. - train_args (Dict): The training arguments. - callbacks (CallbackList): List of callbacks to be called. - checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. - - """ - self.model = model - self.model_dir = model_dir - self.checkpoint_path = checkpoint_path - self.start_epoch = 1 - self.epochs = train_args["epochs"] + self.start_epoch - self.callbacks = callbacks - - if self.checkpoint_path is not None: - self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + 1 - - # Parse the name of the experiment. - experiment_dir = str(self.model_dir.parents[1]).split("/") - self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1] - - def training_step( - self, - batch: int, - samples: Tuple[torch.Tensor, torch.Tensor], - loss_avg: Type[RunningAverage], - ) -> Dict: - """Performs the training step.""" - # Pass the tensor to the device for computation. - data, targets = samples - data, targets = ( - data.to(self.model.device), - targets.to(self.model.device), - ) - - # Forward pass. - # Get the network prediction. - output = self.model.network(data) - - # Compute the loss. - loss = self.model.criterion(output, targets) - - # Backward pass. - # Clear the previous gradients. - self.model.optimizer.zero_grad() - - # Compute the gradients. - loss.backward() - - # Perform updates using calculated gradients. - self.model.optimizer.step() - - # Compute metrics. - loss_avg.update(loss.item()) - output = output.data.cpu() - targets = targets.data.cpu() - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } - metrics["loss"] = loss_avg() - return metrics - - def train(self) -> None: - """Runs the training loop for one epoch.""" - # Set model to traning mode. - self.model.train() - - # Running average for the loss. - loss_avg = RunningAverage() - - data_loader = self.model.data_loaders["train"] - - with tqdm( - total=len(data_loader), - leave=False, - unit="step", - bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}", - ) as t: - for batch, samples in enumerate(data_loader): - self.callbacks.on_train_batch_begin(batch) - - metrics = self.training_step(batch, samples, loss_avg) - - self.callbacks.on_train_batch_end(batch, logs=metrics) - - # Update Tqdm progress bar. - t.set_postfix(**metrics) - t.update() - - def validation_step( - self, - batch: int, - samples: Tuple[torch.Tensor, torch.Tensor], - loss_avg: Type[RunningAverage], - ) -> Dict: - """Performs the validation step.""" - # Pass the tensor to the device for computation. - data, targets = samples - data, targets = ( - data.to(self.model.device), - targets.to(self.model.device), - ) - - # Forward pass. - # Get the network prediction. - output = self.model.network(data) - - # Compute the loss. - loss = self.model.criterion(output, targets) - - # Compute metrics. - loss_avg.update(loss.item()) - output = output.data.cpu() - targets = targets.data.cpu() - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } - metrics["loss"] = loss.item() - - return metrics - - def validate(self, epoch: Optional[int] = None) -> Dict: - """Runs the validation loop for one epoch.""" - # Set model to eval mode. - self.model.eval() - - # Running average for the loss. - data_loader = self.model.data_loaders["val"] - - # Running average for the loss. - loss_avg = RunningAverage() - - # Summary for the current eval loop. - summary = [] - - with tqdm( - total=len(data_loader), - leave=False, - unit="step", - bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}", - ) as t: - with torch.no_grad(): - for batch, samples in enumerate(data_loader): - self.callbacks.on_validation_batch_begin(batch) - - metrics = self.validation_step(batch, samples, loss_avg) - - self.callbacks.on_validation_batch_end(batch, logs=metrics) - - summary.append(metrics) - - # Update Tqdm progress bar. - t.set_postfix(**metrics) - t.update() - - # Compute mean of all metrics. - metrics_mean = { - "val_" + metric: np.mean([x[metric] for x in summary]) - for metric in summary[0] - } - if epoch: - logger.debug( - f"Validation metrics at epoch {epoch} - " - + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) - ) - else: - logger.debug( - "Validation metrics - " - + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) - ) - - return metrics_mean - - def fit(self) -> None: - """Runs the training and evaluation loop.""" - - logger.debug(f"Running an experiment called {self.experiment_name}.") - t_start = time.time() - - self.callbacks.on_fit_begin() - - # TODO: fix progress bar as callback. - # Run the training loop. - for epoch in trange( - self.start_epoch, - self.epochs, - leave=False, - bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:30}| {remaining}{postfix}", - desc="Epoch", - ): - self.callbacks.on_epoch_begin(epoch) - - # Perform one training pass over the training set. - self.train() - - # Evaluate the model on the validation set. - val_metrics = self.validate(epoch) - - self.callbacks.on_epoch_end(epoch, logs=val_metrics) - - if self.model.stop_training: - break - - # Calculate the total training time. - t_end = time.time() - t_training = t_end - t_start - - self.callbacks.on_fit_end() - - logger.info(f"Training took {t_training:.2f} s.") diff --git a/src/training/trainer/__init__.py b/src/training/trainer/__init__.py new file mode 100644 index 0000000..de41bfb --- /dev/null +++ b/src/training/trainer/__init__.py @@ -0,0 +1,2 @@ +"""Trainer modules.""" +from .train import Trainer diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py new file mode 100644 index 0000000..5942276 --- /dev/null +++ b/src/training/trainer/callbacks/__init__.py @@ -0,0 +1,21 @@ +"""The callback modules used in the training script.""" +from .base import Callback, CallbackList, Checkpoint +from .early_stopping import EarlyStopping +from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR +from .progress_bar import ProgressBar +from .wandb_callbacks import WandbCallback, WandbImageLogger + +__all__ = [ + "Callback", + "CallbackList", + "Checkpoint", + "EarlyStopping", + "WandbCallback", + "WandbImageLogger", + "CyclicLR", + "MultiStepLR", + "OneCycleLR", + "ProgressBar", + "ReduceLROnPlateau", + "StepLR", +] diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py new file mode 100644 index 0000000..8df94f3 --- /dev/null +++ b/src/training/trainer/callbacks/base.py @@ -0,0 +1,248 @@ +"""Metaclass for callback functions.""" + +from enum import Enum +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch + +from text_recognizer.models import Model + + +class ModeKeys: + """Mode keys for CallbackList.""" + + TRAIN = "train" + VALIDATION = "validation" + + +class Callback: + """Metaclass for callbacks used in training.""" + + def __init__(self) -> None: + """Initializes the Callback instance.""" + self.model = None + + def set_model(self, model: Type[Model]) -> None: + """Set the model.""" + self.model = model + + def on_fit_begin(self) -> None: + """Called when fit begins.""" + pass + + def on_fit_end(self) -> None: + """Called when fit ends.""" + pass + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch. Only used in training mode.""" + pass + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch. Only used in training mode.""" + pass + + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + pass + + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + pass + + +class CallbackList: + """Container for abstracting away callback calls.""" + + mode_keys = ModeKeys() + + def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: + """Container for `Callback` instances. + + This object wraps a list of `Callback` instances and allows them all to be + called via a single end point. + + Args: + model (Type[Model]): A `Model` instance. + callbacks (List[Callback]): List of `Callback` instances. Defaults to None. + + """ + + self._callbacks = callbacks or [] + if model: + self.set_model(model) + + def set_model(self, model: Type[Model]) -> None: + """Set the model for all callbacks.""" + self.model = model + for callback in self._callbacks: + callback.set_model(model=self.model) + + def append(self, callback: Type[Callback]) -> None: + """Append new callback to callback list.""" + self.callbacks.append(callback) + + def on_fit_begin(self) -> None: + """Called when fit begins.""" + for callback in self._callbacks: + callback.on_fit_begin() + + def on_fit_end(self) -> None: + """Called when fit ends.""" + for callback in self._callbacks: + callback.on_fit_end() + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + for callback in self._callbacks: + callback.on_epoch_begin(epoch, logs) + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + for callback in self._callbacks: + callback.on_epoch_end(epoch, logs) + + def _call_batch_hook( + self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all batch_{begin | end} methods.""" + if hook == "begin": + self._call_batch_begin_hook(mode, batch, logs) + elif hook == "end": + self._call_batch_end_hook(mode, batch, logs) + else: + raise ValueError(f"Unrecognized hook {hook}.") + + def _call_batch_begin_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all `on_*_batch_begin` methods.""" + hook_name = f"on_{mode}_batch_begin" + self._call_batch_hook_helper(hook_name, batch, logs) + + def _call_batch_end_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all `on_*_batch_end` methods.""" + hook_name = f"on_{mode}_batch_end" + self._call_batch_hook_helper(hook_name, batch, logs) + + def _call_batch_hook_helper( + self, hook_name: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for `on_*_batch_begin` methods.""" + for callback in self._callbacks: + hook = getattr(callback, hook_name) + hook(batch, logs) + + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) + + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: + """Called at the beginning of an epoch.""" + self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) + + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) + + def __iter__(self) -> iter: + """Iter function for callback list.""" + return iter(self._callbacks) + + +class Checkpoint(Callback): + """Saving model parameters at the end of each epoch.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 + ) -> None: + """Monitors a quantity that will allow us to determine the best model weights. + + Args: + monitor (str): Name of the quantity to monitor. Defaults to "accuracy". + mode (str): Description of parameter `mode`. Defaults to "auto". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + + """ + super().__init__() + self.monitor = monitor + self.mode = mode + self.min_delta = torch.tensor(min_delta) + + if mode not in ["auto", "min", "max"]: + logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." + ) + + torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Saves a checkpoint for the network parameters. + + Args: + epoch (int): The current epoch. + logs (Dict): The log containing the monitored metrics. + + """ + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + is_best = True + else: + is_best = False + + self.model.save_checkpoint(is_best, epoch, self.monitor) + + def get_monitor_value(self, logs: Dict) -> Union[float, None]: + """Extracts the monitored value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" + + f"metrics are: {','.join(list(logs.keys()))}" + ) + return None + return monitor_value diff --git a/src/training/trainer/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py new file mode 100644 index 0000000..02b431f --- /dev/null +++ b/src/training/trainer/callbacks/early_stopping.py @@ -0,0 +1,108 @@ +"""Implements Early stopping for PyTorch model.""" +from typing import Dict, Union + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from training.trainer.callbacks import Callback + + +class EarlyStopping(Callback): + """Stops training when a monitored metric stops improving.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, + monitor: str = "val_loss", + min_delta: float = 0.0, + patience: int = 3, + mode: str = "auto", + ) -> None: + """Initializes the EarlyStopping callback. + + Args: + monitor (str): Description of parameter `monitor`. Defaults to "val_loss". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + patience (int): Description of parameter `patience`. Defaults to 3. + mode (str): Description of parameter `mode`. Defaults to "auto". + + """ + super().__init__() + self.monitor = monitor + self.patience = patience + self.min_delta = torch.tensor(min_delta) + self.mode = mode + self.wait_count = 0 + self.stopped_epoch = 0 + + if mode not in ["auto", "min", "max"]: + logger.warning( + f"EarlyStopping mode {mode} is unkown, fallback to auto mode." + ) + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." + ) + + self.torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = ( + self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf + ) + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_fit_begin(self) -> Union[torch.lt, torch.gt]: + """Reset the early stopping variables for reuse.""" + self.wait_count = 0 + self.stopped_epoch = 0 + self.best_score = ( + self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf + ) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Computes the early stop criterion.""" + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + self.wait_count = 0 + else: + self.wait_count += 1 + if self.wait_count >= self.patience: + self.stopped_epoch = epoch + self.model.stop_training = True + + def on_fit_end(self) -> None: + """Logs if early stopping was used.""" + if self.stopped_epoch > 0: + logger.info( + f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." + ) + + def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: + """Extracts the monitor value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" + + f"metrics are: {','.join(list(logs.keys()))}" + ) + return None + return torch.tensor(monitor_value) diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py new file mode 100644 index 0000000..ba2226a --- /dev/null +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -0,0 +1,97 @@ +"""Callbacks for learning rate schedulers.""" +from typing import Callable, Dict, List, Optional, Type + +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class StepLR(Callback): + """Callback for StepLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class MultiStepLR(Callback): + """Callback for MultiStepLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class ReduceLROnPlateau(Callback): + """Callback for ReduceLROnPlateau.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + val_loss = logs["val_loss"] + self.lr_scheduler.step(val_loss) + + +class CyclicLR(Callback): + """Callback for CyclicLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + self.lr_scheduler.step() + + +class OneCycleLR(Callback): + """Callback for OneCycleLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + self.lr_scheduler.step() diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py new file mode 100644 index 0000000..1970747 --- /dev/null +++ b/src/training/trainer/callbacks/progress_bar.py @@ -0,0 +1,61 @@ +"""Progress bar callback for the training loop.""" +from typing import Dict, Optional + +from tqdm import tqdm +from training.trainer.callbacks import Callback + + +class ProgressBar(Callback): + """A TQDM progress bar for the training loop.""" + + def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: + """Initializes the tqdm callback.""" + self.epochs = epochs + self.log_batch_frequency = log_batch_frequency + self.progress_bar = None + self.val_metrics = {} + + def _configure_progress_bar(self) -> None: + """Configures the tqdm progress bar with custom bar format.""" + self.progress_bar = tqdm( + total=len(self.model.data_loaders["train"]), + leave=True, + unit="step", + mininterval=self.log_batch_frequency, + bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", + ) + + def _key_abbreviations(self, logs: Dict) -> Dict: + """Changes the length of keys, so that the progress bar fits better.""" + + def rename(key: str) -> str: + """Renames accuracy to acc.""" + return key.replace("accuracy", "acc") + + return {rename(key): value for key, value in logs.items()} + + def on_fit_begin(self) -> None: + """Creates a tqdm progress bar.""" + self._configure_progress_bar() + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: + """Updates the description with the current epoch.""" + self.progress_bar.reset() + self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """At the end of each epoch, the validation metrics are updated to the progress bar.""" + self.val_metrics = logs + self.progress_bar.set_postfix(**self._key_abbreviations(logs)) + self.progress_bar.update() + + def on_train_batch_end(self, batch: int, logs: Dict) -> None: + """Updates the progress bar for each training step.""" + if self.val_metrics: + logs.update(self.val_metrics) + self.progress_bar.set_postfix(**self._key_abbreviations(logs)) + self.progress_bar.update() + + def on_fit_end(self) -> None: + """Closes the tqdm progress bar.""" + self.progress_bar.close() diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py new file mode 100644 index 0000000..e44c745 --- /dev/null +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -0,0 +1,93 @@ +"""Callback for W&B.""" +from typing import Callable, Dict, List, Optional, Type + +import numpy as np +from torchvision.transforms import Compose, ToTensor +from training.trainer.callbacks import Callback +import wandb + +from text_recognizer.datasets import Transpose +from text_recognizer.models.base import Model + + +class WandbCallback(Callback): + """A custom W&B metric logger for the trainer.""" + + def __init__(self, log_batch_frequency: int = None) -> None: + """Short summary. + + Args: + log_batch_frequency (int): If None, metrics will be logged every epoch. + If set to an integer, callback will log every metrics every log_batch_frequency. + + """ + super().__init__() + self.log_batch_frequency = log_batch_frequency + + def _on_batch_end(self, batch: int, logs: Dict) -> None: + if self.log_batch_frequency and batch % self.log_batch_frequency == 0: + wandb.log(logs, commit=True) + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Logs training metrics.""" + if logs is not None: + self._on_batch_end(batch, logs) + + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Logs validation metrics.""" + if logs is not None: + self._on_batch_end(batch, logs) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Logs at epoch end.""" + wandb.log(logs, commit=True) + + +class WandbImageLogger(Callback): + """Custom W&B callback for image logging.""" + + def __init__( + self, + example_indices: Optional[List] = None, + num_examples: int = 4, + transfroms: Optional[Callable] = None, + ) -> None: + """Initializes the WandbImageLogger with the model to train. + + Args: + example_indices (Optional[List]): Indices for validation images. Defaults to None. + num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. + transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to + None. + + """ + + super().__init__() + self.example_indices = example_indices + self.num_examples = num_examples + self.transfroms = transfroms + if self.transfroms is None: + self.transforms = Compose([Transpose()]) + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and extracts validation images from the dataset.""" + self.model = model + data_loader = self.model.data_loaders["val"] + if self.example_indices is None: + self.example_indices = np.random.randint( + 0, len(data_loader.dataset.data), self.num_examples + ) + self.val_images = data_loader.dataset.data[self.example_indices] + self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Get network predictions on validation images.""" + images = [] + for i, image in enumerate(self.val_images): + image = self.transforms(image) + pred, conf = self.model.predict_on_image(image) + ground_truth = self.model.mapper(int(self.val_targets[i])) + caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" + images.append(wandb.Image(image, caption=caption)) + + wandb.log({"examples": images}, commit=False) diff --git a/src/training/trainer/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py new file mode 100644 index 0000000..868d739 --- /dev/null +++ b/src/training/trainer/population_based_training/__init__.py @@ -0,0 +1 @@ +"""TBC.""" diff --git a/src/training/trainer/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py new file mode 100644 index 0000000..868d739 --- /dev/null +++ b/src/training/trainer/population_based_training/population_based_training.py @@ -0,0 +1 @@ +"""TBC.""" diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py new file mode 100644 index 0000000..a75ae8f --- /dev/null +++ b/src/training/trainer/train.py @@ -0,0 +1,216 @@ +"""Training script for PyTorch models.""" + +from pathlib import Path +import time +from typing import Dict, List, Optional, Tuple, Type + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from training.trainer.callbacks import Callback, CallbackList +from training.trainer.util import RunningAverage +import wandb + +from text_recognizer.models import Model + + +torch.backends.cudnn.benchmark = True +np.random.seed(4711) +torch.manual_seed(4711) +torch.cuda.manual_seed(4711) + + +class Trainer: + """Trainer for training PyTorch models.""" + + def __init__( + self, + model: Type[Model], + model_dir: Path, + train_args: Dict, + callbacks: CallbackList, + checkpoint_path: Optional[Path] = None, + ) -> None: + """Initialization of the Trainer. + + Args: + model (Type[Model]): A model object. + model_dir (Path): Path to the model directory. + train_args (Dict): The training arguments. + callbacks (CallbackList): List of callbacks to be called. + checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. + + """ + self.model = model + self.model_dir = model_dir + self.checkpoint_path = checkpoint_path + self.start_epoch = 1 + self.epochs = train_args["epochs"] + self.callbacks = callbacks + + if self.checkpoint_path is not None: + self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + + # Parse the name of the experiment. + experiment_dir = str(self.model_dir.parents[1]).split("/") + self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1] + + def training_step( + self, + batch: int, + samples: Tuple[Tensor, Tensor], + loss_avg: Type[RunningAverage], + ) -> Dict: + """Performs the training step.""" + # Pass the tensor to the device for computation. + data, targets = samples + data, targets = ( + data.to(self.model.device), + targets.to(self.model.device), + ) + + # Forward pass. + # Get the network prediction. + output = self.model.network(data) + + # Compute the loss. + loss = self.model.criterion(output, targets) + + # Backward pass. + # Clear the previous gradients. + self.model.optimizer.zero_grad() + + # Compute the gradients. + loss.backward() + + # Perform updates using calculated gradients. + self.model.optimizer.step() + + # Compute metrics. + loss_avg.update(loss.item()) + output = output.data.cpu() + targets = targets.data.cpu() + metrics = { + metric: self.model.metrics[metric](output, targets) + for metric in self.model.metrics + } + metrics["loss"] = loss_avg() + return metrics + + def train(self) -> None: + """Runs the training loop for one epoch.""" + # Set model to traning mode. + self.model.train() + + # Running average for the loss. + loss_avg = RunningAverage() + + data_loader = self.model.data_loaders["train"] + + for batch, samples in enumerate(data_loader): + self.callbacks.on_train_batch_begin(batch) + metrics = self.training_step(batch, samples, loss_avg) + self.callbacks.on_train_batch_end(batch, logs=metrics) + + @torch.no_grad() + def validation_step( + self, + batch: int, + samples: Tuple[Tensor, Tensor], + loss_avg: Type[RunningAverage], + ) -> Dict: + """Performs the validation step.""" + # Pass the tensor to the device for computation. + data, targets = samples + data, targets = ( + data.to(self.model.device), + targets.to(self.model.device), + ) + + # Forward pass. + # Get the network prediction. + output = self.model.network(data) + + # Compute the loss. + loss = self.model.criterion(output, targets) + + # Compute metrics. + loss_avg.update(loss.item()) + output = output.data.cpu() + targets = targets.data.cpu() + metrics = { + metric: self.model.metrics[metric](output, targets) + for metric in self.model.metrics + } + metrics["loss"] = loss.item() + + return metrics + + def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None: + log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") + logger.debug( + log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) + ) + + def validate(self, epoch: Optional[int] = None) -> Dict: + """Runs the validation loop for one epoch.""" + # Set model to eval mode. + self.model.eval() + + # Running average for the loss. + data_loader = self.model.data_loaders["val"] + + # Running average for the loss. + loss_avg = RunningAverage() + + # Summary for the current eval loop. + summary = [] + + for batch, samples in enumerate(data_loader): + self.callbacks.on_validation_batch_begin(batch) + metrics = self.validation_step(batch, samples, loss_avg) + self.callbacks.on_validation_batch_end(batch, logs=metrics) + summary.append(metrics) + + # Compute mean of all metrics. + metrics_mean = { + "val_" + metric: np.mean([x[metric] for x in summary]) + for metric in summary[0] + } + self._log_val_metric(metrics_mean, epoch) + + return metrics_mean + + def fit(self) -> None: + """Runs the training and evaluation loop.""" + + logger.debug(f"Running an experiment called {self.experiment_name}.") + + # Set start time. + t_start = time.time() + + self.callbacks.on_fit_begin() + + # Run the training loop. + for epoch in range(self.start_epoch, self.epochs + 1): + self.callbacks.on_epoch_begin(epoch) + + # Perform one training pass over the training set. + self.train() + + # Evaluate the model on the validation set. + val_metrics = self.validate(epoch) + + self.callbacks.on_epoch_end(epoch, logs=val_metrics) + + if self.model.stop_training: + break + + # Calculate the total training time. + t_end = time.time() + t_training = t_end - t_start + + self.callbacks.on_fit_end() + + logger.info(f"Training took {t_training:.2f} s.") diff --git a/src/training/trainer/util.py b/src/training/trainer/util.py new file mode 100644 index 0000000..132b2dc --- /dev/null +++ b/src/training/trainer/util.py @@ -0,0 +1,19 @@ +"""Utility functions for training neural networks.""" + + +class RunningAverage: + """Maintains a running average.""" + + def __init__(self) -> None: + """Initializes the parameters.""" + self.steps = 0 + self.total = 0 + + def update(self, val: float) -> None: + """Updates the parameters.""" + self.total += val + self.steps += 1 + + def __call__(self) -> float: + """Computes the running average.""" + return self.total / float(self.steps) diff --git a/src/training/util.py b/src/training/util.py deleted file mode 100644 index 132b2dc..0000000 --- a/src/training/util.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Utility functions for training neural networks.""" - - -class RunningAverage: - """Maintains a running average.""" - - def __init__(self) -> None: - """Initializes the parameters.""" - self.steps = 0 - self.total = 0 - - def update(self, val: float) -> None: - """Updates the parameters.""" - self.total += val - self.steps += 1 - - def __call__(self) -> float: - """Computes the running average.""" - return self.total / float(self.steps) -- cgit v1.2.3-70-g09d2