diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /src/training/trainer | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'src/training/trainer')
-rw-r--r-- | src/training/trainer/__init__.py | 2 | ||||
-rw-r--r-- | src/training/trainer/callbacks/__init__.py | 29 | ||||
-rw-r--r-- | src/training/trainer/callbacks/base.py | 188 | ||||
-rw-r--r-- | src/training/trainer/callbacks/checkpoint.py | 95 | ||||
-rw-r--r-- | src/training/trainer/callbacks/early_stopping.py | 108 | ||||
-rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py | 77 | ||||
-rw-r--r-- | src/training/trainer/callbacks/progress_bar.py | 65 | ||||
-rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 261 | ||||
-rw-r--r-- | src/training/trainer/train.py | 325 | ||||
-rw-r--r-- | src/training/trainer/util.py | 28 |
10 files changed, 0 insertions, 1178 deletions
diff --git a/src/training/trainer/__init__.py b/src/training/trainer/__init__.py deleted file mode 100644 index de41bfb..0000000 --- a/src/training/trainer/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Trainer modules.""" -from .train import Trainer diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py deleted file mode 100644 index 80c4177..0000000 --- a/src/training/trainer/callbacks/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -"""The callback modules used in the training script.""" -from .base import Callback, CallbackList -from .checkpoint import Checkpoint -from .early_stopping import EarlyStopping -from .lr_schedulers import ( - LRScheduler, - SWA, -) -from .progress_bar import ProgressBar -from .wandb_callbacks import ( - WandbCallback, - WandbImageLogger, - WandbReconstructionLogger, - WandbSegmentationLogger, -) - -__all__ = [ - "Callback", - "CallbackList", - "Checkpoint", - "EarlyStopping", - "LRScheduler", - "WandbCallback", - "WandbImageLogger", - "WandbReconstructionLogger", - "WandbSegmentationLogger", - "ProgressBar", - "SWA", -] diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py deleted file mode 100644 index 500b642..0000000 --- a/src/training/trainer/callbacks/base.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Metaclass for callback functions.""" - -from enum import Enum -from typing import Callable, Dict, List, Optional, Type, Union - -from loguru import logger -import numpy as np -import torch - -from text_recognizer.models import Model - - -class ModeKeys: - """Mode keys for CallbackList.""" - - TRAIN = "train" - VALIDATION = "validation" - - -class Callback: - """Metaclass for callbacks used in training.""" - - def __init__(self) -> None: - """Initializes the Callback instance.""" - self.model = None - - def set_model(self, model: Type[Model]) -> None: - """Set the model.""" - self.model = model - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - pass - - def on_fit_end(self) -> None: - """Called when fit ends.""" - pass - - def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch. Only used in training mode.""" - pass - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch. Only used in training mode.""" - pass - - def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - pass - - def on_validation_batch_begin( - self, batch: int, logs: Optional[Dict] = None - ) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - pass - - def on_test_begin(self) -> None: - """Called at the beginning of test.""" - pass - - def on_test_end(self) -> None: - """Called at the end of test.""" - pass - - -class CallbackList: - """Container for abstracting away callback calls.""" - - mode_keys = ModeKeys() - - def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: - """Container for `Callback` instances. - - This object wraps a list of `Callback` instances and allows them all to be - called via a single end point. - - Args: - model (Type[Model]): A `Model` instance. - callbacks (List[Callback]): List of `Callback` instances. Defaults to None. - - """ - - self._callbacks = callbacks or [] - if model: - self.set_model(model) - - def set_model(self, model: Type[Model]) -> None: - """Set the model for all callbacks.""" - self.model = model - for callback in self._callbacks: - callback.set_model(model=self.model) - - def append(self, callback: Type[Callback]) -> None: - """Append new callback to callback list.""" - self._callbacks.append(callback) - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - for callback in self._callbacks: - callback.on_fit_begin() - - def on_fit_end(self) -> None: - """Called when fit ends.""" - for callback in self._callbacks: - callback.on_fit_end() - - def on_test_begin(self) -> None: - """Called when test begins.""" - for callback in self._callbacks: - callback.on_test_begin() - - def on_test_end(self) -> None: - """Called when test ends.""" - for callback in self._callbacks: - callback.on_test_end() - - def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_begin(epoch, logs) - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_end(epoch, logs) - - def _call_batch_hook( - self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for all batch_{begin | end} methods.""" - if hook == "begin": - self._call_batch_begin_hook(mode, batch, logs) - elif hook == "end": - self._call_batch_end_hook(mode, batch, logs) - else: - raise ValueError(f"Unrecognized hook {hook}.") - - def _call_batch_begin_hook( - self, mode: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for all `on_*_batch_begin` methods.""" - hook_name = f"on_{mode}_batch_begin" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_end_hook( - self, mode: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for all `on_*_batch_end` methods.""" - hook_name = f"on_{mode}_batch_end" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_hook_helper( - self, hook_name: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for `on_*_batch_begin` methods.""" - for callback in self._callbacks: - hook = getattr(callback, hook_name) - hook(batch, logs) - - def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) - - def on_validation_batch_begin( - self, batch: int, logs: Optional[Dict] = None - ) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) - - def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) - - def __iter__(self) -> iter: - """Iter function for callback list.""" - return iter(self._callbacks) diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py deleted file mode 100644 index a54e0a9..0000000 --- a/src/training/trainer/callbacks/checkpoint.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Callback checkpoint for training models.""" -from enum import Enum -from pathlib import Path -from typing import Callable, Dict, List, Optional, Type, Union - -from loguru import logger -import numpy as np -import torch -from training.trainer.callbacks import Callback - -from text_recognizer.models import Model - - -class Checkpoint(Callback): - """Saving model parameters at the end of each epoch.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, - checkpoint_path: Union[str, Path], - monitor: str = "accuracy", - mode: str = "auto", - min_delta: float = 0.0, - ) -> None: - """Monitors a quantity that will allow us to determine the best model weights. - - Args: - checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. - monitor (str): Name of the quantity to monitor. Defaults to "accuracy". - mode (str): Description of parameter `mode`. Defaults to "auto". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - - """ - super().__init__() - self.checkpoint_path = Path(checkpoint_path) - self.monitor = monitor - self.mode = mode - self.min_delta = torch.tensor(min_delta) - - if mode not in ["auto", "min", "max"]: - logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." - ) - - torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Saves a checkpoint for the network parameters. - - Args: - epoch (int): The current epoch. - logs (Dict): The log containing the monitored metrics. - - """ - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - is_best = True - else: - is_best = False - - self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor) - - def get_monitor_value(self, logs: Dict) -> Union[float, None]: - """Extracts the monitored value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" - + f" metrics are: {','.join(list(logs.keys()))}" - ) - return None - return monitor_value diff --git a/src/training/trainer/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py deleted file mode 100644 index 02b431f..0000000 --- a/src/training/trainer/callbacks/early_stopping.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Implements Early stopping for PyTorch model.""" -from typing import Dict, Union - -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from training.trainer.callbacks import Callback - - -class EarlyStopping(Callback): - """Stops training when a monitored metric stops improving.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, - monitor: str = "val_loss", - min_delta: float = 0.0, - patience: int = 3, - mode: str = "auto", - ) -> None: - """Initializes the EarlyStopping callback. - - Args: - monitor (str): Description of parameter `monitor`. Defaults to "val_loss". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - patience (int): Description of parameter `patience`. Defaults to 3. - mode (str): Description of parameter `mode`. Defaults to "auto". - - """ - super().__init__() - self.monitor = monitor - self.patience = patience - self.min_delta = torch.tensor(min_delta) - self.mode = mode - self.wait_count = 0 - self.stopped_epoch = 0 - - if mode not in ["auto", "min", "max"]: - logger.warning( - f"EarlyStopping mode {mode} is unkown, fallback to auto mode." - ) - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." - ) - - self.torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = ( - self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf - ) - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_fit_begin(self) -> Union[torch.lt, torch.gt]: - """Reset the early stopping variables for reuse.""" - self.wait_count = 0 - self.stopped_epoch = 0 - self.best_score = ( - self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf - ) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Computes the early stop criterion.""" - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - self.wait_count = 0 - else: - self.wait_count += 1 - if self.wait_count >= self.patience: - self.stopped_epoch = epoch - self.model.stop_training = True - - def on_fit_end(self) -> None: - """Logs if early stopping was used.""" - if self.stopped_epoch > 0: - logger.info( - f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." - ) - - def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: - """Extracts the monitor value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" - + f"metrics are: {','.join(list(logs.keys()))}" - ) - return None - return torch.tensor(monitor_value) diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py deleted file mode 100644 index 630c434..0000000 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Callbacks for learning rate schedulers.""" -from typing import Callable, Dict, List, Optional, Type - -from torch.optim.swa_utils import update_bn -from training.trainer.callbacks import Callback - -from text_recognizer.models import Model - - -class LRScheduler(Callback): - """Generic learning rate scheduler callback.""" - - def __init__(self) -> None: - super().__init__() - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] - self.interval = self.model.lr_scheduler["interval"] - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - if self.interval == "epoch": - if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__: - self.lr_scheduler.step(logs["val_loss"]) - else: - self.lr_scheduler.step() - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every training batch.""" - if self.interval == "step": - self.lr_scheduler.step() - - -class SWA(Callback): - """Stochastic Weight Averaging callback.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - self.interval = None - self.swa_scheduler = None - self.swa_start = None - self.current_epoch = 1 - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] - self.interval = self.model.lr_scheduler["interval"] - self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"] - self.swa_start = self.model.swa_scheduler["swa_start"] - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every training batch.""" - if epoch > self.swa_start: - self.model.swa_network.update_parameters(self.model.network) - self.swa_scheduler.step() - elif self.interval == "epoch": - self.lr_scheduler.step() - self.current_epoch = epoch - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every training batch.""" - if self.current_epoch < self.swa_start and self.interval == "step": - self.lr_scheduler.step() - - def on_fit_end(self) -> None: - """Update batch norm statistics for the swa model at the end of training.""" - if self.model.swa_network: - update_bn( - self.model.val_dataloader(), - self.model.swa_network, - device=self.model.device, - ) diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py deleted file mode 100644 index 6c4305a..0000000 --- a/src/training/trainer/callbacks/progress_bar.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Progress bar callback for the training loop.""" -from typing import Dict, Optional - -from tqdm import tqdm -from training.trainer.callbacks import Callback - - -class ProgressBar(Callback): - """A TQDM progress bar for the training loop.""" - - def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: - """Initializes the tqdm callback.""" - self.epochs = epochs - print(epochs, type(epochs)) - self.log_batch_frequency = log_batch_frequency - self.progress_bar = None - self.val_metrics = {} - - def _configure_progress_bar(self) -> None: - """Configures the tqdm progress bar with custom bar format.""" - self.progress_bar = tqdm( - total=len(self.model.train_dataloader()), - leave=False, - unit="steps", - mininterval=self.log_batch_frequency, - bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", - ) - - def _key_abbreviations(self, logs: Dict) -> Dict: - """Changes the length of keys, so that the progress bar fits better.""" - - def rename(key: str) -> str: - """Renames accuracy to acc.""" - return key.replace("accuracy", "acc") - - return {rename(key): value for key, value in logs.items()} - - # def on_fit_begin(self) -> None: - # """Creates a tqdm progress bar.""" - # self._configure_progress_bar() - - def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: - """Updates the description with the current epoch.""" - if epoch == 1: - self._configure_progress_bar() - else: - self.progress_bar.reset() - self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """At the end of each epoch, the validation metrics are updated to the progress bar.""" - self.val_metrics = logs - self.progress_bar.set_postfix(**self._key_abbreviations(logs)) - self.progress_bar.update() - - def on_train_batch_end(self, batch: int, logs: Dict) -> None: - """Updates the progress bar for each training step.""" - if self.val_metrics: - logs.update(self.val_metrics) - self.progress_bar.set_postfix(**self._key_abbreviations(logs)) - self.progress_bar.update() - - def on_fit_end(self) -> None: - """Closes the tqdm progress bar.""" - self.progress_bar.close() diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py deleted file mode 100644 index 552a4f4..0000000 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ /dev/null @@ -1,261 +0,0 @@ -"""Callback for W&B.""" -from typing import Callable, Dict, List, Optional, Type - -import numpy as np -from training.trainer.callbacks import Callback -import wandb - -import text_recognizer.datasets.transforms as transforms -from text_recognizer.models.base import Model - - -class WandbCallback(Callback): - """A custom W&B metric logger for the trainer.""" - - def __init__(self, log_batch_frequency: int = None) -> None: - """Short summary. - - Args: - log_batch_frequency (int): If None, metrics will be logged every epoch. - If set to an integer, callback will log every metrics every log_batch_frequency. - - """ - super().__init__() - self.log_batch_frequency = log_batch_frequency - - def _on_batch_end(self, batch: int, logs: Dict) -> None: - if self.log_batch_frequency and batch % self.log_batch_frequency == 0: - wandb.log(logs, commit=True) - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Logs training metrics.""" - if logs is not None: - logs["lr"] = self.model.optimizer.param_groups[0]["lr"] - self._on_batch_end(batch, logs) - - def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Logs validation metrics.""" - if logs is not None: - self._on_batch_end(batch, logs) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Logs at epoch end.""" - wandb.log(logs, commit=True) - - -class WandbImageLogger(Callback): - """Custom W&B callback for image logging.""" - - def __init__( - self, - example_indices: Optional[List] = None, - num_examples: int = 4, - transform: Optional[bool] = None, - ) -> None: - """Initializes the WandbImageLogger with the model to train. - - Args: - example_indices (Optional[List]): Indices for validation images. Defaults to None. - num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - transform (Optional[Dict]): Use transform on image or not. Defaults to None. - - """ - - super().__init__() - self.caption = None - self.example_indices = example_indices - self.test_sample_indices = None - self.num_examples = num_examples - self.transform = ( - self._configure_transform(transform) if transform is not None else None - ) - - def _configure_transform(self, transform: Dict) -> Callable: - args = transform["args"] or {} - return getattr(transforms, transform["type"])(**args) - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and extracts validation images from the dataset.""" - self.model = model - self.caption = "Validation Examples" - if self.example_indices is None: - self.example_indices = np.random.randint( - 0, len(self.model.val_dataset), self.num_examples - ) - self.images = self.model.val_dataset.dataset.data[self.example_indices] - self.targets = self.model.val_dataset.dataset.targets[self.example_indices] - self.targets = self.targets.tolist() - - def on_test_begin(self) -> None: - """Get samples from test dataset.""" - self.caption = "Test Examples" - if self.test_sample_indices is None: - self.test_sample_indices = np.random.randint( - 0, len(self.model.test_dataset), self.num_examples - ) - self.images = self.model.test_dataset.data[self.test_sample_indices] - self.targets = self.model.test_dataset.targets[self.test_sample_indices] - self.targets = self.targets.tolist() - - def on_test_end(self) -> None: - """Log test images.""" - self.on_epoch_end(0, {}) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Get network predictions on validation images.""" - images = [] - for i, image in enumerate(self.images): - image = self.transform(image) if self.transform is not None else image - pred, conf = self.model.predict_on_image(image) - if isinstance(self.targets[i], list): - ground_truth = "".join( - [ - self.model.mapper(int(target_index) - 26) - if target_index > 35 - else self.model.mapper(int(target_index)) - for target_index in self.targets[i] - ] - ).rstrip("_") - else: - ground_truth = self.model.mapper(int(self.targets[i])) - caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" - images.append(wandb.Image(image, caption=caption)) - - wandb.log({f"{self.caption}": images}, commit=False) - - -class WandbSegmentationLogger(Callback): - """Custom W&B callback for image logging.""" - - def __init__( - self, - class_labels: Dict, - example_indices: Optional[List] = None, - num_examples: int = 4, - ) -> None: - """Initializes the WandbImageLogger with the model to train. - - Args: - class_labels (Dict): A dict with int as key and class string as value. - example_indices (Optional[List]): Indices for validation images. Defaults to None. - num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - - """ - - super().__init__() - self.caption = None - self.class_labels = {int(k): v for k, v in class_labels.items()} - self.example_indices = example_indices - self.test_sample_indices = None - self.num_examples = num_examples - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and extracts validation images from the dataset.""" - self.model = model - self.caption = "Validation Segmentation Examples" - if self.example_indices is None: - self.example_indices = np.random.randint( - 0, len(self.model.val_dataset), self.num_examples - ) - self.images = self.model.val_dataset.dataset.data[self.example_indices] - self.targets = self.model.val_dataset.dataset.targets[self.example_indices] - self.targets = self.targets.tolist() - - def on_test_begin(self) -> None: - """Get samples from test dataset.""" - self.caption = "Test Segmentation Examples" - if self.test_sample_indices is None: - self.test_sample_indices = np.random.randint( - 0, len(self.model.test_dataset), self.num_examples - ) - self.images = self.model.test_dataset.data[self.test_sample_indices] - self.targets = self.model.test_dataset.targets[self.test_sample_indices] - self.targets = self.targets.tolist() - - def on_test_end(self) -> None: - """Log test images.""" - self.on_epoch_end(0, {}) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Get network predictions on validation images.""" - images = [] - for i, image in enumerate(self.images): - pred_mask = ( - self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() - ) - gt_mask = np.array(self.targets[i]) - images.append( - wandb.Image( - image, - masks={ - "predictions": { - "mask_data": pred_mask, - "class_labels": self.class_labels, - }, - "ground_truth": { - "mask_data": gt_mask, - "class_labels": self.class_labels, - }, - }, - ) - ) - - wandb.log({f"{self.caption}": images}, commit=False) - - -class WandbReconstructionLogger(Callback): - """Custom W&B callback for image reconstructions logging.""" - - def __init__( - self, example_indices: Optional[List] = None, num_examples: int = 4, - ) -> None: - """Initializes the WandbImageLogger with the model to train. - - Args: - example_indices (Optional[List]): Indices for validation images. Defaults to None. - num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - - """ - - super().__init__() - self.caption = None - self.example_indices = example_indices - self.test_sample_indices = None - self.num_examples = num_examples - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and extracts validation images from the dataset.""" - self.model = model - self.caption = "Validation Reconstructions Examples" - if self.example_indices is None: - self.example_indices = np.random.randint( - 0, len(self.model.val_dataset), self.num_examples - ) - self.images = self.model.val_dataset.dataset.data[self.example_indices] - - def on_test_begin(self) -> None: - """Get samples from test dataset.""" - self.caption = "Test Reconstructions Examples" - if self.test_sample_indices is None: - self.test_sample_indices = np.random.randint( - 0, len(self.model.test_dataset), self.num_examples - ) - self.images = self.model.test_dataset.data[self.test_sample_indices] - - def on_test_end(self) -> None: - """Log test images.""" - self.on_epoch_end(0, {}) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Get network predictions on validation images.""" - images = [] - for image in self.images: - reconstructed_image = ( - self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() - ) - images.append(image) - images.append(reconstructed_image) - - wandb.log( - {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False, - ) diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py deleted file mode 100644 index b770c94..0000000 --- a/src/training/trainer/train.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Training script for PyTorch models.""" - -from pathlib import Path -import time -from typing import Dict, List, Optional, Tuple, Type -import warnings - -from einops import rearrange -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torch.optim.swa_utils import update_bn -from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA -from training.trainer.util import log_val_metric -import wandb - -from text_recognizer.models import Model - - -torch.backends.cudnn.benchmark = True -np.random.seed(4711) -torch.manual_seed(4711) -torch.cuda.manual_seed(4711) - - -warnings.filterwarnings("ignore") - - -class Trainer: - """Trainer for training PyTorch models.""" - - def __init__( - self, - max_epochs: int, - callbacks: List[Type[Callback]], - transformer_model: bool = False, - max_norm: float = 0.0, - freeze_backbone: Optional[int] = None, - ) -> None: - """Initialization of the Trainer. - - Args: - max_epochs (int): The maximum number of epochs in the training loop. - callbacks (CallbackList): List of callbacks to be called. - transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. - max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0. - freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training - Transformers. Default is None. - - """ - # Training arguments. - self.start_epoch = 1 - self.max_epochs = max_epochs - self.callbacks = callbacks - self.freeze_backbone = freeze_backbone - - # Flag for setting callbacks. - self.callbacks_configured = False - - self.transformer_model = transformer_model - - self.max_norm = max_norm - - # Model placeholders - self.model = None - - def _configure_callbacks(self) -> None: - """Instantiate the CallbackList.""" - if not self.callbacks_configured: - # If learning rate schedulers are present, they need to be added to the callbacks. - if self.model.swa_scheduler is not None: - self.callbacks.append(SWA()) - elif self.model.lr_scheduler is not None: - self.callbacks.append(LRScheduler()) - - self.callbacks = CallbackList(self.model, self.callbacks) - - def compute_metrics( - self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int - ) -> Dict: - """Computes metrics for output and target pairs.""" - # Compute metrics. - loss = loss.detach().float().item() - output = output.detach() - targets = targets.detach() - if self.model.metrics is not None: - metrics = {} - for metric in self.model.metrics: - if metric == "cer" or metric == "wer": - metrics[metric] = self.model.metrics[metric]( - output, - targets, - batch_size, - self.model.mapper(self.model.pad_token), - ) - else: - metrics[metric] = self.model.metrics[metric](output, targets) - else: - metrics = {} - metrics["loss"] = loss - - return metrics - - def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: - """Performs the training step.""" - # Pass the tensor to the device for computation. - data, targets = samples - data, targets = ( - data.to(self.model.device), - targets.to(self.model.device), - ) - - batch_size = data.shape[0] - - # Placeholder for uxiliary loss. - aux_loss = None - - # Forward pass. - # Get the network prediction. - if self.transformer_model: - if self.freeze_backbone is not None and batch < self.freeze_backbone: - with torch.no_grad(): - image_features = self.model.network.extract_image_features(data) - - if isinstance(image_features, Tuple): - image_features, _ = image_features - - output = self.model.network.decode_image_features( - image_features, targets[:, :-1] - ) - else: - output = self.model.network.forward(data, targets[:, :-1]) - if isinstance(output, Tuple): - output, aux_loss = output - output = rearrange(output, "b t v -> (b t) v") - targets = rearrange(targets[:, 1:], "b t -> (b t)").long() - else: - output = self.model.forward(data) - - if isinstance(output, Tuple): - output, aux_loss = output - targets = data - - # Compute the loss. - loss = self.model.criterion(output, targets) - - if aux_loss is not None: - loss += aux_loss - - # Backward pass. - # Clear the previous gradients. - for p in self.model.network.parameters(): - p.grad = None - - # Compute the gradients. - loss.backward() - - if self.max_norm > 0: - torch.nn.utils.clip_grad_norm_( - self.model.network.parameters(), self.max_norm - ) - - # Perform updates using calculated gradients. - self.model.optimizer.step() - - metrics = self.compute_metrics(output, targets, loss, batch_size) - - return metrics - - def train(self) -> None: - """Runs the training loop for one epoch.""" - # Set model to traning mode. - self.model.train() - - for batch, samples in enumerate(self.model.train_dataloader()): - self.callbacks.on_train_batch_begin(batch) - metrics = self.training_step(batch, samples) - self.callbacks.on_train_batch_end(batch, logs=metrics) - - @torch.no_grad() - def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: - """Performs the validation step.""" - - # Pass the tensor to the device for computation. - data, targets = samples - data, targets = ( - data.to(self.model.device), - targets.to(self.model.device), - ) - - batch_size = data.shape[0] - - # Placeholder for uxiliary loss. - aux_loss = None - - # Forward pass. - # Get the network prediction. - # Use SWA if available and using test dataset. - if self.transformer_model: - output = self.model.network.forward(data, targets[:, :-1]) - if isinstance(output, Tuple): - output, aux_loss = output - output = rearrange(output, "b t v -> (b t) v") - targets = rearrange(targets[:, 1:], "b t -> (b t)").long() - else: - output = self.model.forward(data) - - if isinstance(output, Tuple): - output, aux_loss = output - targets = data - - # Compute the loss. - loss = self.model.criterion(output, targets) - - if aux_loss is not None: - loss += aux_loss - - # Compute metrics. - metrics = self.compute_metrics(output, targets, loss, batch_size) - - return metrics - - def validate(self) -> Dict: - """Runs the validation loop for one epoch.""" - # Set model to eval mode. - self.model.eval() - - # Summary for the current eval loop. - summary = [] - - for batch, samples in enumerate(self.model.val_dataloader()): - self.callbacks.on_validation_batch_begin(batch) - metrics = self.validation_step(batch, samples) - self.callbacks.on_validation_batch_end(batch, logs=metrics) - summary.append(metrics) - - # Compute mean of all metrics. - metrics_mean = { - "val_" + metric: np.mean([x[metric] for x in summary]) - for metric in summary[0] - } - - return metrics_mean - - def fit(self, model: Type[Model]) -> None: - """Runs the training and evaluation loop.""" - - # Sets model, loads the data, criterion, and optimizers. - self.model = model - self.model.prepare_data() - self.model.configure_model() - - # Configure callbacks. - self._configure_callbacks() - - # Set start time. - t_start = time.time() - - self.callbacks.on_fit_begin() - - # Run the training loop. - for epoch in range(self.start_epoch, self.max_epochs + 1): - self.callbacks.on_epoch_begin(epoch) - - # Perform one training pass over the training set. - self.train() - - # Evaluate the model on the validation set. - val_metrics = self.validate() - log_val_metric(val_metrics, epoch) - - self.callbacks.on_epoch_end(epoch, logs=val_metrics) - - if self.model.stop_training: - break - - # Calculate the total training time. - t_end = time.time() - t_training = t_end - t_start - - self.callbacks.on_fit_end() - - logger.info(f"Training took {t_training:.2f} s.") - - # "Teardown". - self.model = None - - def test(self, model: Type[Model]) -> Dict: - """Run inference on test data.""" - - # Sets model, loads the data, criterion, and optimizers. - self.model = model - self.model.prepare_data() - self.model.configure_model() - - # Configure callbacks. - self._configure_callbacks() - - self.callbacks.on_test_begin() - - self.model.eval() - - # Check if SWA network is available. - self.model.use_swa_model() - - # Summary for the current test loop. - summary = [] - - for batch, samples in enumerate(self.model.test_dataloader()): - metrics = self.validation_step(batch, samples) - summary.append(metrics) - - self.callbacks.on_test_end() - - # Compute mean of all test metrics. - metrics_mean = { - "test_" + metric: np.mean([x[metric] for x in summary]) - for metric in summary[0] - } - - # "Teardown". - self.model = None - - return metrics_mean diff --git a/src/training/trainer/util.py b/src/training/trainer/util.py deleted file mode 100644 index 7cf1b45..0000000 --- a/src/training/trainer/util.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Utility functions for training neural networks.""" -from typing import Dict, Optional - -from loguru import logger - - -def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None: - """Logging of val metrics to file/terminal.""" - log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") - logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())) - - -class RunningAverage: - """Maintains a running average.""" - - def __init__(self) -> None: - """Initializes the parameters.""" - self.steps = 0 - self.total = 0 - - def update(self, val: float) -> None: - """Updates the parameters.""" - self.total += val - self.steps += 1 - - def __call__(self) -> float: - """Computes the running average.""" - return self.total / float(self.steps) |