diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /training/trainer/callbacks/early_stopping.py | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'training/trainer/callbacks/early_stopping.py')
-rw-r--r-- | training/trainer/callbacks/early_stopping.py | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/training/trainer/callbacks/early_stopping.py b/training/trainer/callbacks/early_stopping.py new file mode 100644 index 0000000..02b431f --- /dev/null +++ b/training/trainer/callbacks/early_stopping.py @@ -0,0 +1,108 @@ +"""Implements Early stopping for PyTorch model.""" +from typing import Dict, Union + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from training.trainer.callbacks import Callback + + +class EarlyStopping(Callback): + """Stops training when a monitored metric stops improving.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, + monitor: str = "val_loss", + min_delta: float = 0.0, + patience: int = 3, + mode: str = "auto", + ) -> None: + """Initializes the EarlyStopping callback. + + Args: + monitor (str): Description of parameter `monitor`. Defaults to "val_loss". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + patience (int): Description of parameter `patience`. Defaults to 3. + mode (str): Description of parameter `mode`. Defaults to "auto". + + """ + super().__init__() + self.monitor = monitor + self.patience = patience + self.min_delta = torch.tensor(min_delta) + self.mode = mode + self.wait_count = 0 + self.stopped_epoch = 0 + + if mode not in ["auto", "min", "max"]: + logger.warning( + f"EarlyStopping mode {mode} is unkown, fallback to auto mode." + ) + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." + ) + + self.torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = ( + self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf + ) + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_fit_begin(self) -> Union[torch.lt, torch.gt]: + """Reset the early stopping variables for reuse.""" + self.wait_count = 0 + self.stopped_epoch = 0 + self.best_score = ( + self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf + ) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Computes the early stop criterion.""" + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + self.wait_count = 0 + else: + self.wait_count += 1 + if self.wait_count >= self.patience: + self.stopped_epoch = epoch + self.model.stop_training = True + + def on_fit_end(self) -> None: + """Logs if early stopping was used.""" + if self.stopped_epoch > 0: + logger.info( + f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." + ) + + def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: + """Extracts the monitor value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" + + f"metrics are: {','.join(list(logs.keys()))}" + ) + return None + return torch.tensor(monitor_value) |