diff options
Diffstat (limited to 'src/training/trainer/callbacks')
-rw-r--r-- | src/training/trainer/callbacks/__init__.py | 15 | ||||
-rw-r--r-- | src/training/trainer/callbacks/base.py | 78 | ||||
-rw-r--r-- | src/training/trainer/callbacks/checkpoint.py | 95 | ||||
-rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py | 52 | ||||
-rw-r--r-- | src/training/trainer/callbacks/progress_bar.py | 19 | ||||
-rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 32 |
6 files changed, 190 insertions, 101 deletions
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index 5942276..c81e4bf 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -1,7 +1,16 @@ """The callback modules used in the training script.""" -from .base import Callback, CallbackList, Checkpoint +from .base import Callback, CallbackList +from .checkpoint import Checkpoint from .early_stopping import EarlyStopping -from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR +from .lr_schedulers import ( + CosineAnnealingLR, + CyclicLR, + MultiStepLR, + OneCycleLR, + ReduceLROnPlateau, + StepLR, + SWA, +) from .progress_bar import ProgressBar from .wandb_callbacks import WandbCallback, WandbImageLogger @@ -9,6 +18,7 @@ __all__ = [ "Callback", "CallbackList", "Checkpoint", + "CosineAnnealingLR", "EarlyStopping", "WandbCallback", "WandbImageLogger", @@ -18,4 +28,5 @@ __all__ = [ "ProgressBar", "ReduceLROnPlateau", "StepLR", + "SWA", ] diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py index 8df94f3..8c7b085 100644 --- a/src/training/trainer/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -168,81 +168,3 @@ class CallbackList: def __iter__(self) -> iter: """Iter function for callback list.""" return iter(self._callbacks) - - -class Checkpoint(Callback): - """Saving model parameters at the end of each epoch.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 - ) -> None: - """Monitors a quantity that will allow us to determine the best model weights. - - Args: - monitor (str): Name of the quantity to monitor. Defaults to "accuracy". - mode (str): Description of parameter `mode`. Defaults to "auto". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - - """ - super().__init__() - self.monitor = monitor - self.mode = mode - self.min_delta = torch.tensor(min_delta) - - if mode not in ["auto", "min", "max"]: - logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." - ) - - torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Saves a checkpoint for the network parameters. - - Args: - epoch (int): The current epoch. - logs (Dict): The log containing the monitored metrics. - - """ - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - is_best = True - else: - is_best = False - - self.model.save_checkpoint(is_best, epoch, self.monitor) - - def get_monitor_value(self, logs: Dict) -> Union[float, None]: - """Extracts the monitored value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" - + f"metrics are: {','.join(list(logs.keys()))}" - ) - return None - return monitor_value diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py new file mode 100644 index 0000000..6fe06d3 --- /dev/null +++ b/src/training/trainer/callbacks/checkpoint.py @@ -0,0 +1,95 @@ +"""Callback checkpoint for training models.""" +from enum import Enum +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class Checkpoint(Callback): + """Saving model parameters at the end of each epoch.""" + + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, + checkpoint_path: Path, + monitor: str = "accuracy", + mode: str = "auto", + min_delta: float = 0.0, + ) -> None: + """Monitors a quantity that will allow us to determine the best model weights. + + Args: + checkpoint_path (Path): Path to the experiment with the checkpoint. + monitor (str): Name of the quantity to monitor. Defaults to "accuracy". + mode (str): Description of parameter `mode`. Defaults to "auto". + min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + + """ + super().__init__() + self.checkpoint_path = checkpoint_path + self.monitor = monitor + self.mode = mode + self.min_delta = torch.tensor(min_delta) + + if mode not in ["auto", "min", "max"]: + logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + + self.mode = "auto" + + if self.mode == "auto": + if "accuracy" in self.monitor: + self.mode = "max" + else: + self.mode = "min" + logger.debug( + f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." + ) + + torch_inf = torch.tensor(np.inf) + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + + @property + def monitor_op(self) -> float: + """Returns the comparison method.""" + return self.mode_dict[self.mode] + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Saves a checkpoint for the network parameters. + + Args: + epoch (int): The current epoch. + logs (Dict): The log containing the monitored metrics. + + """ + current = self.get_monitor_value(logs) + if current is None: + return + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + is_best = True + else: + is_best = False + + self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor) + + def get_monitor_value(self, logs: Dict) -> Union[float, None]: + """Extracts the monitored value.""" + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logger.warning( + f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" + + f" metrics are: {','.join(list(logs.keys()))}" + ) + return None + return monitor_value diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index ba2226a..bb41d2d 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -1,6 +1,7 @@ """Callbacks for learning rate schedulers.""" from typing import Callable, Dict, List, Optional, Type +from torch.optim.swa_utils import update_bn from training.trainer.callbacks import Callback from text_recognizer.models import Model @@ -95,3 +96,54 @@ class OneCycleLR(Callback): def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" self.lr_scheduler.step() + + +class CosineAnnealingLR(Callback): + """Callback for Cosine Annealing.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class SWA(Callback): + """Stochastic Weight Averaging callback.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.swa_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.swa_start = self.model.swa_start + self.swa_scheduler = self.model.lr_scheduler + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + if epoch > self.swa_start: + self.model.swa_network.update_parameters(self.model.network) + self.swa_scheduler.step() + else: + self.lr_scheduler.step() + + def on_fit_end(self) -> None: + """Update batch norm statistics for the swa model at the end of training.""" + if self.model.swa_network: + update_bn( + self.model.val_dataloader(), + self.model.swa_network, + device=self.model.device, + ) diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py index 1970747..7829fa0 100644 --- a/src/training/trainer/callbacks/progress_bar.py +++ b/src/training/trainer/callbacks/progress_bar.py @@ -18,11 +18,11 @@ class ProgressBar(Callback): def _configure_progress_bar(self) -> None: """Configures the tqdm progress bar with custom bar format.""" self.progress_bar = tqdm( - total=len(self.model.data_loaders["train"]), - leave=True, - unit="step", + total=len(self.model.train_dataloader()), + leave=False, + unit="steps", mininterval=self.log_batch_frequency, - bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", + bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", ) def _key_abbreviations(self, logs: Dict) -> Dict: @@ -34,13 +34,16 @@ class ProgressBar(Callback): return {rename(key): value for key, value in logs.items()} - def on_fit_begin(self) -> None: - """Creates a tqdm progress bar.""" - self._configure_progress_bar() + # def on_fit_begin(self) -> None: + # """Creates a tqdm progress bar.""" + # self._configure_progress_bar() def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: """Updates the description with the current epoch.""" - self.progress_bar.reset() + if epoch == 1: + self._configure_progress_bar() + else: + self.progress_bar.reset() self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") def on_epoch_end(self, epoch: int, logs: Dict) -> None: diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index e44c745..6643a44 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -2,7 +2,8 @@ from typing import Callable, Dict, List, Optional, Type import numpy as np -from torchvision.transforms import Compose, ToTensor +import torch +from torchvision.transforms import ToTensor from training.trainer.callbacks import Callback import wandb @@ -50,43 +51,48 @@ class WandbImageLogger(Callback): self, example_indices: Optional[List] = None, num_examples: int = 4, - transfroms: Optional[Callable] = None, + use_transpose: Optional[bool] = False, ) -> None: """Initializes the WandbImageLogger with the model to train. Args: example_indices (Optional[List]): Indices for validation images. Defaults to None. num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to - None. + use_transpose (Optional[bool]): Use transpose on image or not. Defaults to False. """ super().__init__() self.example_indices = example_indices self.num_examples = num_examples - self.transfroms = transfroms - if self.transfroms is None: - self.transforms = Compose([Transpose()]) + self.transpose = Transpose() if use_transpose else None def set_model(self, model: Type[Model]) -> None: """Sets the model and extracts validation images from the dataset.""" self.model = model - data_loader = self.model.data_loaders["val"] if self.example_indices is None: self.example_indices = np.random.randint( - 0, len(data_loader.dataset.data), self.num_examples + 0, len(self.model.val_dataset), self.num_examples ) - self.val_images = data_loader.dataset.data[self.example_indices] - self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() + self.val_images = self.model.val_dataset.dataset.data[self.example_indices] + self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices] + self.val_targets = self.val_targets.tolist() def on_epoch_end(self, epoch: int, logs: Dict) -> None: """Get network predictions on validation images.""" images = [] for i, image in enumerate(self.val_images): - image = self.transforms(image) + image = self.transpose(image) if self.transpose is not None else image pred, conf = self.model.predict_on_image(image) - ground_truth = self.model.mapper(int(self.val_targets[i])) + if isinstance(self.val_targets[i], list): + ground_truth = "".join( + [ + self.model.mapper(int(target_index)) + for target_index in self.val_targets[i] + ] + ).rstrip("_") + else: + ground_truth = self.val_targets[i] caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" images.append(wandb.Image(image, caption=caption)) |