From f473456c19558aaf8552df97a51d4e18cc69dfa8 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Wed, 22 Jul 2020 23:18:08 +0200 Subject: Working training loop and testing of trained CharacterModel. --- src/training/callbacks/__init__.py | 1 + src/training/callbacks/base.py | 101 +++++++++++++++++++++++++++++++ src/training/callbacks/early_stopping.py | 1 + 3 files changed, 103 insertions(+) create mode 100644 src/training/callbacks/__init__.py create mode 100644 src/training/callbacks/base.py create mode 100644 src/training/callbacks/early_stopping.py (limited to 'src/training/callbacks') diff --git a/src/training/callbacks/__init__.py b/src/training/callbacks/__init__.py new file mode 100644 index 0000000..868d739 --- /dev/null +++ b/src/training/callbacks/__init__.py @@ -0,0 +1 @@ +"""TBC.""" diff --git a/src/training/callbacks/base.py b/src/training/callbacks/base.py new file mode 100644 index 0000000..d80a1e5 --- /dev/null +++ b/src/training/callbacks/base.py @@ -0,0 +1,101 @@ +"""Metaclass for callback functions.""" + +from abc import ABC +from typing import Callable, List, Type + + +class Callback(ABC): + """Metaclass for callbacks used in training.""" + + def on_fit_begin(self) -> None: + """Called when fit begins.""" + pass + + def on_fit_end(self) -> None: + """Called when fit ends.""" + pass + + def on_train_epoch_begin(self) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_train_epoch_end(self) -> None: + """Called at the end of an epoch.""" + pass + + def on_val_epoch_begin(self) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_val_epoch_end(self) -> None: + """Called at the end of an epoch.""" + pass + + def on_train_batch_begin(self) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_train_batch_end(self) -> None: + """Called at the end of an epoch.""" + pass + + def on_val_batch_begin(self) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_val_batch_end(self) -> None: + """Called at the end of an epoch.""" + pass + + +class CallbackList: + """Container for abstracting away callback calls.""" + + def __init__(self, callbacks: List[Callable] = None) -> None: + """TBC.""" + self._callbacks = callbacks if callbacks is not None else [] + + def append(self, callback: Type[Callback]) -> None: + """Append new callback to callback list.""" + self.callbacks.append(callback) + + def on_fit_begin(self) -> None: + """Called when fit begins.""" + for _ in self._callbacks: + pass + + def on_fit_end(self) -> None: + """Called when fit ends.""" + pass + + def on_train_epoch_begin(self) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_train_epoch_end(self) -> None: + """Called at the end of an epoch.""" + pass + + def on_val_epoch_begin(self) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_val_epoch_end(self) -> None: + """Called at the end of an epoch.""" + pass + + def on_train_batch_begin(self) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_train_batch_end(self) -> None: + """Called at the end of an epoch.""" + pass + + def on_val_batch_begin(self) -> None: + """Called at the beginning of an epoch.""" + pass + + def on_val_batch_end(self) -> None: + """Called at the end of an epoch.""" + pass diff --git a/src/training/callbacks/early_stopping.py b/src/training/callbacks/early_stopping.py new file mode 100644 index 0000000..4da0e85 --- /dev/null +++ b/src/training/callbacks/early_stopping.py @@ -0,0 +1 @@ +"""Implements Early stopping for PyTorch model.""" -- cgit v1.2.3-70-g09d2