summaryrefslogtreecommitdiff
path: root/src/training/callbacks
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/callbacks')
-rw-r--r--src/training/callbacks/__init__.py19
-rw-r--r--src/training/callbacks/base.py240
-rw-r--r--src/training/callbacks/early_stopping.py107
-rw-r--r--src/training/callbacks/lr_schedulers.py97
-rw-r--r--src/training/callbacks/wandb_callbacks.py93
5 files changed, 0 insertions, 556 deletions
diff --git a/src/training/callbacks/__init__.py b/src/training/callbacks/__init__.py
deleted file mode 100644
index fbcc285..0000000
--- a/src/training/callbacks/__init__.py
+++ /dev/null
@@ -1,19 +0,0 @@
-"""The callback modules used in the training script."""
-from .base import Callback, CallbackList, Checkpoint
-from .early_stopping import EarlyStopping
-from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR
-from .wandb_callbacks import WandbCallback, WandbImageLogger
-
-__all__ = [
- "Callback",
- "CallbackList",
- "Checkpoint",
- "EarlyStopping",
- "WandbCallback",
- "WandbImageLogger",
- "CyclicLR",
- "MultiStepLR",
- "OneCycleLR",
- "ReduceLROnPlateau",
- "StepLR",
-]
diff --git a/src/training/callbacks/base.py b/src/training/callbacks/base.py
deleted file mode 100644
index e0d91e6..0000000
--- a/src/training/callbacks/base.py
+++ /dev/null
@@ -1,240 +0,0 @@
-"""Metaclass for callback functions."""
-
-from enum import Enum
-from typing import Callable, Dict, List, Type, Union
-
-from loguru import logger
-import numpy as np
-import torch
-
-from text_recognizer.models import Model
-
-
-class ModeKeys:
- """Mode keys for CallbackList."""
-
- TRAIN = "train"
- VALIDATION = "validation"
-
-
-class Callback:
- """Metaclass for callbacks used in training."""
-
- def __init__(self) -> None:
- """Initializes the Callback instance."""
- self.model = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Set the model."""
- self.model = model
-
- def on_fit_begin(self) -> None:
- """Called when fit begins."""
- pass
-
- def on_fit_end(self) -> None:
- """Called when fit ends."""
- pass
-
- def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None:
- """Called at the beginning of an epoch. Only used in training mode."""
- pass
-
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
- """Called at the end of an epoch. Only used in training mode."""
- pass
-
- def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None:
- """Called at the beginning of an epoch."""
- pass
-
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
- """Called at the end of an epoch."""
- pass
-
- def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None:
- """Called at the beginning of an epoch."""
- pass
-
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
- """Called at the end of an epoch."""
- pass
-
-
-class CallbackList:
- """Container for abstracting away callback calls."""
-
- mode_keys = ModeKeys()
-
- def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None:
- """Container for `Callback` instances.
-
- This object wraps a list of `Callback` instances and allows them all to be
- called via a single end point.
-
- Args:
- model (Type[Model]): A `Model` instance.
- callbacks (List[Callback]): List of `Callback` instances. Defaults to None.
-
- """
-
- self._callbacks = callbacks or []
- if model:
- self.set_model(model)
-
- def set_model(self, model: Type[Model]) -> None:
- """Set the model for all callbacks."""
- self.model = model
- for callback in self._callbacks:
- callback.set_model(model=self.model)
-
- def append(self, callback: Type[Callback]) -> None:
- """Append new callback to callback list."""
- self.callbacks.append(callback)
-
- def on_fit_begin(self) -> None:
- """Called when fit begins."""
- for callback in self._callbacks:
- callback.on_fit_begin()
-
- def on_fit_end(self) -> None:
- """Called when fit ends."""
- for callback in self._callbacks:
- callback.on_fit_end()
-
- def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None:
- """Called at the beginning of an epoch."""
- for callback in self._callbacks:
- callback.on_epoch_begin(epoch, logs)
-
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
- """Called at the end of an epoch."""
- for callback in self._callbacks:
- callback.on_epoch_end(epoch, logs)
-
- def _call_batch_hook(
- self, mode: str, hook: str, batch: int, logs: Dict = {}
- ) -> None:
- """Helper function for all batch_{begin | end} methods."""
- if hook == "begin":
- self._call_batch_begin_hook(mode, batch, logs)
- elif hook == "end":
- self._call_batch_end_hook(mode, batch, logs)
- else:
- raise ValueError(f"Unrecognized hook {hook}.")
-
- def _call_batch_begin_hook(self, mode: str, batch: int, logs: Dict = {}) -> None:
- """Helper function for all `on_*_batch_begin` methods."""
- hook_name = f"on_{mode}_batch_begin"
- self._call_batch_hook_helper(hook_name, batch, logs)
-
- def _call_batch_end_hook(self, mode: str, batch: int, logs: Dict = {}) -> None:
- """Helper function for all `on_*_batch_end` methods."""
- hook_name = f"on_{mode}_batch_end"
- self._call_batch_hook_helper(hook_name, batch, logs)
-
- def _call_batch_hook_helper(
- self, hook_name: str, batch: int, logs: Dict = {}
- ) -> None:
- """Helper function for `on_*_batch_begin` methods."""
- for callback in self._callbacks:
- hook = getattr(callback, hook_name)
- hook(batch, logs)
-
- def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None:
- """Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch)
-
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
- """Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "end", batch)
-
- def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None:
- """Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch)
-
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
- """Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch)
-
- def __iter__(self) -> iter:
- """Iter function for callback list."""
- return iter(self._callbacks)
-
-
-class Checkpoint(Callback):
- """Saving model parameters at the end of each epoch."""
-
- mode_dict = {
- "min": torch.lt,
- "max": torch.gt,
- }
-
- def __init__(
- self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0
- ) -> None:
- """Monitors a quantity that will allow us to determine the best model weights.
-
- Args:
- monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
- mode (str): Description of parameter `mode`. Defaults to "auto".
- min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
-
- """
- super().__init__()
- self.monitor = monitor
- self.mode = mode
- self.min_delta = torch.tensor(min_delta)
-
- if mode not in ["auto", "min", "max"]:
- logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
-
- self.mode = "auto"
-
- if self.mode == "auto":
- if "accuracy" in self.monitor:
- self.mode = "max"
- else:
- self.mode = "min"
- logger.debug(
- f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
- )
-
- torch_inf = torch.tensor(np.inf)
- self.min_delta *= 1 if self.monitor_op == torch.gt else -1
- self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
-
- @property
- def monitor_op(self) -> float:
- """Returns the comparison method."""
- return self.mode_dict[self.mode]
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Saves a checkpoint for the network parameters.
-
- Args:
- epoch (int): The current epoch.
- logs (Dict): The log containing the monitored metrics.
-
- """
- current = self.get_monitor_value(logs)
- if current is None:
- return
- if self.monitor_op(current - self.min_delta, self.best_score):
- self.best_score = current
- is_best = True
- else:
- is_best = False
-
- self.model.save_checkpoint(is_best, epoch, self.monitor)
-
- def get_monitor_value(self, logs: Dict) -> Union[float, None]:
- """Extracts the monitored value."""
- monitor_value = logs.get(self.monitor)
- if monitor_value is None:
- logger.warning(
- f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
- + f"metrics are: {','.join(list(logs.keys()))}"
- )
- return None
- return monitor_value
diff --git a/src/training/callbacks/early_stopping.py b/src/training/callbacks/early_stopping.py
deleted file mode 100644
index c9b7907..0000000
--- a/src/training/callbacks/early_stopping.py
+++ /dev/null
@@ -1,107 +0,0 @@
-"""Implements Early stopping for PyTorch model."""
-from typing import Dict, Union
-
-from loguru import logger
-import numpy as np
-import torch
-from training.callbacks import Callback
-
-
-class EarlyStopping(Callback):
- """Stops training when a monitored metric stops improving."""
-
- mode_dict = {
- "min": torch.lt,
- "max": torch.gt,
- }
-
- def __init__(
- self,
- monitor: str = "val_loss",
- min_delta: float = 0.0,
- patience: int = 3,
- mode: str = "auto",
- ) -> None:
- """Initializes the EarlyStopping callback.
-
- Args:
- monitor (str): Description of parameter `monitor`. Defaults to "val_loss".
- min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
- patience (int): Description of parameter `patience`. Defaults to 3.
- mode (str): Description of parameter `mode`. Defaults to "auto".
-
- """
- super().__init__()
- self.monitor = monitor
- self.patience = patience
- self.min_delta = torch.tensor(min_delta)
- self.mode = mode
- self.wait_count = 0
- self.stopped_epoch = 0
-
- if mode not in ["auto", "min", "max"]:
- logger.warning(
- f"EarlyStopping mode {mode} is unkown, fallback to auto mode."
- )
-
- self.mode = "auto"
-
- if self.mode == "auto":
- if "accuracy" in self.monitor:
- self.mode = "max"
- else:
- self.mode = "min"
- logger.debug(
- f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}."
- )
-
- self.torch_inf = torch.tensor(np.inf)
- self.min_delta *= 1 if self.monitor_op == torch.gt else -1
- self.best_score = (
- self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
- )
-
- @property
- def monitor_op(self) -> float:
- """Returns the comparison method."""
- return self.mode_dict[self.mode]
-
- def on_fit_begin(self) -> Union[torch.lt, torch.gt]:
- """Reset the early stopping variables for reuse."""
- self.wait_count = 0
- self.stopped_epoch = 0
- self.best_score = (
- self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
- )
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Computes the early stop criterion."""
- current = self.get_monitor_value(logs)
- if current is None:
- return
- if self.monitor_op(current - self.min_delta, self.best_score):
- self.best_score = current
- self.wait_count = 0
- else:
- self.wait_count += 1
- if self.wait_count >= self.patience:
- self.stopped_epoch = epoch
- self.model.stop_training = True
-
- def on_fit_end(self) -> None:
- """Logs if early stopping was used."""
- if self.stopped_epoch > 0:
- logger.info(
- f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
- )
-
- def get_monitor_value(self, logs: Dict) -> Union[torch.Tensor, None]:
- """Extracts the monitor value."""
- monitor_value = logs.get(self.monitor)
- if monitor_value is None:
- logger.warning(
- f"Early stopping is conditioned on metric {self.monitor} which is not available. Available"
- + f"metrics are: {','.join(list(logs.keys()))}"
- )
- return None
- return torch.tensor(monitor_value)
diff --git a/src/training/callbacks/lr_schedulers.py b/src/training/callbacks/lr_schedulers.py
deleted file mode 100644
index 00c7e9b..0000000
--- a/src/training/callbacks/lr_schedulers.py
+++ /dev/null
@@ -1,97 +0,0 @@
-"""Callbacks for learning rate schedulers."""
-from typing import Callable, Dict, List, Optional, Type
-
-from training.callbacks import Callback
-
-from text_recognizer.models import Model
-
-
-class StepLR(Callback):
- """Callback for StepLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
- """Takes a step at the end of every epoch."""
- self.lr_scheduler.step()
-
-
-class MultiStepLR(Callback):
- """Callback for MultiStepLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
- """Takes a step at the end of every epoch."""
- self.lr_scheduler.step()
-
-
-class ReduceLROnPlateau(Callback):
- """Callback for ReduceLROnPlateau."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
- """Takes a step at the end of every epoch."""
- val_loss = logs["val_loss"]
- self.lr_scheduler.step(val_loss)
-
-
-class CyclicLR(Callback):
- """Callback for CyclicLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
- """Takes a step at the end of every training batch."""
- self.lr_scheduler.step()
-
-
-class OneCycleLR(Callback):
- """Callback for OneCycleLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
- """Takes a step at the end of every training batch."""
- self.lr_scheduler.step()
diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/callbacks/wandb_callbacks.py
deleted file mode 100644
index 6ada6df..0000000
--- a/src/training/callbacks/wandb_callbacks.py
+++ /dev/null
@@ -1,93 +0,0 @@
-"""Callbacks using wandb."""
-from typing import Callable, Dict, List, Optional, Type
-
-import numpy as np
-from torchvision.transforms import Compose, ToTensor
-from training.callbacks import Callback
-import wandb
-
-from text_recognizer.datasets import Transpose
-from text_recognizer.models.base import Model
-
-
-class WandbCallback(Callback):
- """A custom W&B metric logger for the trainer."""
-
- def __init__(self, log_batch_frequency: int = None) -> None:
- """Short summary.
-
- Args:
- log_batch_frequency (int): If None, metrics will be logged every epoch.
- If set to an integer, callback will log every metrics every log_batch_frequency.
-
- """
- super().__init__()
- self.log_batch_frequency = log_batch_frequency
-
- def _on_batch_end(self, batch: int, logs: Dict) -> None:
- if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
- wandb.log(logs, commit=True)
-
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
- """Logs training metrics."""
- if logs is not None:
- self._on_batch_end(batch, logs)
-
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
- """Logs validation metrics."""
- if logs is not None:
- self._on_batch_end(batch, logs)
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Logs at epoch end."""
- wandb.log(logs, commit=True)
-
-
-class WandbImageLogger(Callback):
- """Custom W&B callback for image logging."""
-
- def __init__(
- self,
- example_indices: Optional[List] = None,
- num_examples: int = 4,
- transfroms: Optional[Callable] = None,
- ) -> None:
- """Initializes the WandbImageLogger with the model to train.
-
- Args:
- example_indices (Optional[List]): Indices for validation images. Defaults to None.
- num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
- transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to
- None.
-
- """
-
- super().__init__()
- self.example_indices = example_indices
- self.num_examples = num_examples
- self.transfroms = transfroms
- if self.transfroms is None:
- self.transforms = Compose([Transpose()])
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and extracts validation images from the dataset."""
- self.model = model
- data_loader = self.model.data_loaders["val"]
- if self.example_indices is None:
- self.example_indices = np.random.randint(
- 0, len(data_loader.dataset.data), self.num_examples
- )
- self.val_images = data_loader.dataset.data[self.example_indices]
- self.val_targets = data_loader.dataset.targets[self.example_indices].numpy()
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Get network predictions on validation images."""
- images = []
- for i, image in enumerate(self.val_images):
- image = self.transforms(image)
- pred, conf = self.model.predict_on_image(image)
- ground_truth = self.model.mapper(int(self.val_targets[i]))
- caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
- images.append(wandb.Image(image, caption=caption))
-
- wandb.log({"examples": images}, commit=False)