1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
|
"""Implements Early stopping for PyTorch model."""
from typing import Dict, Union
from loguru import logger
import numpy as np
import torch
from torch import Tensor
from training.trainer.callbacks import Callback
class EarlyStopping(Callback):
"""Stops training when a monitored metric stops improving."""
mode_dict = {
"min": torch.lt,
"max": torch.gt,
}
def __init__(
self,
monitor: str = "val_loss",
min_delta: float = 0.0,
patience: int = 3,
mode: str = "auto",
) -> None:
"""Initializes the EarlyStopping callback.
Args:
monitor (str): Description of parameter `monitor`. Defaults to "val_loss".
min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
patience (int): Description of parameter `patience`. Defaults to 3.
mode (str): Description of parameter `mode`. Defaults to "auto".
"""
super().__init__()
self.monitor = monitor
self.patience = patience
self.min_delta = torch.tensor(min_delta)
self.mode = mode
self.wait_count = 0
self.stopped_epoch = 0
if mode not in ["auto", "min", "max"]:
logger.warning(
f"EarlyStopping mode {mode} is unkown, fallback to auto mode."
)
self.mode = "auto"
if self.mode == "auto":
if "accuracy" in self.monitor:
self.mode = "max"
else:
self.mode = "min"
logger.debug(
f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}."
)
self.torch_inf = torch.tensor(np.inf)
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
self.best_score = (
self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
)
@property
def monitor_op(self) -> float:
"""Returns the comparison method."""
return self.mode_dict[self.mode]
def on_fit_begin(self) -> Union[torch.lt, torch.gt]:
"""Reset the early stopping variables for reuse."""
self.wait_count = 0
self.stopped_epoch = 0
self.best_score = (
self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
)
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
"""Computes the early stop criterion."""
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
else:
self.wait_count += 1
if self.wait_count >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
def on_fit_end(self) -> None:
"""Logs if early stopping was used."""
if self.stopped_epoch > 0:
logger.info(
f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
)
def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]:
"""Extracts the monitor value."""
monitor_value = logs.get(self.monitor)
if monitor_value is None:
logger.warning(
f"Early stopping is conditioned on metric {self.monitor} which is not available. Available"
+ f"metrics are: {','.join(list(logs.keys()))}"
)
return None
return torch.tensor(monitor_value)
|