diff options
Diffstat (limited to 'src/training/trainer/callbacks/checkpoint.py')
-rw-r--r-- | src/training/trainer/callbacks/checkpoint.py | 95 |
1 files changed, 0 insertions, 95 deletions
diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py deleted file mode 100644 index a54e0a9..0000000 --- a/src/training/trainer/callbacks/checkpoint.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Callback checkpoint for training models.""" -from enum import Enum -from pathlib import Path -from typing import Callable, Dict, List, Optional, Type, Union - -from loguru import logger -import numpy as np -import torch -from training.trainer.callbacks import Callback - -from text_recognizer.models import Model - - -class Checkpoint(Callback): - """Saving model parameters at the end of each epoch.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, - checkpoint_path: Union[str, Path], - monitor: str = "accuracy", - mode: str = "auto", - min_delta: float = 0.0, - ) -> None: - """Monitors a quantity that will allow us to determine the best model weights. - - Args: - checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. - monitor (str): Name of the quantity to monitor. Defaults to "accuracy". - mode (str): Description of parameter `mode`. Defaults to "auto". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - - """ - super().__init__() - self.checkpoint_path = Path(checkpoint_path) - self.monitor = monitor - self.mode = mode - self.min_delta = torch.tensor(min_delta) - - if mode not in ["auto", "min", "max"]: - logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." - ) - - torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Saves a checkpoint for the network parameters. - - Args: - epoch (int): The current epoch. - logs (Dict): The log containing the monitored metrics. - - """ - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - is_best = True - else: - is_best = False - - self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor) - - def get_monitor_value(self, logs: Dict) -> Union[float, None]: - """Extracts the monitored value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" - + f" metrics are: {','.join(list(logs.keys()))}" - ) - return None - return monitor_value |