summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer/callbacks/base.py')
-rw-r--r--src/training/trainer/callbacks/base.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py
index f81fc1f..500b642 100644
--- a/src/training/trainer/callbacks/base.py
+++ b/src/training/trainer/callbacks/base.py
@@ -62,6 +62,14 @@ class Callback:
"""Called at the end of an epoch."""
pass
+ def on_test_begin(self) -> None:
+ """Called at the beginning of test."""
+ pass
+
+ def on_test_end(self) -> None:
+ """Called at the end of test."""
+ pass
+
class CallbackList:
"""Container for abstracting away callback calls."""
@@ -104,6 +112,16 @@ class CallbackList:
for callback in self._callbacks:
callback.on_fit_end()
+ def on_test_begin(self) -> None:
+ """Called when test begins."""
+ for callback in self._callbacks:
+ callback.on_test_begin()
+
+ def on_test_end(self) -> None:
+ """Called when test ends."""
+ for callback in self._callbacks:
+ callback.on_test_end()
+
def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
for callback in self._callbacks: