summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/base.py
blob: f81fc1f5a0f12570b51792a08567b822341eb1f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""Metaclass for callback functions."""

from enum import Enum
from typing import Callable, Dict, List, Optional, Type, Union

from loguru import logger
import numpy as np
import torch

from text_recognizer.models import Model


class ModeKeys:
    """Mode keys for CallbackList."""

    TRAIN = "train"
    VALIDATION = "validation"


class Callback:
    """Metaclass for callbacks used in training."""

    def __init__(self) -> None:
        """Initializes the Callback instance."""
        self.model = None

    def set_model(self, model: Type[Model]) -> None:
        """Set the model."""
        self.model = model

    def on_fit_begin(self) -> None:
        """Called when fit begins."""
        pass

    def on_fit_end(self) -> None:
        """Called when fit ends."""
        pass

    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Called at the beginning of an epoch. Only used in training mode."""
        pass

    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch. Only used in training mode."""
        pass

    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the beginning of an epoch."""
        pass

    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch."""
        pass

    def on_validation_batch_begin(
        self, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Called at the beginning of an epoch."""
        pass

    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch."""
        pass


class CallbackList:
    """Container for abstracting away callback calls."""

    mode_keys = ModeKeys()

    def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None:
        """Container for `Callback` instances.

        This object wraps a list of `Callback` instances and allows them all to be
        called via a single end point.

        Args:
            model (Type[Model]): A `Model` instance.
            callbacks (List[Callback]): List of `Callback` instances. Defaults to None.

        """

        self._callbacks = callbacks or []
        if model:
            self.set_model(model)

    def set_model(self, model: Type[Model]) -> None:
        """Set the model for all callbacks."""
        self.model = model
        for callback in self._callbacks:
            callback.set_model(model=self.model)

    def append(self, callback: Type[Callback]) -> None:
        """Append new callback to callback list."""
        self._callbacks.append(callback)

    def on_fit_begin(self) -> None:
        """Called when fit begins."""
        for callback in self._callbacks:
            callback.on_fit_begin()

    def on_fit_end(self) -> None:
        """Called when fit ends."""
        for callback in self._callbacks:
            callback.on_fit_end()

    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Called at the beginning of an epoch."""
        for callback in self._callbacks:
            callback.on_epoch_begin(epoch, logs)

    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch."""
        for callback in self._callbacks:
            callback.on_epoch_end(epoch, logs)

    def _call_batch_hook(
        self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Helper function for all batch_{begin | end} methods."""
        if hook == "begin":
            self._call_batch_begin_hook(mode, batch, logs)
        elif hook == "end":
            self._call_batch_end_hook(mode, batch, logs)
        else:
            raise ValueError(f"Unrecognized hook {hook}.")

    def _call_batch_begin_hook(
        self, mode: str, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Helper function for all `on_*_batch_begin` methods."""
        hook_name = f"on_{mode}_batch_begin"
        self._call_batch_hook_helper(hook_name, batch, logs)

    def _call_batch_end_hook(
        self, mode: str, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Helper function for all `on_*_batch_end` methods."""
        hook_name = f"on_{mode}_batch_end"
        self._call_batch_hook_helper(hook_name, batch, logs)

    def _call_batch_hook_helper(
        self, hook_name: str, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Helper function for `on_*_batch_begin` methods."""
        for callback in self._callbacks:
            hook = getattr(callback, hook_name)
            hook(batch, logs)

    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the beginning of an epoch."""
        self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)

    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch."""
        self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)

    def on_validation_batch_begin(
        self, batch: int, logs: Optional[Dict] = None
    ) -> None:
        """Called at the beginning of an epoch."""
        self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)

    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Called at the end of an epoch."""
        self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)

    def __iter__(self) -> iter:
        """Iter function for callback list."""
        return iter(self._callbacks)