diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 20:47:55 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 20:47:55 +0200 |
commit | 9ae5fa1a88899180f88ddb14d4cef457ceb847e5 (patch) | |
tree | 4fe2bcd82553c8062eb0908ae6442c123addf55d /training/trainer/callbacks/base.py | |
parent | 9e54591b7e342edc93b0bb04809a0f54045c6a15 (diff) |
Add new training loop with PyTorch Lightning, remove stale files
Diffstat (limited to 'training/trainer/callbacks/base.py')
-rw-r--r-- | training/trainer/callbacks/base.py | 188 |
1 files changed, 0 insertions, 188 deletions
diff --git a/training/trainer/callbacks/base.py b/training/trainer/callbacks/base.py deleted file mode 100644 index 500b642..0000000 --- a/training/trainer/callbacks/base.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Metaclass for callback functions.""" - -from enum import Enum -from typing import Callable, Dict, List, Optional, Type, Union - -from loguru import logger -import numpy as np -import torch - -from text_recognizer.models import Model - - -class ModeKeys: - """Mode keys for CallbackList.""" - - TRAIN = "train" - VALIDATION = "validation" - - -class Callback: - """Metaclass for callbacks used in training.""" - - def __init__(self) -> None: - """Initializes the Callback instance.""" - self.model = None - - def set_model(self, model: Type[Model]) -> None: - """Set the model.""" - self.model = model - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - pass - - def on_fit_end(self) -> None: - """Called when fit ends.""" - pass - - def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch. Only used in training mode.""" - pass - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch. Only used in training mode.""" - pass - - def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - pass - - def on_validation_batch_begin( - self, batch: int, logs: Optional[Dict] = None - ) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - pass - - def on_test_begin(self) -> None: - """Called at the beginning of test.""" - pass - - def on_test_end(self) -> None: - """Called at the end of test.""" - pass - - -class CallbackList: - """Container for abstracting away callback calls.""" - - mode_keys = ModeKeys() - - def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: - """Container for `Callback` instances. - - This object wraps a list of `Callback` instances and allows them all to be - called via a single end point. - - Args: - model (Type[Model]): A `Model` instance. - callbacks (List[Callback]): List of `Callback` instances. Defaults to None. - - """ - - self._callbacks = callbacks or [] - if model: - self.set_model(model) - - def set_model(self, model: Type[Model]) -> None: - """Set the model for all callbacks.""" - self.model = model - for callback in self._callbacks: - callback.set_model(model=self.model) - - def append(self, callback: Type[Callback]) -> None: - """Append new callback to callback list.""" - self._callbacks.append(callback) - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - for callback in self._callbacks: - callback.on_fit_begin() - - def on_fit_end(self) -> None: - """Called when fit ends.""" - for callback in self._callbacks: - callback.on_fit_end() - - def on_test_begin(self) -> None: - """Called when test begins.""" - for callback in self._callbacks: - callback.on_test_begin() - - def on_test_end(self) -> None: - """Called when test ends.""" - for callback in self._callbacks: - callback.on_test_end() - - def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_begin(epoch, logs) - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_end(epoch, logs) - - def _call_batch_hook( - self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for all batch_{begin | end} methods.""" - if hook == "begin": - self._call_batch_begin_hook(mode, batch, logs) - elif hook == "end": - self._call_batch_end_hook(mode, batch, logs) - else: - raise ValueError(f"Unrecognized hook {hook}.") - - def _call_batch_begin_hook( - self, mode: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for all `on_*_batch_begin` methods.""" - hook_name = f"on_{mode}_batch_begin" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_end_hook( - self, mode: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for all `on_*_batch_end` methods.""" - hook_name = f"on_{mode}_batch_end" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_hook_helper( - self, hook_name: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for `on_*_batch_begin` methods.""" - for callback in self._callbacks: - hook = getattr(callback, hook_name) - hook(batch, logs) - - def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) - - def on_validation_batch_begin( - self, batch: int, logs: Optional[Dict] = None - ) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) - - def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) - - def __iter__(self) -> iter: - """Iter function for callback list.""" - return iter(self._callbacks) |