summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/base.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
commitdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch)
tree1b5fc0d06952e13727e85c4f973a26d277068453 /src/training/trainer/callbacks/base.py
parente181195a699d7fa237f256d90ab4dedffc03d405 (diff)
new updates
Diffstat (limited to 'src/training/trainer/callbacks/base.py')
-rw-r--r--src/training/trainer/callbacks/base.py20
1 files changed, 19 insertions, 1 deletions
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py
index 8c7b085..500b642 100644
--- a/src/training/trainer/callbacks/base.py
+++ b/src/training/trainer/callbacks/base.py
@@ -62,6 +62,14 @@ class Callback:
"""Called at the end of an epoch."""
pass
+ def on_test_begin(self) -> None:
+ """Called at the beginning of test."""
+ pass
+
+ def on_test_end(self) -> None:
+ """Called at the end of test."""
+ pass
+
class CallbackList:
"""Container for abstracting away callback calls."""
@@ -92,7 +100,7 @@ class CallbackList:
def append(self, callback: Type[Callback]) -> None:
"""Append new callback to callback list."""
- self.callbacks.append(callback)
+ self._callbacks.append(callback)
def on_fit_begin(self) -> None:
"""Called when fit begins."""
@@ -104,6 +112,16 @@ class CallbackList:
for callback in self._callbacks:
callback.on_fit_end()
+ def on_test_begin(self) -> None:
+ """Called when test begins."""
+ for callback in self._callbacks:
+ callback.on_test_begin()
+
+ def on_test_end(self) -> None:
+ """Called when test ends."""
+ for callback in self._callbacks:
+ callback.on_test_end()
+
def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
for callback in self._callbacks: