summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/base.py
blob: 8df94f34d880addb13cd5669a74103c422f669ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""Metaclass for callback functions."""

from enum import Enum
from typing import Callable, Dict, List, Optional, Type, Union

from loguru import logger
import numpy as np
import torch

from text_recognizer.models import Model


class ModeKeys:
    """Mode keys for CallbackList."""

    TRAIN = "train"
    VALIDATION = "validation"


class Callback:
    """Metaclass for callbacks used in training."""

    def __init__(self) -> None:
        """Initializes the Callback instance."""
        self.model = None

    def set_model(self, model: Type[Model]) -> None:
        """Set the model."""
        self.model = model

    def on_fit_begin(self) -> None:
        """Called when fit begins."""
        pass

    def on_fit_end(self) -> None:
        """Called when fit ends."""
        pass

    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Called at the beginning of an epoch. Only used in training mode."""
        pass

    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch. Only used in training mode."""
        pass

    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the beginning of an epoch."""
        pass

    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch."""
        pass

    def on_validation_batch_begin(
        self, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Called at the beginning of an epoch."""
        pass

    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch."""
        pass


class CallbackList:
    """Container for abstracting away callback calls."""

    mode_keys = ModeKeys()

    def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None:
        """Container for `Callback` instances.

        This object wraps a list of `Callback` instances and allows them all to be
        called via a single end point.

        Args:
            model (Type[Model]): A `Model` instance.
            callbacks (List[Callback]): List of `Callback` instances. Defaults to None.

        """

        self._callbacks = callbacks or []
        if model:
            self.set_model(model)

    def set_model(self, model: Type[Model]) -> None:
        """Set the model for all callbacks."""
        self.model = model
        for callback in self._callbacks:
            callback.set_model(model=self.model)

    def append(self, callback: Type[Callback]) -> None:
        """Append new callback to callback list."""
        self.callbacks.append(callback)

    def on_fit_begin(self) -> None:
        """Called when fit begins."""
        for callback in self._callbacks:
            callback.on_fit_begin()

    def on_fit_end(self) -> None:
        """Called when fit ends."""
        for callback in self._callbacks:
            callback.on_fit_end()

    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Called at the beginning of an epoch."""
        for callback in self._callbacks:
            callback.on_epoch_begin(epoch, logs)

    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch."""
        for callback in self._callbacks:
            callback.on_epoch_end(epoch, logs)

    def _call_batch_hook(
        self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Helper function for all batch_{begin | end} methods."""
        if hook == "begin":
            self._call_batch_begin_hook(mode, batch, logs)
        elif hook == "end":
            self._call_batch_end_hook(mode, batch, logs)
        else:
            raise ValueError(f"Unrecognized hook {hook}.")

    def _call_batch_begin_hook(
        self, mode: str, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Helper function for all `on_*_batch_begin` methods."""
        hook_name = f"on_{mode}_batch_begin"
        self._call_batch_hook_helper(hook_name, batch, logs)

    def _call_batch_end_hook(
        self, mode: str, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Helper function for all `on_*_batch_end` methods."""
        hook_name = f"on_{mode}_batch_end"
        self._call_batch_hook_helper(hook_name, batch, logs)

    def _call_batch_hook_helper(
        self, hook_name: str, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Helper function for `on_*_batch_begin` methods."""
        for callback in self._callbacks:
            hook = getattr(callback, hook_name)
            hook(batch, logs)

    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the beginning of an epoch."""
        self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)

    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch."""
        self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)

    def on_validation_batch_begin(
        self, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Called at the beginning of an epoch."""
        self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)

    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch."""
        self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)

    def __iter__(self) -> iter:
        """Iter function for callback list."""
        return iter(self._callbacks)


class Checkpoint(Callback):
    """Saving model parameters at the end of each epoch."""

    mode_dict = {
        "min": torch.lt,
        "max": torch.gt,
    }

    def __init__(
        self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0
    ) -> None:
        """Monitors a quantity that will allow us to determine the best model weights.

        Args:
            monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
            mode (str): Description of parameter `mode`. Defaults to "auto".
            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.

        """
        super().__init__()
        self.monitor = monitor
        self.mode = mode
        self.min_delta = torch.tensor(min_delta)

        if mode not in ["auto", "min", "max"]:
            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")

            self.mode = "auto"

        if self.mode == "auto":
            if "accuracy" in self.monitor:
                self.mode = "max"
            else:
                self.mode = "min"
            logger.debug(
                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
            )

        torch_inf = torch.tensor(np.inf)
        self.min_delta *= 1 if self.monitor_op == torch.gt else -1
        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf

    @property
    def monitor_op(self) -> float:
        """Returns the comparison method."""
        return self.mode_dict[self.mode]

    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
        """Saves a checkpoint for the network parameters.

        Args:
            epoch (int): The current epoch.
            logs (Dict): The log containing the monitored metrics.

        """
        current = self.get_monitor_value(logs)
        if current is None:
            return
        if self.monitor_op(current - self.min_delta, self.best_score):
            self.best_score = current
            is_best = True
        else:
            is_best = False

        self.model.save_checkpoint(is_best, epoch, self.monitor)

    def get_monitor_value(self, logs: Dict) -> Union[float, None]:
        """Extracts the monitored value."""
        monitor_value = logs.get(self.monitor)
        if monitor_value is None:
            logger.warning(
                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
                + f"metrics are: {','.join(list(logs.keys()))}"
            )
            return None
        return monitor_value