diff options
Diffstat (limited to 'src/training/callbacks')
-rw-r--r-- | src/training/callbacks/__init__.py | 19 | ||||
-rw-r--r-- | src/training/callbacks/base.py | 240 | ||||
-rw-r--r-- | src/training/callbacks/early_stopping.py | 107 | ||||
-rw-r--r-- | src/training/callbacks/lr_schedulers.py | 97 | ||||
-rw-r--r-- | src/training/callbacks/wandb_callbacks.py | 93 |
5 files changed, 0 insertions, 556 deletions
diff --git a/src/training/callbacks/__init__.py b/src/training/callbacks/__init__.py deleted file mode 100644 index fbcc285..0000000 --- a/src/training/callbacks/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""The callback modules used in the training script.""" -from .base import Callback, CallbackList, Checkpoint -from .early_stopping import EarlyStopping -from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR -from .wandb_callbacks import WandbCallback, WandbImageLogger - -__all__ = [ - "Callback", - "CallbackList", - "Checkpoint", - "EarlyStopping", - "WandbCallback", - "WandbImageLogger", - "CyclicLR", - "MultiStepLR", - "OneCycleLR", - "ReduceLROnPlateau", - "StepLR", -] diff --git a/src/training/callbacks/base.py b/src/training/callbacks/base.py deleted file mode 100644 index e0d91e6..0000000 --- a/src/training/callbacks/base.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Metaclass for callback functions.""" - -from enum import Enum -from typing import Callable, Dict, List, Type, Union - -from loguru import logger -import numpy as np -import torch - -from text_recognizer.models import Model - - -class ModeKeys: - """Mode keys for CallbackList.""" - - TRAIN = "train" - VALIDATION = "validation" - - -class Callback: - """Metaclass for callbacks used in training.""" - - def __init__(self) -> None: - """Initializes the Callback instance.""" - self.model = None - - def set_model(self, model: Type[Model]) -> None: - """Set the model.""" - self.model = model - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - pass - - def on_fit_end(self) -> None: - """Called when fit ends.""" - pass - - def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch. Only used in training mode.""" - pass - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch. Only used in training mode.""" - pass - - def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - pass - - def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - pass - - -class CallbackList: - """Container for abstracting away callback calls.""" - - mode_keys = ModeKeys() - - def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: - """Container for `Callback` instances. - - This object wraps a list of `Callback` instances and allows them all to be - called via a single end point. - - Args: - model (Type[Model]): A `Model` instance. - callbacks (List[Callback]): List of `Callback` instances. Defaults to None. - - """ - - self._callbacks = callbacks or [] - if model: - self.set_model(model) - - def set_model(self, model: Type[Model]) -> None: - """Set the model for all callbacks.""" - self.model = model - for callback in self._callbacks: - callback.set_model(model=self.model) - - def append(self, callback: Type[Callback]) -> None: - """Append new callback to callback list.""" - self.callbacks.append(callback) - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - for callback in self._callbacks: - callback.on_fit_begin() - - def on_fit_end(self) -> None: - """Called when fit ends.""" - for callback in self._callbacks: - callback.on_fit_end() - - def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_begin(epoch, logs) - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_end(epoch, logs) - - def _call_batch_hook( - self, mode: str, hook: str, batch: int, logs: Dict = {} - ) -> None: - """Helper function for all batch_{begin | end} methods.""" - if hook == "begin": - self._call_batch_begin_hook(mode, batch, logs) - elif hook == "end": - self._call_batch_end_hook(mode, batch, logs) - else: - raise ValueError(f"Unrecognized hook {hook}.") - - def _call_batch_begin_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: - """Helper function for all `on_*_batch_begin` methods.""" - hook_name = f"on_{mode}_batch_begin" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_end_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: - """Helper function for all `on_*_batch_end` methods.""" - hook_name = f"on_{mode}_batch_end" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_hook_helper( - self, hook_name: str, batch: int, logs: Dict = {} - ) -> None: - """Helper function for `on_*_batch_begin` methods.""" - for callback in self._callbacks: - hook = getattr(callback, hook_name) - hook(batch, logs) - - def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch) - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "end", batch) - - def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch) - - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch) - - def __iter__(self) -> iter: - """Iter function for callback list.""" - return iter(self._callbacks) - - -class Checkpoint(Callback): - """Saving model parameters at the end of each epoch.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 - ) -> None: - """Monitors a quantity that will allow us to determine the best model weights. - - Args: - monitor (str): Name of the quantity to monitor. Defaults to "accuracy". - mode (str): Description of parameter `mode`. Defaults to "auto". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - - """ - super().__init__() - self.monitor = monitor - self.mode = mode - self.min_delta = torch.tensor(min_delta) - - if mode not in ["auto", "min", "max"]: - logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." - ) - - torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Saves a checkpoint for the network parameters. - - Args: - epoch (int): The current epoch. - logs (Dict): The log containing the monitored metrics. - - """ - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - is_best = True - else: - is_best = False - - self.model.save_checkpoint(is_best, epoch, self.monitor) - - def get_monitor_value(self, logs: Dict) -> Union[float, None]: - """Extracts the monitored value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" - + f"metrics are: {','.join(list(logs.keys()))}" - ) - return None - return monitor_value diff --git a/src/training/callbacks/early_stopping.py b/src/training/callbacks/early_stopping.py deleted file mode 100644 index c9b7907..0000000 --- a/src/training/callbacks/early_stopping.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Implements Early stopping for PyTorch model.""" -from typing import Dict, Union - -from loguru import logger -import numpy as np -import torch -from training.callbacks import Callback - - -class EarlyStopping(Callback): - """Stops training when a monitored metric stops improving.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, - monitor: str = "val_loss", - min_delta: float = 0.0, - patience: int = 3, - mode: str = "auto", - ) -> None: - """Initializes the EarlyStopping callback. - - Args: - monitor (str): Description of parameter `monitor`. Defaults to "val_loss". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - patience (int): Description of parameter `patience`. Defaults to 3. - mode (str): Description of parameter `mode`. Defaults to "auto". - - """ - super().__init__() - self.monitor = monitor - self.patience = patience - self.min_delta = torch.tensor(min_delta) - self.mode = mode - self.wait_count = 0 - self.stopped_epoch = 0 - - if mode not in ["auto", "min", "max"]: - logger.warning( - f"EarlyStopping mode {mode} is unkown, fallback to auto mode." - ) - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." - ) - - self.torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = ( - self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf - ) - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_fit_begin(self) -> Union[torch.lt, torch.gt]: - """Reset the early stopping variables for reuse.""" - self.wait_count = 0 - self.stopped_epoch = 0 - self.best_score = ( - self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf - ) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Computes the early stop criterion.""" - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - self.wait_count = 0 - else: - self.wait_count += 1 - if self.wait_count >= self.patience: - self.stopped_epoch = epoch - self.model.stop_training = True - - def on_fit_end(self) -> None: - """Logs if early stopping was used.""" - if self.stopped_epoch > 0: - logger.info( - f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." - ) - - def get_monitor_value(self, logs: Dict) -> Union[torch.Tensor, None]: - """Extracts the monitor value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" - + f"metrics are: {','.join(list(logs.keys()))}" - ) - return None - return torch.tensor(monitor_value) diff --git a/src/training/callbacks/lr_schedulers.py b/src/training/callbacks/lr_schedulers.py deleted file mode 100644 index 00c7e9b..0000000 --- a/src/training/callbacks/lr_schedulers.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Callbacks for learning rate schedulers.""" -from typing import Callable, Dict, List, Optional, Type - -from training.callbacks import Callback - -from text_recognizer.models import Model - - -class StepLR(Callback): - """Callback for StepLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class MultiStepLR(Callback): - """Callback for MultiStepLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class ReduceLROnPlateau(Callback): - """Callback for ReduceLROnPlateau.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every epoch.""" - val_loss = logs["val_loss"] - self.lr_scheduler.step(val_loss) - - -class CyclicLR(Callback): - """Callback for CyclicLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() - - -class OneCycleLR(Callback): - """Callback for OneCycleLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/callbacks/wandb_callbacks.py deleted file mode 100644 index 6ada6df..0000000 --- a/src/training/callbacks/wandb_callbacks.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Callbacks using wandb.""" -from typing import Callable, Dict, List, Optional, Type - -import numpy as np -from torchvision.transforms import Compose, ToTensor -from training.callbacks import Callback -import wandb - -from text_recognizer.datasets import Transpose -from text_recognizer.models.base import Model - - -class WandbCallback(Callback): - """A custom W&B metric logger for the trainer.""" - - def __init__(self, log_batch_frequency: int = None) -> None: - """Short summary. - - Args: - log_batch_frequency (int): If None, metrics will be logged every epoch. - If set to an integer, callback will log every metrics every log_batch_frequency. - - """ - super().__init__() - self.log_batch_frequency = log_batch_frequency - - def _on_batch_end(self, batch: int, logs: Dict) -> None: - if self.log_batch_frequency and batch % self.log_batch_frequency == 0: - wandb.log(logs, commit=True) - - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Logs training metrics.""" - if logs is not None: - self._on_batch_end(batch, logs) - - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: - """Logs validation metrics.""" - if logs is not None: - self._on_batch_end(batch, logs) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Logs at epoch end.""" - wandb.log(logs, commit=True) - - -class WandbImageLogger(Callback): - """Custom W&B callback for image logging.""" - - def __init__( - self, - example_indices: Optional[List] = None, - num_examples: int = 4, - transfroms: Optional[Callable] = None, - ) -> None: - """Initializes the WandbImageLogger with the model to train. - - Args: - example_indices (Optional[List]): Indices for validation images. Defaults to None. - num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to - None. - - """ - - super().__init__() - self.example_indices = example_indices - self.num_examples = num_examples - self.transfroms = transfroms - if self.transfroms is None: - self.transforms = Compose([Transpose()]) - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and extracts validation images from the dataset.""" - self.model = model - data_loader = self.model.data_loaders["val"] - if self.example_indices is None: - self.example_indices = np.random.randint( - 0, len(data_loader.dataset.data), self.num_examples - ) - self.val_images = data_loader.dataset.data[self.example_indices] - self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Get network predictions on validation images.""" - images = [] - for i, image in enumerate(self.val_images): - image = self.transforms(image) - pred, conf = self.model.predict_on_image(image) - ground_truth = self.model.mapper(int(self.val_targets[i])) - caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" - images.append(wandb.Image(image, caption=caption)) - - wandb.log({"examples": images}, commit=False) |