diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
commit | 1f459ba19422593de325983040e176f97cf4ffc0 (patch) | |
tree | 89fef442d5dbe0c83253e9566d1762f0704f64e2 /src/training/trainer/callbacks/base.py | |
parent | 95cbdf5bc1cc9639febda23c28d8f464c998b214 (diff) |
A lot of stuff working :D. ResNet implemented!
Diffstat (limited to 'src/training/trainer/callbacks/base.py')
-rw-r--r-- | src/training/trainer/callbacks/base.py | 248 |
1 files changed, 248 insertions, 0 deletions
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py new file mode 100644 index 0000000..8df94f3 --- /dev/null +++ b/src/training/trainer/callbacks/base.py @@ -0,0 +1,248 @@ +"""Metaclass for callback functions.""" + +from enum import Enum +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch + +from text_recognizer.models import Model + + +class ModeKeys: + """Mode keys for CallbackList.""" + + TRAIN = "train" + VALIDATION = "validation" + + +class Callback: + """Metaclass for callbacks used in training.""" + + def __init__(self) -> None: + """Initializes the Callback instance.""" + self.model = None + + def set_model(self, model: Type[Model]) -> None: + """Set the model.""" + self.model = model + + def on_fit_begin(self) -> None: + """Called when fit begins.""" + pass + + def on_fit_end(self) -> None: + """Called when fit ends.""" + pass + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch. Only used in training mode.""" + pass + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch. Only used in training mode.""" + pass + + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + pass + + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + pass + + +class CallbackList: + """Container for abstracting away callback calls.""" + + mode_keys = ModeKeys() + + def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: + """Container for `Callback` instances. + + This object wraps a list of `Callback` instances and allows them all to be + called via a single end point. + + Args: + model (Type[Model]): A `Model` instance. + callbacks (List[Callback]): List of `Callback` instances. Defaults to None. + + """ + + self._callbacks = callbacks or [] + if model: + self.set_model(model) + + def set_model(self, model: Type[Model]) -> None: + """Set the model for all callbacks.""" + self.model = model + for callback in self._callbacks: + callback.set_model(model=self.model) + + def append(self, callback: Type[Callback]) -> None: + """Append new callback to callback list.""" + self.callbacks.append(callback) + + def on_fit_begin(self) -> None: + """Called when fit begins.""" + for callback in self._callbacks: + callback.on_fit_begin() + + def on_fit_end(self) -> None: + """Called when fit ends.""" + for callback in self._callbacks: + callback.on_fit_end() + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + for callback in self._callbacks: + callback.on_epoch_begin(epoch, logs) + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + for callback in self._callbacks: + callback.on_epoch_end(epoch, logs) + + def _call_batch_hook( + self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all batch_{begin | end} methods.""" + if hook == "begin": + self._call_batch_begin_hook(mode, batch, logs) + elif hook == "end": + self._call_batch_end_hook(mode, batch, logs) + else: + raise ValueError(f"Unrecognized hook {hook}.") + + def _call_batch_begin_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all `on_*_batch_begin` methods.""" + hook_name = f"on_{mode}_batch_begin" + self._call_batch_hook_helper(hook_name, batch, logs) + + def _call_batch_end_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for all `on_*_batch_end` methods.""" + hook_name = f"on_{mode}_batch_end" + self._call_batch_hook_helper(hook_name, batch, logs) + + def _call_batch_hook_helper( + self, hook_name: str, batch: int, logs: Optional[Dict] = None + ) -> None: + """Helper function for `on_*_batch_begin` methods.""" + for callback in self._callbacks: + hook = getattr(callback, hook_name) + hook(batch, logs) + + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the beginning of an epoch.""" + self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) + + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: + """Called at the beginning of an epoch.""" + self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) + + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Called at the end of an epoch.""" + self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) + + def __iter__(self) -> iter: + """Iter function for callback list.""" + return iter(self._callbacks) + + +class Checkpoint(Callback): + """Saving model parameters at the end of each epoch.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 + ) -> None: + """Monitors a quantity that will allow us to determine the best model weights. + + Args: + monitor (str): Name of the quantity to monitor. Defaults to "accuracy". + mode (str): Description of parameter `mode`. Defaults to "auto". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + + """ + super().__init__() + self.monitor = monitor + self.mode = mode + self.min_delta = torch.tensor(min_delta) + + if mode not in ["auto", "min", "max"]: + logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." + ) + + torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Saves a checkpoint for the network parameters. + + Args: + epoch (int): The current epoch. + logs (Dict): The log containing the monitored metrics. + + """ + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + is_best = True + else: + is_best = False + + self.model.save_checkpoint(is_best, epoch, self.monitor) + + def get_monitor_value(self, logs: Dict) -> Union[float, None]: + """Extracts the monitored value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" + + f"metrics are: {','.join(list(logs.keys()))}" + ) + return None + return monitor_value |