summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/checkpoint.py
blob: 6fe06d3dd0343ba50db5539acdaf9684d74be7d7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Callback checkpoint for training models."""
from enum import Enum
from pathlib import Path
from typing import Callable, Dict, List, Optional, Type, Union

from loguru import logger
import numpy as np
import torch
from training.trainer.callbacks import Callback

from text_recognizer.models import Model


class Checkpoint(Callback):
    """Saving model parameters at the end of each epoch."""

    mode_dict = {
        "min": torch.lt,
        "max": torch.gt,
    }

    def __init__(
        self,
        checkpoint_path: Path,
        monitor: str = "accuracy",
        mode: str = "auto",
        min_delta: float = 0.0,
    ) -> None:
        """Monitors a quantity that will allow us to determine the best model weights.

        Args:
            checkpoint_path (Path): Path to the experiment with the checkpoint.
            monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
            mode (str): Description of parameter `mode`. Defaults to "auto".
            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.

        """
        super().__init__()
        self.checkpoint_path = checkpoint_path
        self.monitor = monitor
        self.mode = mode
        self.min_delta = torch.tensor(min_delta)

        if mode not in ["auto", "min", "max"]:
            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")

            self.mode = "auto"

        if self.mode == "auto":
            if "accuracy" in self.monitor:
                self.mode = "max"
            else:
                self.mode = "min"
            logger.debug(
                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
            )

        torch_inf = torch.tensor(np.inf)
        self.min_delta *= 1 if self.monitor_op == torch.gt else -1
        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf

    @property
    def monitor_op(self) -> float:
        """Returns the comparison method."""
        return self.mode_dict[self.mode]

    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
        """Saves a checkpoint for the network parameters.

        Args:
            epoch (int): The current epoch.
            logs (Dict): The log containing the monitored metrics.

        """
        current = self.get_monitor_value(logs)
        if current is None:
            return
        if self.monitor_op(current - self.min_delta, self.best_score):
            self.best_score = current
            is_best = True
        else:
            is_best = False

        self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor)

    def get_monitor_value(self, logs: Dict) -> Union[float, None]:
        """Extracts the monitored value."""
        monitor_value = logs.get(self.monitor)
        if monitor_value is None:
            logger.warning(
                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
                + f" metrics are: {','.join(list(logs.keys()))}"
            )
            return None
        return monitor_value