1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
|
"""Callback checkpoint for training models."""
from enum import Enum
from pathlib import Path
from typing import Callable, Dict, List, Optional, Type, Union
from loguru import logger
import numpy as np
import torch
from training.trainer.callbacks import Callback
from text_recognizer.models import Model
class Checkpoint(Callback):
"""Saving model parameters at the end of each epoch."""
mode_dict = {
"min": torch.lt,
"max": torch.gt,
}
def __init__(
self,
checkpoint_path: Path,
monitor: str = "accuracy",
mode: str = "auto",
min_delta: float = 0.0,
) -> None:
"""Monitors a quantity that will allow us to determine the best model weights.
Args:
checkpoint_path (Path): Path to the experiment with the checkpoint.
monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
mode (str): Description of parameter `mode`. Defaults to "auto".
min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
"""
super().__init__()
self.checkpoint_path = checkpoint_path
self.monitor = monitor
self.mode = mode
self.min_delta = torch.tensor(min_delta)
if mode not in ["auto", "min", "max"]:
logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
self.mode = "auto"
if self.mode == "auto":
if "accuracy" in self.monitor:
self.mode = "max"
else:
self.mode = "min"
logger.debug(
f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
)
torch_inf = torch.tensor(np.inf)
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
@property
def monitor_op(self) -> float:
"""Returns the comparison method."""
return self.mode_dict[self.mode]
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
"""Saves a checkpoint for the network parameters.
Args:
epoch (int): The current epoch.
logs (Dict): The log containing the monitored metrics.
"""
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
is_best = True
else:
is_best = False
self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor)
def get_monitor_value(self, logs: Dict) -> Union[float, None]:
"""Extracts the monitored value."""
monitor_value = logs.get(self.monitor)
if monitor_value is None:
logger.warning(
f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
+ f" metrics are: {','.join(list(logs.keys()))}"
)
return None
return monitor_value
|