From dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 8 Nov 2020 14:54:44 +0100 Subject: new updates --- src/training/trainer/callbacks/base.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) (limited to 'src/training/trainer/callbacks/base.py') diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py index 8c7b085..500b642 100644 --- a/src/training/trainer/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -62,6 +62,14 @@ class Callback: """Called at the end of an epoch.""" pass + def on_test_begin(self) -> None: + """Called at the beginning of test.""" + pass + + def on_test_end(self) -> None: + """Called at the end of test.""" + pass + class CallbackList: """Container for abstracting away callback calls.""" @@ -92,7 +100,7 @@ class CallbackList: def append(self, callback: Type[Callback]) -> None: """Append new callback to callback list.""" - self.callbacks.append(callback) + self._callbacks.append(callback) def on_fit_begin(self) -> None: """Called when fit begins.""" @@ -104,6 +112,16 @@ class CallbackList: for callback in self._callbacks: callback.on_fit_end() + def on_test_begin(self) -> None: + """Called when test begins.""" + for callback in self._callbacks: + callback.on_test_begin() + + def on_test_end(self) -> None: + """Called when test ends.""" + for callback in self._callbacks: + callback.on_test_end() + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch.""" for callback in self._callbacks: -- cgit v1.2.3-70-g09d2