summaryrefslogtreecommitdiff
path: root/src/training/callbacks/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/callbacks/base.py')
-rw-r--r--src/training/callbacks/base.py101
1 files changed, 101 insertions, 0 deletions
diff --git a/src/training/callbacks/base.py b/src/training/callbacks/base.py
new file mode 100644
index 0000000..d80a1e5
--- /dev/null
+++ b/src/training/callbacks/base.py
@@ -0,0 +1,101 @@
+"""Metaclass for callback functions."""
+
+from abc import ABC
+from typing import Callable, List, Type
+
+
+class Callback(ABC):
+ """Metaclass for callbacks used in training."""
+
+ def on_fit_begin(self) -> None:
+ """Called when fit begins."""
+ pass
+
+ def on_fit_end(self) -> None:
+ """Called when fit ends."""
+ pass
+
+ def on_train_epoch_begin(self) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_train_epoch_end(self) -> None:
+ """Called at the end of an epoch."""
+ pass
+
+ def on_val_epoch_begin(self) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_val_epoch_end(self) -> None:
+ """Called at the end of an epoch."""
+ pass
+
+ def on_train_batch_begin(self) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_train_batch_end(self) -> None:
+ """Called at the end of an epoch."""
+ pass
+
+ def on_val_batch_begin(self) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_val_batch_end(self) -> None:
+ """Called at the end of an epoch."""
+ pass
+
+
+class CallbackList:
+ """Container for abstracting away callback calls."""
+
+ def __init__(self, callbacks: List[Callable] = None) -> None:
+ """TBC."""
+ self._callbacks = callbacks if callbacks is not None else []
+
+ def append(self, callback: Type[Callback]) -> None:
+ """Append new callback to callback list."""
+ self.callbacks.append(callback)
+
+ def on_fit_begin(self) -> None:
+ """Called when fit begins."""
+ for _ in self._callbacks:
+ pass
+
+ def on_fit_end(self) -> None:
+ """Called when fit ends."""
+ pass
+
+ def on_train_epoch_begin(self) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_train_epoch_end(self) -> None:
+ """Called at the end of an epoch."""
+ pass
+
+ def on_val_epoch_begin(self) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_val_epoch_end(self) -> None:
+ """Called at the end of an epoch."""
+ pass
+
+ def on_train_batch_begin(self) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_train_batch_end(self) -> None:
+ """Called at the end of an epoch."""
+ pass
+
+ def on_val_batch_begin(self) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_val_batch_end(self) -> None:
+ """Called at the end of an epoch."""
+ pass