1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
|
"""Implements Early stopping for PyTorch model."""
from typing import Dict, Union
from loguru import logger
import numpy as np
import torch
from training.callbacks import Callback
class EarlyStopping(Callback):
"""Stops training when a monitored metric stops improving."""
mode_dict = {
"min": torch.lt,
"max": torch.gt,
}
def __init__(
self,
monitor: str = "val_loss",
min_delta: float = 0.0,
patience: int = 3,
mode: str = "auto",
) -> None:
"""Initializes the EarlyStopping callback.
Args:
monitor (str): Description of parameter `monitor`. Defaults to "val_loss".
min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
patience (int): Description of parameter `patience`. Defaults to 3.
mode (str): Description of parameter `mode`. Defaults to "auto".
"""
super().__init__()
self.monitor = monitor
self.patience = patience
self.min_delta = torch.tensor(min_delta)
self.mode = mode
self.wait_count = 0
self.stopped_epoch = 0
if mode not in ["auto", "min", "max"]:
logger.warning(
f"EarlyStopping mode {mode} is unkown, fallback to auto mode."
)
self.mode = "auto"
if self.mode == "auto":
if "accuracy" in self.monitor:
self.mode = "max"
else:
self.mode = "min"
logger.debug(
f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}."
)
self.torch_inf = torch.tensor(np.inf)
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
self.best_score = (
self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
)
@property
def monitor_op(self) -> float:
"""Returns the comparison method."""
return self.mode_dict[self.mode]
def on_fit_begin(self) -> Union[torch.lt, torch.gt]:
"""Reset the early stopping variables for reuse."""
self.wait_count = 0
self.stopped_epoch = 0
self.best_score = (
self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
)
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
"""Computes the early stop criterion."""
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
else:
self.wait_count += 1
if self.wait_count >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
def on_fit_end(self) -> None:
"""Logs if early stopping was used."""
if self.stopped_epoch > 0:
logger.info(
f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
)
def get_monitor_value(self, logs: Dict) -> Union[torch.Tensor, None]:
"""Extracts the monitor value."""
monitor_value = logs.get(self.monitor)
if monitor_value is None:
logger.warning(
f"Early stopping is conditioned on metric {self.monitor} which is not available. Available"
+ f"metrics are: {','.join(list(logs.keys()))}"
)
return None
return torch.tensor(monitor_value)
|