diff options
Diffstat (limited to 'training/trainer/callbacks/checkpoint.py')
-rw-r--r-- | training/trainer/callbacks/checkpoint.py | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/training/trainer/callbacks/checkpoint.py b/training/trainer/callbacks/checkpoint.py new file mode 100644 index 0000000..a54e0a9 --- /dev/null +++ b/training/trainer/callbacks/checkpoint.py @@ -0,0 +1,95 @@ +"""Callback checkpoint for training models.""" +from enum import Enum +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class Checkpoint(Callback): + """Saving model parameters at the end of each epoch.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, + checkpoint_path: Union[str, Path], + monitor: str = "accuracy", + mode: str = "auto", + min_delta: float = 0.0, + ) -> None: + """Monitors a quantity that will allow us to determine the best model weights. + + Args: + checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. + monitor (str): Name of the quantity to monitor. Defaults to "accuracy". + mode (str): Description of parameter `mode`. Defaults to "auto". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + + """ + super().__init__() + self.checkpoint_path = Path(checkpoint_path) + self.monitor = monitor + self.mode = mode + self.min_delta = torch.tensor(min_delta) + + if mode not in ["auto", "min", "max"]: + logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." + ) + + torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Saves a checkpoint for the network parameters. + + Args: + epoch (int): The current epoch. + logs (Dict): The log containing the monitored metrics. + + """ + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + is_best = True + else: + is_best = False + + self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor) + + def get_monitor_value(self, logs: Dict) -> Union[float, None]: + """Extracts the monitored value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" + + f" metrics are: {','.join(list(logs.keys()))}" + ) + return None + return monitor_value |