diff options
| author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 | 
|---|---|---|
| committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 | 
| commit | 1f459ba19422593de325983040e176f97cf4ffc0 (patch) | |
| tree | 89fef442d5dbe0c83253e9566d1762f0704f64e2 /src/training/trainer/callbacks | |
| parent | 95cbdf5bc1cc9639febda23c28d8f464c998b214 (diff) | |
A lot of stuff working :D. ResNet implemented!
Diffstat (limited to 'src/training/trainer/callbacks')
| -rw-r--r-- | src/training/trainer/callbacks/__init__.py | 21 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/base.py | 248 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/early_stopping.py | 108 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py | 97 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/progress_bar.py | 61 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 93 | 
6 files changed, 628 insertions, 0 deletions
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py new file mode 100644 index 0000000..5942276 --- /dev/null +++ b/src/training/trainer/callbacks/__init__.py @@ -0,0 +1,21 @@ +"""The callback modules used in the training script.""" +from .base import Callback, CallbackList, Checkpoint +from .early_stopping import EarlyStopping +from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR +from .progress_bar import ProgressBar +from .wandb_callbacks import WandbCallback, WandbImageLogger + +__all__ = [ +    "Callback", +    "CallbackList", +    "Checkpoint", +    "EarlyStopping", +    "WandbCallback", +    "WandbImageLogger", +    "CyclicLR", +    "MultiStepLR", +    "OneCycleLR", +    "ProgressBar", +    "ReduceLROnPlateau", +    "StepLR", +] diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py new file mode 100644 index 0000000..8df94f3 --- /dev/null +++ b/src/training/trainer/callbacks/base.py @@ -0,0 +1,248 @@ +"""Metaclass for callback functions.""" + +from enum import Enum +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch + +from text_recognizer.models import Model + + +class ModeKeys: +    """Mode keys for CallbackList.""" + +    TRAIN = "train" +    VALIDATION = "validation" + + +class Callback: +    """Metaclass for callbacks used in training.""" + +    def __init__(self) -> None: +        """Initializes the Callback instance.""" +        self.model = None + +    def set_model(self, model: Type[Model]) -> None: +        """Set the model.""" +        self.model = model + +    def on_fit_begin(self) -> None: +        """Called when fit begins.""" +        pass + +    def on_fit_end(self) -> None: +        """Called when fit ends.""" +        pass + +    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Called at the beginning of an epoch. Only used in training mode.""" +        pass + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch. Only used in training mode.""" +        pass + +    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the beginning of an epoch.""" +        pass + +    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch.""" +        pass + +    def on_validation_batch_begin( +        self, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Called at the beginning of an epoch.""" +        pass + +    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch.""" +        pass + + +class CallbackList: +    """Container for abstracting away callback calls.""" + +    mode_keys = ModeKeys() + +    def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: +        """Container for `Callback` instances. + +        This object wraps a list of `Callback` instances and allows them all to be +        called via a single end point. + +        Args: +            model (Type[Model]): A `Model` instance. +            callbacks (List[Callback]): List of `Callback` instances. Defaults to None. + +        """ + +        self._callbacks = callbacks or [] +        if model: +            self.set_model(model) + +    def set_model(self, model: Type[Model]) -> None: +        """Set the model for all callbacks.""" +        self.model = model +        for callback in self._callbacks: +            callback.set_model(model=self.model) + +    def append(self, callback: Type[Callback]) -> None: +        """Append new callback to callback list.""" +        self.callbacks.append(callback) + +    def on_fit_begin(self) -> None: +        """Called when fit begins.""" +        for callback in self._callbacks: +            callback.on_fit_begin() + +    def on_fit_end(self) -> None: +        """Called when fit ends.""" +        for callback in self._callbacks: +            callback.on_fit_end() + +    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Called at the beginning of an epoch.""" +        for callback in self._callbacks: +            callback.on_epoch_begin(epoch, logs) + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch.""" +        for callback in self._callbacks: +            callback.on_epoch_end(epoch, logs) + +    def _call_batch_hook( +        self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Helper function for all batch_{begin | end} methods.""" +        if hook == "begin": +            self._call_batch_begin_hook(mode, batch, logs) +        elif hook == "end": +            self._call_batch_end_hook(mode, batch, logs) +        else: +            raise ValueError(f"Unrecognized hook {hook}.") + +    def _call_batch_begin_hook( +        self, mode: str, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Helper function for all `on_*_batch_begin` methods.""" +        hook_name = f"on_{mode}_batch_begin" +        self._call_batch_hook_helper(hook_name, batch, logs) + +    def _call_batch_end_hook( +        self, mode: str, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Helper function for all `on_*_batch_end` methods.""" +        hook_name = f"on_{mode}_batch_end" +        self._call_batch_hook_helper(hook_name, batch, logs) + +    def _call_batch_hook_helper( +        self, hook_name: str, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Helper function for `on_*_batch_begin` methods.""" +        for callback in self._callbacks: +            hook = getattr(callback, hook_name) +            hook(batch, logs) + +    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the beginning of an epoch.""" +        self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) + +    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch.""" +        self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) + +    def on_validation_batch_begin( +        self, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Called at the beginning of an epoch.""" +        self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) + +    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch.""" +        self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) + +    def __iter__(self) -> iter: +        """Iter function for callback list.""" +        return iter(self._callbacks) + + +class Checkpoint(Callback): +    """Saving model parameters at the end of each epoch.""" + +    mode_dict = { +        "min": torch.lt, +        "max": torch.gt, +    } + +    def __init__( +        self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 +    ) -> None: +        """Monitors a quantity that will allow us to determine the best model weights. + +        Args: +            monitor (str): Name of the quantity to monitor. Defaults to "accuracy". +            mode (str): Description of parameter `mode`. Defaults to "auto". +            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + +        """ +        super().__init__() +        self.monitor = monitor +        self.mode = mode +        self.min_delta = torch.tensor(min_delta) + +        if mode not in ["auto", "min", "max"]: +            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + +            self.mode = "auto" + +        if self.mode == "auto": +            if "accuracy" in self.monitor: +                self.mode = "max" +            else: +                self.mode = "min" +            logger.debug( +                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." +            ) + +        torch_inf = torch.tensor(np.inf) +        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 +        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + +    @property +    def monitor_op(self) -> float: +        """Returns the comparison method.""" +        return self.mode_dict[self.mode] + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """Saves a checkpoint for the network parameters. + +        Args: +            epoch (int): The current epoch. +            logs (Dict): The log containing the monitored metrics. + +        """ +        current = self.get_monitor_value(logs) +        if current is None: +            return +        if self.monitor_op(current - self.min_delta, self.best_score): +            self.best_score = current +            is_best = True +        else: +            is_best = False + +        self.model.save_checkpoint(is_best, epoch, self.monitor) + +    def get_monitor_value(self, logs: Dict) -> Union[float, None]: +        """Extracts the monitored value.""" +        monitor_value = logs.get(self.monitor) +        if monitor_value is None: +            logger.warning( +                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" +                + f"metrics are: {','.join(list(logs.keys()))}" +            ) +            return None +        return monitor_value diff --git a/src/training/trainer/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py new file mode 100644 index 0000000..02b431f --- /dev/null +++ b/src/training/trainer/callbacks/early_stopping.py @@ -0,0 +1,108 @@ +"""Implements Early stopping for PyTorch model.""" +from typing import Dict, Union + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from training.trainer.callbacks import Callback + + +class EarlyStopping(Callback): +    """Stops training when a monitored metric stops improving.""" + +    mode_dict = { +        "min": torch.lt, +        "max": torch.gt, +    } + +    def __init__( +        self, +        monitor: str = "val_loss", +        min_delta: float = 0.0, +        patience: int = 3, +        mode: str = "auto", +    ) -> None: +        """Initializes the EarlyStopping callback. + +        Args: +            monitor (str): Description of parameter `monitor`. Defaults to "val_loss". +            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. +            patience (int): Description of parameter `patience`. Defaults to 3. +            mode (str): Description of parameter `mode`. Defaults to "auto". + +        """ +        super().__init__() +        self.monitor = monitor +        self.patience = patience +        self.min_delta = torch.tensor(min_delta) +        self.mode = mode +        self.wait_count = 0 +        self.stopped_epoch = 0 + +        if mode not in ["auto", "min", "max"]: +            logger.warning( +                f"EarlyStopping mode {mode} is unkown, fallback to auto mode." +            ) + +            self.mode = "auto" + +        if self.mode == "auto": +            if "accuracy" in self.monitor: +                self.mode = "max" +            else: +                self.mode = "min" +            logger.debug( +                f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." +            ) + +        self.torch_inf = torch.tensor(np.inf) +        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 +        self.best_score = ( +            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf +        ) + +    @property +    def monitor_op(self) -> float: +        """Returns the comparison method.""" +        return self.mode_dict[self.mode] + +    def on_fit_begin(self) -> Union[torch.lt, torch.gt]: +        """Reset the early stopping variables for reuse.""" +        self.wait_count = 0 +        self.stopped_epoch = 0 +        self.best_score = ( +            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf +        ) + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """Computes the early stop criterion.""" +        current = self.get_monitor_value(logs) +        if current is None: +            return +        if self.monitor_op(current - self.min_delta, self.best_score): +            self.best_score = current +            self.wait_count = 0 +        else: +            self.wait_count += 1 +            if self.wait_count >= self.patience: +                self.stopped_epoch = epoch +                self.model.stop_training = True + +    def on_fit_end(self) -> None: +        """Logs if early stopping was used.""" +        if self.stopped_epoch > 0: +            logger.info( +                f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." +            ) + +    def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: +        """Extracts the monitor value.""" +        monitor_value = logs.get(self.monitor) +        if monitor_value is None: +            logger.warning( +                f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" +                + f"metrics are: {','.join(list(logs.keys()))}" +            ) +            return None +        return torch.tensor(monitor_value) diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py new file mode 100644 index 0000000..ba2226a --- /dev/null +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -0,0 +1,97 @@ +"""Callbacks for learning rate schedulers.""" +from typing import Callable, Dict, List, Optional, Type + +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class StepLR(Callback): +    """Callback for StepLR.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every epoch.""" +        self.lr_scheduler.step() + + +class MultiStepLR(Callback): +    """Callback for MultiStepLR.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every epoch.""" +        self.lr_scheduler.step() + + +class ReduceLROnPlateau(Callback): +    """Callback for ReduceLROnPlateau.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every epoch.""" +        val_loss = logs["val_loss"] +        self.lr_scheduler.step(val_loss) + + +class CyclicLR(Callback): +    """Callback for CyclicLR.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every training batch.""" +        self.lr_scheduler.step() + + +class OneCycleLR(Callback): +    """Callback for OneCycleLR.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every training batch.""" +        self.lr_scheduler.step() diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py new file mode 100644 index 0000000..1970747 --- /dev/null +++ b/src/training/trainer/callbacks/progress_bar.py @@ -0,0 +1,61 @@ +"""Progress bar callback for the training loop.""" +from typing import Dict, Optional + +from tqdm import tqdm +from training.trainer.callbacks import Callback + + +class ProgressBar(Callback): +    """A TQDM progress bar for the training loop.""" + +    def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: +        """Initializes the tqdm callback.""" +        self.epochs = epochs +        self.log_batch_frequency = log_batch_frequency +        self.progress_bar = None +        self.val_metrics = {} + +    def _configure_progress_bar(self) -> None: +        """Configures the tqdm progress bar with custom bar format.""" +        self.progress_bar = tqdm( +            total=len(self.model.data_loaders["train"]), +            leave=True, +            unit="step", +            mininterval=self.log_batch_frequency, +            bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", +        ) + +    def _key_abbreviations(self, logs: Dict) -> Dict: +        """Changes the length of keys, so that the progress bar fits better.""" + +        def rename(key: str) -> str: +            """Renames accuracy to acc.""" +            return key.replace("accuracy", "acc") + +        return {rename(key): value for key, value in logs.items()} + +    def on_fit_begin(self) -> None: +        """Creates a tqdm progress bar.""" +        self._configure_progress_bar() + +    def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: +        """Updates the description with the current epoch.""" +        self.progress_bar.reset() +        self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """At the end of each epoch, the validation metrics are updated to the progress bar.""" +        self.val_metrics = logs +        self.progress_bar.set_postfix(**self._key_abbreviations(logs)) +        self.progress_bar.update() + +    def on_train_batch_end(self, batch: int, logs: Dict) -> None: +        """Updates the progress bar for each training step.""" +        if self.val_metrics: +            logs.update(self.val_metrics) +        self.progress_bar.set_postfix(**self._key_abbreviations(logs)) +        self.progress_bar.update() + +    def on_fit_end(self) -> None: +        """Closes the tqdm progress bar.""" +        self.progress_bar.close() diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py new file mode 100644 index 0000000..e44c745 --- /dev/null +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -0,0 +1,93 @@ +"""Callback for W&B.""" +from typing import Callable, Dict, List, Optional, Type + +import numpy as np +from torchvision.transforms import Compose, ToTensor +from training.trainer.callbacks import Callback +import wandb + +from text_recognizer.datasets import Transpose +from text_recognizer.models.base import Model + + +class WandbCallback(Callback): +    """A custom W&B metric logger for the trainer.""" + +    def __init__(self, log_batch_frequency: int = None) -> None: +        """Short summary. + +        Args: +            log_batch_frequency (int): If None, metrics will be logged every epoch. +                If set to an integer, callback will log every metrics every log_batch_frequency. + +        """ +        super().__init__() +        self.log_batch_frequency = log_batch_frequency + +    def _on_batch_end(self, batch: int, logs: Dict) -> None: +        if self.log_batch_frequency and batch % self.log_batch_frequency == 0: +            wandb.log(logs, commit=True) + +    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Logs training metrics.""" +        if logs is not None: +            self._on_batch_end(batch, logs) + +    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Logs validation metrics.""" +        if logs is not None: +            self._on_batch_end(batch, logs) + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """Logs at epoch end.""" +        wandb.log(logs, commit=True) + + +class WandbImageLogger(Callback): +    """Custom W&B callback for image logging.""" + +    def __init__( +        self, +        example_indices: Optional[List] = None, +        num_examples: int = 4, +        transfroms: Optional[Callable] = None, +    ) -> None: +        """Initializes the WandbImageLogger with the model to train. + +        Args: +            example_indices (Optional[List]): Indices for validation images. Defaults to None. +            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. +            transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to +                None. + +        """ + +        super().__init__() +        self.example_indices = example_indices +        self.num_examples = num_examples +        self.transfroms = transfroms +        if self.transfroms is None: +            self.transforms = Compose([Transpose()]) + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and extracts validation images from the dataset.""" +        self.model = model +        data_loader = self.model.data_loaders["val"] +        if self.example_indices is None: +            self.example_indices = np.random.randint( +                0, len(data_loader.dataset.data), self.num_examples +            ) +        self.val_images = data_loader.dataset.data[self.example_indices] +        self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """Get network predictions on validation images.""" +        images = [] +        for i, image in enumerate(self.val_images): +            image = self.transforms(image) +            pred, conf = self.model.predict_on_image(image) +            ground_truth = self.model.mapper(int(self.val_targets[i])) +            caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" +            images.append(wandb.Image(image, caption=caption)) + +        wandb.log({"examples": images}, commit=False)  |