1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
|
"""Metaclass for callback functions."""
from enum import Enum
from typing import Callable, Dict, List, Optional, Type, Union
from loguru import logger
import numpy as np
import torch
from text_recognizer.models import Model
class ModeKeys:
"""Mode keys for CallbackList."""
TRAIN = "train"
VALIDATION = "validation"
class Callback:
"""Metaclass for callbacks used in training."""
def __init__(self) -> None:
"""Initializes the Callback instance."""
self.model = None
def set_model(self, model: Type[Model]) -> None:
"""Set the model."""
self.model = model
def on_fit_begin(self) -> None:
"""Called when fit begins."""
pass
def on_fit_end(self) -> None:
"""Called when fit ends."""
pass
def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch. Only used in training mode."""
pass
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch. Only used in training mode."""
pass
def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
pass
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
pass
def on_validation_batch_begin(
self, batch: int, logs: Optional[Dict] = None
) -> None:
"""Called at the beginning of an epoch."""
pass
def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
pass
class CallbackList:
"""Container for abstracting away callback calls."""
mode_keys = ModeKeys()
def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None:
"""Container for `Callback` instances.
This object wraps a list of `Callback` instances and allows them all to be
called via a single end point.
Args:
model (Type[Model]): A `Model` instance.
callbacks (List[Callback]): List of `Callback` instances. Defaults to None.
"""
self._callbacks = callbacks or []
if model:
self.set_model(model)
def set_model(self, model: Type[Model]) -> None:
"""Set the model for all callbacks."""
self.model = model
for callback in self._callbacks:
callback.set_model(model=self.model)
def append(self, callback: Type[Callback]) -> None:
"""Append new callback to callback list."""
self.callbacks.append(callback)
def on_fit_begin(self) -> None:
"""Called when fit begins."""
for callback in self._callbacks:
callback.on_fit_begin()
def on_fit_end(self) -> None:
"""Called when fit ends."""
for callback in self._callbacks:
callback.on_fit_end()
def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
for callback in self._callbacks:
callback.on_epoch_begin(epoch, logs)
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
for callback in self._callbacks:
callback.on_epoch_end(epoch, logs)
def _call_batch_hook(
self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
) -> None:
"""Helper function for all batch_{begin | end} methods."""
if hook == "begin":
self._call_batch_begin_hook(mode, batch, logs)
elif hook == "end":
self._call_batch_end_hook(mode, batch, logs)
else:
raise ValueError(f"Unrecognized hook {hook}.")
def _call_batch_begin_hook(
self, mode: str, batch: int, logs: Optional[Dict] = None
) -> None:
"""Helper function for all `on_*_batch_begin` methods."""
hook_name = f"on_{mode}_batch_begin"
self._call_batch_hook_helper(hook_name, batch, logs)
def _call_batch_end_hook(
self, mode: str, batch: int, logs: Optional[Dict] = None
) -> None:
"""Helper function for all `on_*_batch_end` methods."""
hook_name = f"on_{mode}_batch_end"
self._call_batch_hook_helper(hook_name, batch, logs)
def _call_batch_hook_helper(
self, hook_name: str, batch: int, logs: Optional[Dict] = None
) -> None:
"""Helper function for `on_*_batch_begin` methods."""
for callback in self._callbacks:
hook = getattr(callback, hook_name)
hook(batch, logs)
def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)
def on_validation_batch_begin(
self, batch: int, logs: Optional[Dict] = None
) -> None:
"""Called at the beginning of an epoch."""
self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)
def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)
def __iter__(self) -> iter:
"""Iter function for callback list."""
return iter(self._callbacks)
|