summaryrefslogtreecommitdiff
path: root/training/trainer/callbacks/early_stopping.py
blob: 02b431fdd45e0e1ac4d5954e156826bb8d8b12a8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Implements Early stopping for PyTorch model."""
from typing import Dict, Union

from loguru import logger
import numpy as np
import torch
from torch import Tensor
from training.trainer.callbacks import Callback


class EarlyStopping(Callback):
    """Stops training when a monitored metric stops improving."""

    mode_dict = {
        "min": torch.lt,
        "max": torch.gt,
    }

    def __init__(
        self,
        monitor: str = "val_loss",
        min_delta: float = 0.0,
        patience: int = 3,
        mode: str = "auto",
    ) -> None:
        """Initializes the EarlyStopping callback.

        Args:
            monitor (str): Description of parameter `monitor`. Defaults to "val_loss".
            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
            patience (int): Description of parameter `patience`. Defaults to 3.
            mode (str): Description of parameter `mode`. Defaults to "auto".

        """
        super().__init__()
        self.monitor = monitor
        self.patience = patience
        self.min_delta = torch.tensor(min_delta)
        self.mode = mode
        self.wait_count = 0
        self.stopped_epoch = 0

        if mode not in ["auto", "min", "max"]:
            logger.warning(
                f"EarlyStopping mode {mode} is unkown, fallback to auto mode."
            )

            self.mode = "auto"

        if self.mode == "auto":
            if "accuracy" in self.monitor:
                self.mode = "max"
            else:
                self.mode = "min"
            logger.debug(
                f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}."
            )

        self.torch_inf = torch.tensor(np.inf)
        self.min_delta *= 1 if self.monitor_op == torch.gt else -1
        self.best_score = (
            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
        )

    @property
    def monitor_op(self) -> float:
        """Returns the comparison method."""
        return self.mode_dict[self.mode]

    def on_fit_begin(self) -> Union[torch.lt, torch.gt]:
        """Reset the early stopping variables for reuse."""
        self.wait_count = 0
        self.stopped_epoch = 0
        self.best_score = (
            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
        )

    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
        """Computes the early stop criterion."""
        current = self.get_monitor_value(logs)
        if current is None:
            return
        if self.monitor_op(current - self.min_delta, self.best_score):
            self.best_score = current
            self.wait_count = 0
        else:
            self.wait_count += 1
            if self.wait_count >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True

    def on_fit_end(self) -> None:
        """Logs if early stopping was used."""
        if self.stopped_epoch > 0:
            logger.info(
                f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
            )

    def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]:
        """Extracts the monitor value."""
        monitor_value = logs.get(self.monitor)
        if monitor_value is None:
            logger.warning(
                f"Early stopping is conditioned on metric {self.monitor} which is not available. Available"
                + f"metrics are: {','.join(list(logs.keys()))}"
            )
            return None
        return torch.tensor(monitor_value)