diff options
Diffstat (limited to 'src/training/trainer/callbacks')
| -rw-r--r-- | src/training/trainer/callbacks/__init__.py | 15 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/base.py | 78 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/checkpoint.py | 95 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py | 52 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/progress_bar.py | 19 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 32 | 
6 files changed, 190 insertions, 101 deletions
| diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index 5942276..c81e4bf 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -1,7 +1,16 @@  """The callback modules used in the training script.""" -from .base import Callback, CallbackList, Checkpoint +from .base import Callback, CallbackList +from .checkpoint import Checkpoint  from .early_stopping import EarlyStopping -from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR +from .lr_schedulers import ( +    CosineAnnealingLR, +    CyclicLR, +    MultiStepLR, +    OneCycleLR, +    ReduceLROnPlateau, +    StepLR, +    SWA, +)  from .progress_bar import ProgressBar  from .wandb_callbacks import WandbCallback, WandbImageLogger @@ -9,6 +18,7 @@ __all__ = [      "Callback",      "CallbackList",      "Checkpoint", +    "CosineAnnealingLR",      "EarlyStopping",      "WandbCallback",      "WandbImageLogger", @@ -18,4 +28,5 @@ __all__ = [      "ProgressBar",      "ReduceLROnPlateau",      "StepLR", +    "SWA",  ] diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py index 8df94f3..8c7b085 100644 --- a/src/training/trainer/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -168,81 +168,3 @@ class CallbackList:      def __iter__(self) -> iter:          """Iter function for callback list."""          return iter(self._callbacks) - - -class Checkpoint(Callback): -    """Saving model parameters at the end of each epoch.""" - -    mode_dict = { -        "min": torch.lt, -        "max": torch.gt, -    } - -    def __init__( -        self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 -    ) -> None: -        """Monitors a quantity that will allow us to determine the best model weights. - -        Args: -            monitor (str): Name of the quantity to monitor. Defaults to "accuracy". -            mode (str): Description of parameter `mode`. Defaults to "auto". -            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - -        """ -        super().__init__() -        self.monitor = monitor -        self.mode = mode -        self.min_delta = torch.tensor(min_delta) - -        if mode not in ["auto", "min", "max"]: -            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - -            self.mode = "auto" - -        if self.mode == "auto": -            if "accuracy" in self.monitor: -                self.mode = "max" -            else: -                self.mode = "min" -            logger.debug( -                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." -            ) - -        torch_inf = torch.tensor(np.inf) -        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 -        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - -    @property -    def monitor_op(self) -> float: -        """Returns the comparison method.""" -        return self.mode_dict[self.mode] - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Saves a checkpoint for the network parameters. - -        Args: -            epoch (int): The current epoch. -            logs (Dict): The log containing the monitored metrics. - -        """ -        current = self.get_monitor_value(logs) -        if current is None: -            return -        if self.monitor_op(current - self.min_delta, self.best_score): -            self.best_score = current -            is_best = True -        else: -            is_best = False - -        self.model.save_checkpoint(is_best, epoch, self.monitor) - -    def get_monitor_value(self, logs: Dict) -> Union[float, None]: -        """Extracts the monitored value.""" -        monitor_value = logs.get(self.monitor) -        if monitor_value is None: -            logger.warning( -                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" -                + f"metrics are: {','.join(list(logs.keys()))}" -            ) -            return None -        return monitor_value diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py new file mode 100644 index 0000000..6fe06d3 --- /dev/null +++ b/src/training/trainer/callbacks/checkpoint.py @@ -0,0 +1,95 @@ +"""Callback checkpoint for training models.""" +from enum import Enum +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class Checkpoint(Callback): +    """Saving model parameters at the end of each epoch.""" + +    mode_dict = { +        "min": torch.lt, +        "max": torch.gt, +    } + +    def __init__( +        self, +        checkpoint_path: Path, +        monitor: str = "accuracy", +        mode: str = "auto", +        min_delta: float = 0.0, +    ) -> None: +        """Monitors a quantity that will allow us to determine the best model weights. + +        Args: +            checkpoint_path (Path): Path to the experiment with the checkpoint. +            monitor (str): Name of the quantity to monitor. Defaults to "accuracy". +            mode (str): Description of parameter `mode`. Defaults to "auto". +            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + +        """ +        super().__init__() +        self.checkpoint_path = checkpoint_path +        self.monitor = monitor +        self.mode = mode +        self.min_delta = torch.tensor(min_delta) + +        if mode not in ["auto", "min", "max"]: +            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + +            self.mode = "auto" + +        if self.mode == "auto": +            if "accuracy" in self.monitor: +                self.mode = "max" +            else: +                self.mode = "min" +            logger.debug( +                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." +            ) + +        torch_inf = torch.tensor(np.inf) +        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 +        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + +    @property +    def monitor_op(self) -> float: +        """Returns the comparison method.""" +        return self.mode_dict[self.mode] + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """Saves a checkpoint for the network parameters. + +        Args: +            epoch (int): The current epoch. +            logs (Dict): The log containing the monitored metrics. + +        """ +        current = self.get_monitor_value(logs) +        if current is None: +            return +        if self.monitor_op(current - self.min_delta, self.best_score): +            self.best_score = current +            is_best = True +        else: +            is_best = False + +        self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor) + +    def get_monitor_value(self, logs: Dict) -> Union[float, None]: +        """Extracts the monitored value.""" +        monitor_value = logs.get(self.monitor) +        if monitor_value is None: +            logger.warning( +                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" +                + f" metrics are: {','.join(list(logs.keys()))}" +            ) +            return None +        return monitor_value diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index ba2226a..bb41d2d 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -1,6 +1,7 @@  """Callbacks for learning rate schedulers."""  from typing import Callable, Dict, List, Optional, Type +from torch.optim.swa_utils import update_bn  from training.trainer.callbacks import Callback  from text_recognizer.models import Model @@ -95,3 +96,54 @@ class OneCycleLR(Callback):      def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:          """Takes a step at the end of every training batch."""          self.lr_scheduler.step() + + +class CosineAnnealingLR(Callback): +    """Callback for Cosine Annealing.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every epoch.""" +        self.lr_scheduler.step() + + +class SWA(Callback): +    """Stochastic Weight Averaging callback.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.swa_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.swa_start = self.model.swa_start +        self.swa_scheduler = self.model.lr_scheduler +        self.lr_scheduler = self.model.lr_scheduler + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every training batch.""" +        if epoch > self.swa_start: +            self.model.swa_network.update_parameters(self.model.network) +            self.swa_scheduler.step() +        else: +            self.lr_scheduler.step() + +    def on_fit_end(self) -> None: +        """Update batch norm statistics for the swa model at the end of training.""" +        if self.model.swa_network: +            update_bn( +                self.model.val_dataloader(), +                self.model.swa_network, +                device=self.model.device, +            ) diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py index 1970747..7829fa0 100644 --- a/src/training/trainer/callbacks/progress_bar.py +++ b/src/training/trainer/callbacks/progress_bar.py @@ -18,11 +18,11 @@ class ProgressBar(Callback):      def _configure_progress_bar(self) -> None:          """Configures the tqdm progress bar with custom bar format."""          self.progress_bar = tqdm( -            total=len(self.model.data_loaders["train"]), -            leave=True, -            unit="step", +            total=len(self.model.train_dataloader()), +            leave=False, +            unit="steps",              mininterval=self.log_batch_frequency, -            bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", +            bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",          )      def _key_abbreviations(self, logs: Dict) -> Dict: @@ -34,13 +34,16 @@ class ProgressBar(Callback):          return {rename(key): value for key, value in logs.items()} -    def on_fit_begin(self) -> None: -        """Creates a tqdm progress bar.""" -        self._configure_progress_bar() +    # def on_fit_begin(self) -> None: +    #     """Creates a tqdm progress bar.""" +    #     self._configure_progress_bar()      def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:          """Updates the description with the current epoch.""" -        self.progress_bar.reset() +        if epoch == 1: +            self._configure_progress_bar() +        else: +            self.progress_bar.reset()          self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")      def on_epoch_end(self, epoch: int, logs: Dict) -> None: diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index e44c745..6643a44 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -2,7 +2,8 @@  from typing import Callable, Dict, List, Optional, Type  import numpy as np -from torchvision.transforms import Compose, ToTensor +import torch +from torchvision.transforms import ToTensor  from training.trainer.callbacks import Callback  import wandb @@ -50,43 +51,48 @@ class WandbImageLogger(Callback):          self,          example_indices: Optional[List] = None,          num_examples: int = 4, -        transfroms: Optional[Callable] = None, +        use_transpose: Optional[bool] = False,      ) -> None:          """Initializes the WandbImageLogger with the model to train.          Args:              example_indices (Optional[List]): Indices for validation images. Defaults to None.              num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. -            transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to -                None. +            use_transpose (Optional[bool]): Use transpose on image or not. Defaults to False.          """          super().__init__()          self.example_indices = example_indices          self.num_examples = num_examples -        self.transfroms = transfroms -        if self.transfroms is None: -            self.transforms = Compose([Transpose()]) +        self.transpose = Transpose() if use_transpose else None      def set_model(self, model: Type[Model]) -> None:          """Sets the model and extracts validation images from the dataset."""          self.model = model -        data_loader = self.model.data_loaders["val"]          if self.example_indices is None:              self.example_indices = np.random.randint( -                0, len(data_loader.dataset.data), self.num_examples +                0, len(self.model.val_dataset), self.num_examples              ) -        self.val_images = data_loader.dataset.data[self.example_indices] -        self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() +        self.val_images = self.model.val_dataset.dataset.data[self.example_indices] +        self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices] +        self.val_targets = self.val_targets.tolist()      def on_epoch_end(self, epoch: int, logs: Dict) -> None:          """Get network predictions on validation images."""          images = []          for i, image in enumerate(self.val_images): -            image = self.transforms(image) +            image = self.transpose(image) if self.transpose is not None else image              pred, conf = self.model.predict_on_image(image) -            ground_truth = self.model.mapper(int(self.val_targets[i])) +            if isinstance(self.val_targets[i], list): +                ground_truth = "".join( +                    [ +                        self.model.mapper(int(target_index)) +                        for target_index in self.val_targets[i] +                    ] +                ).rstrip("_") +            else: +                ground_truth = self.val_targets[i]              caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"              images.append(wandb.Image(image, caption=caption)) |