diff options
Diffstat (limited to 'src/training/callbacks')
-rw-r--r-- | src/training/callbacks/__init__.py | 20 | ||||
-rw-r--r-- | src/training/callbacks/base.py | 231 | ||||
-rw-r--r-- | src/training/callbacks/early_stopping.py | 106 | ||||
-rw-r--r-- | src/training/callbacks/lr_schedulers.py | 97 | ||||
-rw-r--r-- | src/training/callbacks/wandb_callbacks.py | 93 |
5 files changed, 500 insertions, 47 deletions
diff --git a/src/training/callbacks/__init__.py b/src/training/callbacks/__init__.py index 868d739..fbcc285 100644 --- a/src/training/callbacks/__init__.py +++ b/src/training/callbacks/__init__.py @@ -1 +1,19 @@ -"""TBC.""" +"""The callback modules used in the training script.""" +from .base import Callback, CallbackList, Checkpoint +from .early_stopping import EarlyStopping +from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR +from .wandb_callbacks import WandbCallback, WandbImageLogger + +__all__ = [ + "Callback", + "CallbackList", + "Checkpoint", + "EarlyStopping", + "WandbCallback", + "WandbImageLogger", + "CyclicLR", + "MultiStepLR", + "OneCycleLR", + "ReduceLROnPlateau", + "StepLR", +] diff --git a/src/training/callbacks/base.py b/src/training/callbacks/base.py index d80a1e5..e0d91e6 100644 --- a/src/training/callbacks/base.py +++ b/src/training/callbacks/base.py @@ -1,12 +1,33 @@ """Metaclass for callback functions.""" -from abc import ABC -from typing import Callable, List, Type +from enum import Enum +from typing import Callable, Dict, List, Type, Union +from loguru import logger +import numpy as np +import torch -class Callback(ABC): +from text_recognizer.models import Model + + +class ModeKeys: + """Mode keys for CallbackList.""" + + TRAIN = "train" + VALIDATION = "validation" + + +class Callback: """Metaclass for callbacks used in training.""" + def __init__(self) -> None: + """Initializes the Callback instance.""" + self.model = None + + def set_model(self, model: Type[Model]) -> None: + """Set the model.""" + self.model = model + def on_fit_begin(self) -> None: """Called when fit begins.""" pass @@ -15,35 +36,27 @@ class Callback(ABC): """Called when fit ends.""" pass - def on_train_epoch_begin(self) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_train_epoch_end(self) -> None: - """Called at the end of an epoch.""" + def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: + """Called at the beginning of an epoch. Only used in training mode.""" pass - def on_val_epoch_begin(self) -> None: - """Called at the beginning of an epoch.""" + def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + """Called at the end of an epoch. Only used in training mode.""" pass - def on_val_epoch_end(self) -> None: - """Called at the end of an epoch.""" - pass - - def on_train_batch_begin(self) -> None: + def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: """Called at the beginning of an epoch.""" pass - def on_train_batch_end(self) -> None: + def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: """Called at the end of an epoch.""" pass - def on_val_batch_begin(self) -> None: + def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: """Called at the beginning of an epoch.""" pass - def on_val_batch_end(self) -> None: + def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: """Called at the end of an epoch.""" pass @@ -51,9 +64,29 @@ class Callback(ABC): class CallbackList: """Container for abstracting away callback calls.""" - def __init__(self, callbacks: List[Callable] = None) -> None: - """TBC.""" - self._callbacks = callbacks if callbacks is not None else [] + mode_keys = ModeKeys() + + def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: + """Container for `Callback` instances. + + This object wraps a list of `Callback` instances and allows them all to be + called via a single end point. + + Args: + model (Type[Model]): A `Model` instance. + callbacks (List[Callback]): List of `Callback` instances. Defaults to None. + + """ + + self._callbacks = callbacks or [] + if model: + self.set_model(model) + + def set_model(self, model: Type[Model]) -> None: + """Set the model for all callbacks.""" + self.model = model + for callback in self._callbacks: + callback.set_model(model=self.model) def append(self, callback: Type[Callback]) -> None: """Append new callback to callback list.""" @@ -61,41 +94,147 @@ class CallbackList: def on_fit_begin(self) -> None: """Called when fit begins.""" - for _ in self._callbacks: - pass + for callback in self._callbacks: + callback.on_fit_begin() def on_fit_end(self) -> None: """Called when fit ends.""" - pass + for callback in self._callbacks: + callback.on_fit_end() - def on_train_epoch_begin(self) -> None: + def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: """Called at the beginning of an epoch.""" - pass + for callback in self._callbacks: + callback.on_epoch_begin(epoch, logs) - def on_train_epoch_end(self) -> None: + def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: """Called at the end of an epoch.""" - pass - - def on_val_epoch_begin(self) -> None: + for callback in self._callbacks: + callback.on_epoch_end(epoch, logs) + + def _call_batch_hook( + self, mode: str, hook: str, batch: int, logs: Dict = {} + ) -> None: + """Helper function for all batch_{begin | end} methods.""" + if hook == "begin": + self._call_batch_begin_hook(mode, batch, logs) + elif hook == "end": + self._call_batch_end_hook(mode, batch, logs) + else: + raise ValueError(f"Unrecognized hook {hook}.") + + def _call_batch_begin_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: + """Helper function for all `on_*_batch_begin` methods.""" + hook_name = f"on_{mode}_batch_begin" + self._call_batch_hook_helper(hook_name, batch, logs) + + def _call_batch_end_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: + """Helper function for all `on_*_batch_end` methods.""" + hook_name = f"on_{mode}_batch_end" + self._call_batch_hook_helper(hook_name, batch, logs) + + def _call_batch_hook_helper( + self, hook_name: str, batch: int, logs: Dict = {} + ) -> None: + """Helper function for `on_*_batch_begin` methods.""" + for callback in self._callbacks: + hook = getattr(callback, hook_name) + hook(batch, logs) + + def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: """Called at the beginning of an epoch.""" - pass + self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch) - def on_val_epoch_end(self) -> None: + def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: """Called at the end of an epoch.""" - pass + self._call_batch_hook(self.mode_keys.TRAIN, "end", batch) - def on_train_batch_begin(self) -> None: + def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: """Called at the beginning of an epoch.""" - pass - - def on_train_batch_end(self) -> None: - """Called at the end of an epoch.""" - pass + self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch) - def on_val_batch_begin(self) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_val_batch_end(self) -> None: + def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: """Called at the end of an epoch.""" - pass + self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch) + + def __iter__(self) -> iter: + """Iter function for callback list.""" + return iter(self._callbacks) + + +class Checkpoint(Callback): + """Saving model parameters at the end of each epoch.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 + ) -> None: + """Monitors a quantity that will allow us to determine the best model weights. + + Args: + monitor (str): Name of the quantity to monitor. Defaults to "accuracy". + mode (str): Description of parameter `mode`. Defaults to "auto". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + + """ + super().__init__() + self.monitor = monitor + self.mode = mode + self.min_delta = torch.tensor(min_delta) + + if mode not in ["auto", "min", "max"]: + logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." + ) + + torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Saves a checkpoint for the network parameters. + + Args: + epoch (int): The current epoch. + logs (Dict): The log containing the monitored metrics. + + """ + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + is_best = True + else: + is_best = False + + self.model.save_checkpoint(is_best, epoch, self.monitor) + + def get_monitor_value(self, logs: Dict) -> Union[float, None]: + """Extracts the monitored value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" + + f"metrics are: {','.join(list(logs.keys()))}" + ) + return None + return monitor_value diff --git a/src/training/callbacks/early_stopping.py b/src/training/callbacks/early_stopping.py index 4da0e85..c9b7907 100644 --- a/src/training/callbacks/early_stopping.py +++ b/src/training/callbacks/early_stopping.py @@ -1 +1,107 @@ """Implements Early stopping for PyTorch model.""" +from typing import Dict, Union + +from loguru import logger +import numpy as np +import torch +from training.callbacks import Callback + + +class EarlyStopping(Callback): + """Stops training when a monitored metric stops improving.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, + monitor: str = "val_loss", + min_delta: float = 0.0, + patience: int = 3, + mode: str = "auto", + ) -> None: + """Initializes the EarlyStopping callback. + + Args: + monitor (str): Description of parameter `monitor`. Defaults to "val_loss". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + patience (int): Description of parameter `patience`. Defaults to 3. + mode (str): Description of parameter `mode`. Defaults to "auto". + + """ + super().__init__() + self.monitor = monitor + self.patience = patience + self.min_delta = torch.tensor(min_delta) + self.mode = mode + self.wait_count = 0 + self.stopped_epoch = 0 + + if mode not in ["auto", "min", "max"]: + logger.warning( + f"EarlyStopping mode {mode} is unkown, fallback to auto mode." + ) + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." + ) + + self.torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = ( + self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf + ) + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_fit_begin(self) -> Union[torch.lt, torch.gt]: + """Reset the early stopping variables for reuse.""" + self.wait_count = 0 + self.stopped_epoch = 0 + self.best_score = ( + self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf + ) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Computes the early stop criterion.""" + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + self.wait_count = 0 + else: + self.wait_count += 1 + if self.wait_count >= self.patience: + self.stopped_epoch = epoch + self.model.stop_training = True + + def on_fit_end(self) -> None: + """Logs if early stopping was used.""" + if self.stopped_epoch > 0: + logger.info( + f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." + ) + + def get_monitor_value(self, logs: Dict) -> Union[torch.Tensor, None]: + """Extracts the monitor value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" + + f"metrics are: {','.join(list(logs.keys()))}" + ) + return None + return torch.tensor(monitor_value) diff --git a/src/training/callbacks/lr_schedulers.py b/src/training/callbacks/lr_schedulers.py new file mode 100644 index 0000000..00c7e9b --- /dev/null +++ b/src/training/callbacks/lr_schedulers.py @@ -0,0 +1,97 @@ +"""Callbacks for learning rate schedulers.""" +from typing import Callable, Dict, List, Optional, Type + +from training.callbacks import Callback + +from text_recognizer.models import Model + + +class StepLR(Callback): + """Callback for StepLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class MultiStepLR(Callback): + """Callback for MultiStepLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class ReduceLROnPlateau(Callback): + """Callback for ReduceLROnPlateau.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + """Takes a step at the end of every epoch.""" + val_loss = logs["val_loss"] + self.lr_scheduler.step(val_loss) + + +class CyclicLR(Callback): + """Callback for CyclicLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + """Takes a step at the end of every training batch.""" + self.lr_scheduler.step() + + +class OneCycleLR(Callback): + """Callback for OneCycleLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + """Takes a step at the end of every training batch.""" + self.lr_scheduler.step() diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/callbacks/wandb_callbacks.py new file mode 100644 index 0000000..f64cbe1 --- /dev/null +++ b/src/training/callbacks/wandb_callbacks.py @@ -0,0 +1,93 @@ +"""Callbacks using wandb.""" +from typing import Callable, Dict, List, Optional, Type + +import numpy as np +from torchvision.transforms import Compose, ToTensor +from training.callbacks import Callback +import wandb + +from text_recognizer.datasets import Transpose +from text_recognizer.models.base import Model + + +class WandbCallback(Callback): + """A custom W&B metric logger for the trainer.""" + + def __init__(self, log_batch_frequency: int = None) -> None: + """Short summary. + + Args: + log_batch_frequency (int): If None, metrics will be logged every epoch. + If set to an integer, callback will log every metrics every log_batch_frequency. + + """ + super().__init__() + self.log_batch_frequency = log_batch_frequency + + def _on_batch_end(self, batch: int, logs: Dict) -> None: + if self.log_batch_frequency and batch % self.log_batch_frequency == 0: + wandb.log(logs, commit=True) + + def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + """Logs training metrics.""" + if logs is not None: + self._on_batch_end(batch, logs) + + def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: + """Logs validation metrics.""" + if logs is not None: + self._on_batch_end(batch, logs) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Logs at epoch end.""" + wandb.log(logs, commit=True) + + +class WandbImageLogger(Callback): + """Custom W&B callback for image logging.""" + + def __init__( + self, + example_indices: Optional[List] = None, + num_examples: int = 4, + transfroms: Optional[Callable] = None, + ) -> None: + """Initializes the WandbImageLogger with the model to train. + + Args: + example_indices (Optional[List]): Indices for validation images. Defaults to None. + num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. + transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to + None. + + """ + + super().__init__() + self.example_indices = example_indices + self.num_examples = num_examples + self.transfroms = transfroms + if self.transfroms is None: + self.transforms = Compose([Transpose()]) + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and extracts validation images from the dataset.""" + self.model = model + data_loader = self.model.data_loaders("val") + if self.example_indices is None: + self.example_indices = np.random.randint( + 0, len(data_loader.dataset.data), self.num_examples + ) + self.val_images = data_loader.dataset.data[self.example_indices] + self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Get network predictions on validation images.""" + images = [] + for i, image in enumerate(self.val_images): + image = self.transforms(image) + pred, conf = self.model.predict_on_image(image) + ground_truth = self.model._mapping[self.val_targets[i]] + caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" + images.append(wandb.Image(image, caption=caption)) + + wandb.log({"examples": images}, commit=False) |