summaryrefslogtreecommitdiff
path: root/training/trainer
diff options
context:
space:
mode:
Diffstat (limited to 'training/trainer')
-rw-r--r--training/trainer/__init__.py2
-rw-r--r--training/trainer/callbacks/__init__.py29
-rw-r--r--training/trainer/callbacks/base.py188
-rw-r--r--training/trainer/callbacks/checkpoint.py95
-rw-r--r--training/trainer/callbacks/early_stopping.py108
-rw-r--r--training/trainer/callbacks/lr_schedulers.py77
-rw-r--r--training/trainer/callbacks/progress_bar.py65
-rw-r--r--training/trainer/callbacks/wandb_callbacks.py261
-rw-r--r--training/trainer/train.py325
-rw-r--r--training/trainer/util.py28
10 files changed, 1178 insertions, 0 deletions
diff --git a/training/trainer/__init__.py b/training/trainer/__init__.py
new file mode 100644
index 0000000..de41bfb
--- /dev/null
+++ b/training/trainer/__init__.py
@@ -0,0 +1,2 @@
+"""Trainer modules."""
+from .train import Trainer
diff --git a/training/trainer/callbacks/__init__.py b/training/trainer/callbacks/__init__.py
new file mode 100644
index 0000000..80c4177
--- /dev/null
+++ b/training/trainer/callbacks/__init__.py
@@ -0,0 +1,29 @@
+"""The callback modules used in the training script."""
+from .base import Callback, CallbackList
+from .checkpoint import Checkpoint
+from .early_stopping import EarlyStopping
+from .lr_schedulers import (
+ LRScheduler,
+ SWA,
+)
+from .progress_bar import ProgressBar
+from .wandb_callbacks import (
+ WandbCallback,
+ WandbImageLogger,
+ WandbReconstructionLogger,
+ WandbSegmentationLogger,
+)
+
+__all__ = [
+ "Callback",
+ "CallbackList",
+ "Checkpoint",
+ "EarlyStopping",
+ "LRScheduler",
+ "WandbCallback",
+ "WandbImageLogger",
+ "WandbReconstructionLogger",
+ "WandbSegmentationLogger",
+ "ProgressBar",
+ "SWA",
+]
diff --git a/training/trainer/callbacks/base.py b/training/trainer/callbacks/base.py
new file mode 100644
index 0000000..500b642
--- /dev/null
+++ b/training/trainer/callbacks/base.py
@@ -0,0 +1,188 @@
+"""Metaclass for callback functions."""
+
+from enum import Enum
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from loguru import logger
+import numpy as np
+import torch
+
+from text_recognizer.models import Model
+
+
+class ModeKeys:
+ """Mode keys for CallbackList."""
+
+ TRAIN = "train"
+ VALIDATION = "validation"
+
+
+class Callback:
+ """Metaclass for callbacks used in training."""
+
+ def __init__(self) -> None:
+ """Initializes the Callback instance."""
+ self.model = None
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Set the model."""
+ self.model = model
+
+ def on_fit_begin(self) -> None:
+ """Called when fit begins."""
+ pass
+
+ def on_fit_end(self) -> None:
+ """Called when fit ends."""
+ pass
+
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the beginning of an epoch. Only used in training mode."""
+ pass
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch. Only used in training mode."""
+ pass
+
+ def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch."""
+ pass
+
+ def on_validation_batch_begin(
+ self, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch."""
+ pass
+
+ def on_test_begin(self) -> None:
+ """Called at the beginning of test."""
+ pass
+
+ def on_test_end(self) -> None:
+ """Called at the end of test."""
+ pass
+
+
+class CallbackList:
+ """Container for abstracting away callback calls."""
+
+ mode_keys = ModeKeys()
+
+ def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None:
+ """Container for `Callback` instances.
+
+ This object wraps a list of `Callback` instances and allows them all to be
+ called via a single end point.
+
+ Args:
+ model (Type[Model]): A `Model` instance.
+ callbacks (List[Callback]): List of `Callback` instances. Defaults to None.
+
+ """
+
+ self._callbacks = callbacks or []
+ if model:
+ self.set_model(model)
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Set the model for all callbacks."""
+ self.model = model
+ for callback in self._callbacks:
+ callback.set_model(model=self.model)
+
+ def append(self, callback: Type[Callback]) -> None:
+ """Append new callback to callback list."""
+ self._callbacks.append(callback)
+
+ def on_fit_begin(self) -> None:
+ """Called when fit begins."""
+ for callback in self._callbacks:
+ callback.on_fit_begin()
+
+ def on_fit_end(self) -> None:
+ """Called when fit ends."""
+ for callback in self._callbacks:
+ callback.on_fit_end()
+
+ def on_test_begin(self) -> None:
+ """Called when test begins."""
+ for callback in self._callbacks:
+ callback.on_test_begin()
+
+ def on_test_end(self) -> None:
+ """Called when test ends."""
+ for callback in self._callbacks:
+ callback.on_test_end()
+
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the beginning of an epoch."""
+ for callback in self._callbacks:
+ callback.on_epoch_begin(epoch, logs)
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch."""
+ for callback in self._callbacks:
+ callback.on_epoch_end(epoch, logs)
+
+ def _call_batch_hook(
+ self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Helper function for all batch_{begin | end} methods."""
+ if hook == "begin":
+ self._call_batch_begin_hook(mode, batch, logs)
+ elif hook == "end":
+ self._call_batch_end_hook(mode, batch, logs)
+ else:
+ raise ValueError(f"Unrecognized hook {hook}.")
+
+ def _call_batch_begin_hook(
+ self, mode: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Helper function for all `on_*_batch_begin` methods."""
+ hook_name = f"on_{mode}_batch_begin"
+ self._call_batch_hook_helper(hook_name, batch, logs)
+
+ def _call_batch_end_hook(
+ self, mode: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Helper function for all `on_*_batch_end` methods."""
+ hook_name = f"on_{mode}_batch_end"
+ self._call_batch_hook_helper(hook_name, batch, logs)
+
+ def _call_batch_hook_helper(
+ self, hook_name: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Helper function for `on_*_batch_begin` methods."""
+ for callback in self._callbacks:
+ hook = getattr(callback, hook_name)
+ hook(batch, logs)
+
+ def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the beginning of an epoch."""
+ self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch."""
+ self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)
+
+ def on_validation_batch_begin(
+ self, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Called at the beginning of an epoch."""
+ self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)
+
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch."""
+ self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)
+
+ def __iter__(self) -> iter:
+ """Iter function for callback list."""
+ return iter(self._callbacks)
diff --git a/training/trainer/callbacks/checkpoint.py b/training/trainer/callbacks/checkpoint.py
new file mode 100644
index 0000000..a54e0a9
--- /dev/null
+++ b/training/trainer/callbacks/checkpoint.py
@@ -0,0 +1,95 @@
+"""Callback checkpoint for training models."""
+from enum import Enum
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from loguru import logger
+import numpy as np
+import torch
+from training.trainer.callbacks import Callback
+
+from text_recognizer.models import Model
+
+
+class Checkpoint(Callback):
+ """Saving model parameters at the end of each epoch."""
+
+ mode_dict = {
+ "min": torch.lt,
+ "max": torch.gt,
+ }
+
+ def __init__(
+ self,
+ checkpoint_path: Union[str, Path],
+ monitor: str = "accuracy",
+ mode: str = "auto",
+ min_delta: float = 0.0,
+ ) -> None:
+ """Monitors a quantity that will allow us to determine the best model weights.
+
+ Args:
+ checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint.
+ monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
+ mode (str): Description of parameter `mode`. Defaults to "auto".
+ min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
+
+ """
+ super().__init__()
+ self.checkpoint_path = Path(checkpoint_path)
+ self.monitor = monitor
+ self.mode = mode
+ self.min_delta = torch.tensor(min_delta)
+
+ if mode not in ["auto", "min", "max"]:
+ logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
+
+ self.mode = "auto"
+
+ if self.mode == "auto":
+ if "accuracy" in self.monitor:
+ self.mode = "max"
+ else:
+ self.mode = "min"
+ logger.debug(
+ f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
+ )
+
+ torch_inf = torch.tensor(np.inf)
+ self.min_delta *= 1 if self.monitor_op == torch.gt else -1
+ self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
+
+ @property
+ def monitor_op(self) -> float:
+ """Returns the comparison method."""
+ return self.mode_dict[self.mode]
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Saves a checkpoint for the network parameters.
+
+ Args:
+ epoch (int): The current epoch.
+ logs (Dict): The log containing the monitored metrics.
+
+ """
+ current = self.get_monitor_value(logs)
+ if current is None:
+ return
+ if self.monitor_op(current - self.min_delta, self.best_score):
+ self.best_score = current
+ is_best = True
+ else:
+ is_best = False
+
+ self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor)
+
+ def get_monitor_value(self, logs: Dict) -> Union[float, None]:
+ """Extracts the monitored value."""
+ monitor_value = logs.get(self.monitor)
+ if monitor_value is None:
+ logger.warning(
+ f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
+ + f" metrics are: {','.join(list(logs.keys()))}"
+ )
+ return None
+ return monitor_value
diff --git a/training/trainer/callbacks/early_stopping.py b/training/trainer/callbacks/early_stopping.py
new file mode 100644
index 0000000..02b431f
--- /dev/null
+++ b/training/trainer/callbacks/early_stopping.py
@@ -0,0 +1,108 @@
+"""Implements Early stopping for PyTorch model."""
+from typing import Dict, Union
+
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from training.trainer.callbacks import Callback
+
+
+class EarlyStopping(Callback):
+ """Stops training when a monitored metric stops improving."""
+
+ mode_dict = {
+ "min": torch.lt,
+ "max": torch.gt,
+ }
+
+ def __init__(
+ self,
+ monitor: str = "val_loss",
+ min_delta: float = 0.0,
+ patience: int = 3,
+ mode: str = "auto",
+ ) -> None:
+ """Initializes the EarlyStopping callback.
+
+ Args:
+ monitor (str): Description of parameter `monitor`. Defaults to "val_loss".
+ min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
+ patience (int): Description of parameter `patience`. Defaults to 3.
+ mode (str): Description of parameter `mode`. Defaults to "auto".
+
+ """
+ super().__init__()
+ self.monitor = monitor
+ self.patience = patience
+ self.min_delta = torch.tensor(min_delta)
+ self.mode = mode
+ self.wait_count = 0
+ self.stopped_epoch = 0
+
+ if mode not in ["auto", "min", "max"]:
+ logger.warning(
+ f"EarlyStopping mode {mode} is unkown, fallback to auto mode."
+ )
+
+ self.mode = "auto"
+
+ if self.mode == "auto":
+ if "accuracy" in self.monitor:
+ self.mode = "max"
+ else:
+ self.mode = "min"
+ logger.debug(
+ f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}."
+ )
+
+ self.torch_inf = torch.tensor(np.inf)
+ self.min_delta *= 1 if self.monitor_op == torch.gt else -1
+ self.best_score = (
+ self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
+ )
+
+ @property
+ def monitor_op(self) -> float:
+ """Returns the comparison method."""
+ return self.mode_dict[self.mode]
+
+ def on_fit_begin(self) -> Union[torch.lt, torch.gt]:
+ """Reset the early stopping variables for reuse."""
+ self.wait_count = 0
+ self.stopped_epoch = 0
+ self.best_score = (
+ self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
+ )
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Computes the early stop criterion."""
+ current = self.get_monitor_value(logs)
+ if current is None:
+ return
+ if self.monitor_op(current - self.min_delta, self.best_score):
+ self.best_score = current
+ self.wait_count = 0
+ else:
+ self.wait_count += 1
+ if self.wait_count >= self.patience:
+ self.stopped_epoch = epoch
+ self.model.stop_training = True
+
+ def on_fit_end(self) -> None:
+ """Logs if early stopping was used."""
+ if self.stopped_epoch > 0:
+ logger.info(
+ f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
+ )
+
+ def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]:
+ """Extracts the monitor value."""
+ monitor_value = logs.get(self.monitor)
+ if monitor_value is None:
+ logger.warning(
+ f"Early stopping is conditioned on metric {self.monitor} which is not available. Available"
+ + f"metrics are: {','.join(list(logs.keys()))}"
+ )
+ return None
+ return torch.tensor(monitor_value)
diff --git a/training/trainer/callbacks/lr_schedulers.py b/training/trainer/callbacks/lr_schedulers.py
new file mode 100644
index 0000000..630c434
--- /dev/null
+++ b/training/trainer/callbacks/lr_schedulers.py
@@ -0,0 +1,77 @@
+"""Callbacks for learning rate schedulers."""
+from typing import Callable, Dict, List, Optional, Type
+
+from torch.optim.swa_utils import update_bn
+from training.trainer.callbacks import Callback
+
+from text_recognizer.models import Model
+
+
+class LRScheduler(Callback):
+ """Generic learning rate scheduler callback."""
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and lr scheduler."""
+ self.model = model
+ self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+ self.interval = self.model.lr_scheduler["interval"]
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every epoch."""
+ if self.interval == "epoch":
+ if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__:
+ self.lr_scheduler.step(logs["val_loss"])
+ else:
+ self.lr_scheduler.step()
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every training batch."""
+ if self.interval == "step":
+ self.lr_scheduler.step()
+
+
+class SWA(Callback):
+ """Stochastic Weight Averaging callback."""
+
+ def __init__(self) -> None:
+ """Initializes the callback."""
+ super().__init__()
+ self.lr_scheduler = None
+ self.interval = None
+ self.swa_scheduler = None
+ self.swa_start = None
+ self.current_epoch = 1
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and lr scheduler."""
+ self.model = model
+ self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+ self.interval = self.model.lr_scheduler["interval"]
+ self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"]
+ self.swa_start = self.model.swa_scheduler["swa_start"]
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every training batch."""
+ if epoch > self.swa_start:
+ self.model.swa_network.update_parameters(self.model.network)
+ self.swa_scheduler.step()
+ elif self.interval == "epoch":
+ self.lr_scheduler.step()
+ self.current_epoch = epoch
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every training batch."""
+ if self.current_epoch < self.swa_start and self.interval == "step":
+ self.lr_scheduler.step()
+
+ def on_fit_end(self) -> None:
+ """Update batch norm statistics for the swa model at the end of training."""
+ if self.model.swa_network:
+ update_bn(
+ self.model.val_dataloader(),
+ self.model.swa_network,
+ device=self.model.device,
+ )
diff --git a/training/trainer/callbacks/progress_bar.py b/training/trainer/callbacks/progress_bar.py
new file mode 100644
index 0000000..6c4305a
--- /dev/null
+++ b/training/trainer/callbacks/progress_bar.py
@@ -0,0 +1,65 @@
+"""Progress bar callback for the training loop."""
+from typing import Dict, Optional
+
+from tqdm import tqdm
+from training.trainer.callbacks import Callback
+
+
+class ProgressBar(Callback):
+ """A TQDM progress bar for the training loop."""
+
+ def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
+ """Initializes the tqdm callback."""
+ self.epochs = epochs
+ print(epochs, type(epochs))
+ self.log_batch_frequency = log_batch_frequency
+ self.progress_bar = None
+ self.val_metrics = {}
+
+ def _configure_progress_bar(self) -> None:
+ """Configures the tqdm progress bar with custom bar format."""
+ self.progress_bar = tqdm(
+ total=len(self.model.train_dataloader()),
+ leave=False,
+ unit="steps",
+ mininterval=self.log_batch_frequency,
+ bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
+ )
+
+ def _key_abbreviations(self, logs: Dict) -> Dict:
+ """Changes the length of keys, so that the progress bar fits better."""
+
+ def rename(key: str) -> str:
+ """Renames accuracy to acc."""
+ return key.replace("accuracy", "acc")
+
+ return {rename(key): value for key, value in logs.items()}
+
+ # def on_fit_begin(self) -> None:
+ # """Creates a tqdm progress bar."""
+ # self._configure_progress_bar()
+
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
+ """Updates the description with the current epoch."""
+ if epoch == 1:
+ self._configure_progress_bar()
+ else:
+ self.progress_bar.reset()
+ self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """At the end of each epoch, the validation metrics are updated to the progress bar."""
+ self.val_metrics = logs
+ self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+ self.progress_bar.update()
+
+ def on_train_batch_end(self, batch: int, logs: Dict) -> None:
+ """Updates the progress bar for each training step."""
+ if self.val_metrics:
+ logs.update(self.val_metrics)
+ self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+ self.progress_bar.update()
+
+ def on_fit_end(self) -> None:
+ """Closes the tqdm progress bar."""
+ self.progress_bar.close()
diff --git a/training/trainer/callbacks/wandb_callbacks.py b/training/trainer/callbacks/wandb_callbacks.py
new file mode 100644
index 0000000..552a4f4
--- /dev/null
+++ b/training/trainer/callbacks/wandb_callbacks.py
@@ -0,0 +1,261 @@
+"""Callback for W&B."""
+from typing import Callable, Dict, List, Optional, Type
+
+import numpy as np
+from training.trainer.callbacks import Callback
+import wandb
+
+import text_recognizer.datasets.transforms as transforms
+from text_recognizer.models.base import Model
+
+
+class WandbCallback(Callback):
+ """A custom W&B metric logger for the trainer."""
+
+ def __init__(self, log_batch_frequency: int = None) -> None:
+ """Short summary.
+
+ Args:
+ log_batch_frequency (int): If None, metrics will be logged every epoch.
+ If set to an integer, callback will log every metrics every log_batch_frequency.
+
+ """
+ super().__init__()
+ self.log_batch_frequency = log_batch_frequency
+
+ def _on_batch_end(self, batch: int, logs: Dict) -> None:
+ if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
+ wandb.log(logs, commit=True)
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Logs training metrics."""
+ if logs is not None:
+ logs["lr"] = self.model.optimizer.param_groups[0]["lr"]
+ self._on_batch_end(batch, logs)
+
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Logs validation metrics."""
+ if logs is not None:
+ self._on_batch_end(batch, logs)
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Logs at epoch end."""
+ wandb.log(logs, commit=True)
+
+
+class WandbImageLogger(Callback):
+ """Custom W&B callback for image logging."""
+
+ def __init__(
+ self,
+ example_indices: Optional[List] = None,
+ num_examples: int = 4,
+ transform: Optional[bool] = None,
+ ) -> None:
+ """Initializes the WandbImageLogger with the model to train.
+
+ Args:
+ example_indices (Optional[List]): Indices for validation images. Defaults to None.
+ num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+ transform (Optional[Dict]): Use transform on image or not. Defaults to None.
+
+ """
+
+ super().__init__()
+ self.caption = None
+ self.example_indices = example_indices
+ self.test_sample_indices = None
+ self.num_examples = num_examples
+ self.transform = (
+ self._configure_transform(transform) if transform is not None else None
+ )
+
+ def _configure_transform(self, transform: Dict) -> Callable:
+ args = transform["args"] or {}
+ return getattr(transforms, transform["type"])(**args)
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and extracts validation images from the dataset."""
+ self.model = model
+ self.caption = "Validation Examples"
+ if self.example_indices is None:
+ self.example_indices = np.random.randint(
+ 0, len(self.model.val_dataset), self.num_examples
+ )
+ self.images = self.model.val_dataset.dataset.data[self.example_indices]
+ self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_begin(self) -> None:
+ """Get samples from test dataset."""
+ self.caption = "Test Examples"
+ if self.test_sample_indices is None:
+ self.test_sample_indices = np.random.randint(
+ 0, len(self.model.test_dataset), self.num_examples
+ )
+ self.images = self.model.test_dataset.data[self.test_sample_indices]
+ self.targets = self.model.test_dataset.targets[self.test_sample_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_end(self) -> None:
+ """Log test images."""
+ self.on_epoch_end(0, {})
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Get network predictions on validation images."""
+ images = []
+ for i, image in enumerate(self.images):
+ image = self.transform(image) if self.transform is not None else image
+ pred, conf = self.model.predict_on_image(image)
+ if isinstance(self.targets[i], list):
+ ground_truth = "".join(
+ [
+ self.model.mapper(int(target_index) - 26)
+ if target_index > 35
+ else self.model.mapper(int(target_index))
+ for target_index in self.targets[i]
+ ]
+ ).rstrip("_")
+ else:
+ ground_truth = self.model.mapper(int(self.targets[i]))
+ caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
+ images.append(wandb.Image(image, caption=caption))
+
+ wandb.log({f"{self.caption}": images}, commit=False)
+
+
+class WandbSegmentationLogger(Callback):
+ """Custom W&B callback for image logging."""
+
+ def __init__(
+ self,
+ class_labels: Dict,
+ example_indices: Optional[List] = None,
+ num_examples: int = 4,
+ ) -> None:
+ """Initializes the WandbImageLogger with the model to train.
+
+ Args:
+ class_labels (Dict): A dict with int as key and class string as value.
+ example_indices (Optional[List]): Indices for validation images. Defaults to None.
+ num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+
+ """
+
+ super().__init__()
+ self.caption = None
+ self.class_labels = {int(k): v for k, v in class_labels.items()}
+ self.example_indices = example_indices
+ self.test_sample_indices = None
+ self.num_examples = num_examples
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and extracts validation images from the dataset."""
+ self.model = model
+ self.caption = "Validation Segmentation Examples"
+ if self.example_indices is None:
+ self.example_indices = np.random.randint(
+ 0, len(self.model.val_dataset), self.num_examples
+ )
+ self.images = self.model.val_dataset.dataset.data[self.example_indices]
+ self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_begin(self) -> None:
+ """Get samples from test dataset."""
+ self.caption = "Test Segmentation Examples"
+ if self.test_sample_indices is None:
+ self.test_sample_indices = np.random.randint(
+ 0, len(self.model.test_dataset), self.num_examples
+ )
+ self.images = self.model.test_dataset.data[self.test_sample_indices]
+ self.targets = self.model.test_dataset.targets[self.test_sample_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_end(self) -> None:
+ """Log test images."""
+ self.on_epoch_end(0, {})
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Get network predictions on validation images."""
+ images = []
+ for i, image in enumerate(self.images):
+ pred_mask = (
+ self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
+ )
+ gt_mask = np.array(self.targets[i])
+ images.append(
+ wandb.Image(
+ image,
+ masks={
+ "predictions": {
+ "mask_data": pred_mask,
+ "class_labels": self.class_labels,
+ },
+ "ground_truth": {
+ "mask_data": gt_mask,
+ "class_labels": self.class_labels,
+ },
+ },
+ )
+ )
+
+ wandb.log({f"{self.caption}": images}, commit=False)
+
+
+class WandbReconstructionLogger(Callback):
+ """Custom W&B callback for image reconstructions logging."""
+
+ def __init__(
+ self, example_indices: Optional[List] = None, num_examples: int = 4,
+ ) -> None:
+ """Initializes the WandbImageLogger with the model to train.
+
+ Args:
+ example_indices (Optional[List]): Indices for validation images. Defaults to None.
+ num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+
+ """
+
+ super().__init__()
+ self.caption = None
+ self.example_indices = example_indices
+ self.test_sample_indices = None
+ self.num_examples = num_examples
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and extracts validation images from the dataset."""
+ self.model = model
+ self.caption = "Validation Reconstructions Examples"
+ if self.example_indices is None:
+ self.example_indices = np.random.randint(
+ 0, len(self.model.val_dataset), self.num_examples
+ )
+ self.images = self.model.val_dataset.dataset.data[self.example_indices]
+
+ def on_test_begin(self) -> None:
+ """Get samples from test dataset."""
+ self.caption = "Test Reconstructions Examples"
+ if self.test_sample_indices is None:
+ self.test_sample_indices = np.random.randint(
+ 0, len(self.model.test_dataset), self.num_examples
+ )
+ self.images = self.model.test_dataset.data[self.test_sample_indices]
+
+ def on_test_end(self) -> None:
+ """Log test images."""
+ self.on_epoch_end(0, {})
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Get network predictions on validation images."""
+ images = []
+ for image in self.images:
+ reconstructed_image = (
+ self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
+ )
+ images.append(image)
+ images.append(reconstructed_image)
+
+ wandb.log(
+ {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False,
+ )
diff --git a/training/trainer/train.py b/training/trainer/train.py
new file mode 100644
index 0000000..b770c94
--- /dev/null
+++ b/training/trainer/train.py
@@ -0,0 +1,325 @@
+"""Training script for PyTorch models."""
+
+from pathlib import Path
+import time
+from typing import Dict, List, Optional, Tuple, Type
+import warnings
+
+from einops import rearrange
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from torch.optim.swa_utils import update_bn
+from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA
+from training.trainer.util import log_val_metric
+import wandb
+
+from text_recognizer.models import Model
+
+
+torch.backends.cudnn.benchmark = True
+np.random.seed(4711)
+torch.manual_seed(4711)
+torch.cuda.manual_seed(4711)
+
+
+warnings.filterwarnings("ignore")
+
+
+class Trainer:
+ """Trainer for training PyTorch models."""
+
+ def __init__(
+ self,
+ max_epochs: int,
+ callbacks: List[Type[Callback]],
+ transformer_model: bool = False,
+ max_norm: float = 0.0,
+ freeze_backbone: Optional[int] = None,
+ ) -> None:
+ """Initialization of the Trainer.
+
+ Args:
+ max_epochs (int): The maximum number of epochs in the training loop.
+ callbacks (CallbackList): List of callbacks to be called.
+ transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
+ max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0.
+ freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training
+ Transformers. Default is None.
+
+ """
+ # Training arguments.
+ self.start_epoch = 1
+ self.max_epochs = max_epochs
+ self.callbacks = callbacks
+ self.freeze_backbone = freeze_backbone
+
+ # Flag for setting callbacks.
+ self.callbacks_configured = False
+
+ self.transformer_model = transformer_model
+
+ self.max_norm = max_norm
+
+ # Model placeholders
+ self.model = None
+
+ def _configure_callbacks(self) -> None:
+ """Instantiate the CallbackList."""
+ if not self.callbacks_configured:
+ # If learning rate schedulers are present, they need to be added to the callbacks.
+ if self.model.swa_scheduler is not None:
+ self.callbacks.append(SWA())
+ elif self.model.lr_scheduler is not None:
+ self.callbacks.append(LRScheduler())
+
+ self.callbacks = CallbackList(self.model, self.callbacks)
+
+ def compute_metrics(
+ self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int
+ ) -> Dict:
+ """Computes metrics for output and target pairs."""
+ # Compute metrics.
+ loss = loss.detach().float().item()
+ output = output.detach()
+ targets = targets.detach()
+ if self.model.metrics is not None:
+ metrics = {}
+ for metric in self.model.metrics:
+ if metric == "cer" or metric == "wer":
+ metrics[metric] = self.model.metrics[metric](
+ output,
+ targets,
+ batch_size,
+ self.model.mapper(self.model.pad_token),
+ )
+ else:
+ metrics[metric] = self.model.metrics[metric](output, targets)
+ else:
+ metrics = {}
+ metrics["loss"] = loss
+
+ return metrics
+
+ def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
+ """Performs the training step."""
+ # Pass the tensor to the device for computation.
+ data, targets = samples
+ data, targets = (
+ data.to(self.model.device),
+ targets.to(self.model.device),
+ )
+
+ batch_size = data.shape[0]
+
+ # Placeholder for uxiliary loss.
+ aux_loss = None
+
+ # Forward pass.
+ # Get the network prediction.
+ if self.transformer_model:
+ if self.freeze_backbone is not None and batch < self.freeze_backbone:
+ with torch.no_grad():
+ image_features = self.model.network.extract_image_features(data)
+
+ if isinstance(image_features, Tuple):
+ image_features, _ = image_features
+
+ output = self.model.network.decode_image_features(
+ image_features, targets[:, :-1]
+ )
+ else:
+ output = self.model.network.forward(data, targets[:, :-1])
+ if isinstance(output, Tuple):
+ output, aux_loss = output
+ output = rearrange(output, "b t v -> (b t) v")
+ targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
+ else:
+ output = self.model.forward(data)
+
+ if isinstance(output, Tuple):
+ output, aux_loss = output
+ targets = data
+
+ # Compute the loss.
+ loss = self.model.criterion(output, targets)
+
+ if aux_loss is not None:
+ loss += aux_loss
+
+ # Backward pass.
+ # Clear the previous gradients.
+ for p in self.model.network.parameters():
+ p.grad = None
+
+ # Compute the gradients.
+ loss.backward()
+
+ if self.max_norm > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.network.parameters(), self.max_norm
+ )
+
+ # Perform updates using calculated gradients.
+ self.model.optimizer.step()
+
+ metrics = self.compute_metrics(output, targets, loss, batch_size)
+
+ return metrics
+
+ def train(self) -> None:
+ """Runs the training loop for one epoch."""
+ # Set model to traning mode.
+ self.model.train()
+
+ for batch, samples in enumerate(self.model.train_dataloader()):
+ self.callbacks.on_train_batch_begin(batch)
+ metrics = self.training_step(batch, samples)
+ self.callbacks.on_train_batch_end(batch, logs=metrics)
+
+ @torch.no_grad()
+ def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
+ """Performs the validation step."""
+
+ # Pass the tensor to the device for computation.
+ data, targets = samples
+ data, targets = (
+ data.to(self.model.device),
+ targets.to(self.model.device),
+ )
+
+ batch_size = data.shape[0]
+
+ # Placeholder for uxiliary loss.
+ aux_loss = None
+
+ # Forward pass.
+ # Get the network prediction.
+ # Use SWA if available and using test dataset.
+ if self.transformer_model:
+ output = self.model.network.forward(data, targets[:, :-1])
+ if isinstance(output, Tuple):
+ output, aux_loss = output
+ output = rearrange(output, "b t v -> (b t) v")
+ targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
+ else:
+ output = self.model.forward(data)
+
+ if isinstance(output, Tuple):
+ output, aux_loss = output
+ targets = data
+
+ # Compute the loss.
+ loss = self.model.criterion(output, targets)
+
+ if aux_loss is not None:
+ loss += aux_loss
+
+ # Compute metrics.
+ metrics = self.compute_metrics(output, targets, loss, batch_size)
+
+ return metrics
+
+ def validate(self) -> Dict:
+ """Runs the validation loop for one epoch."""
+ # Set model to eval mode.
+ self.model.eval()
+
+ # Summary for the current eval loop.
+ summary = []
+
+ for batch, samples in enumerate(self.model.val_dataloader()):
+ self.callbacks.on_validation_batch_begin(batch)
+ metrics = self.validation_step(batch, samples)
+ self.callbacks.on_validation_batch_end(batch, logs=metrics)
+ summary.append(metrics)
+
+ # Compute mean of all metrics.
+ metrics_mean = {
+ "val_" + metric: np.mean([x[metric] for x in summary])
+ for metric in summary[0]
+ }
+
+ return metrics_mean
+
+ def fit(self, model: Type[Model]) -> None:
+ """Runs the training and evaluation loop."""
+
+ # Sets model, loads the data, criterion, and optimizers.
+ self.model = model
+ self.model.prepare_data()
+ self.model.configure_model()
+
+ # Configure callbacks.
+ self._configure_callbacks()
+
+ # Set start time.
+ t_start = time.time()
+
+ self.callbacks.on_fit_begin()
+
+ # Run the training loop.
+ for epoch in range(self.start_epoch, self.max_epochs + 1):
+ self.callbacks.on_epoch_begin(epoch)
+
+ # Perform one training pass over the training set.
+ self.train()
+
+ # Evaluate the model on the validation set.
+ val_metrics = self.validate()
+ log_val_metric(val_metrics, epoch)
+
+ self.callbacks.on_epoch_end(epoch, logs=val_metrics)
+
+ if self.model.stop_training:
+ break
+
+ # Calculate the total training time.
+ t_end = time.time()
+ t_training = t_end - t_start
+
+ self.callbacks.on_fit_end()
+
+ logger.info(f"Training took {t_training:.2f} s.")
+
+ # "Teardown".
+ self.model = None
+
+ def test(self, model: Type[Model]) -> Dict:
+ """Run inference on test data."""
+
+ # Sets model, loads the data, criterion, and optimizers.
+ self.model = model
+ self.model.prepare_data()
+ self.model.configure_model()
+
+ # Configure callbacks.
+ self._configure_callbacks()
+
+ self.callbacks.on_test_begin()
+
+ self.model.eval()
+
+ # Check if SWA network is available.
+ self.model.use_swa_model()
+
+ # Summary for the current test loop.
+ summary = []
+
+ for batch, samples in enumerate(self.model.test_dataloader()):
+ metrics = self.validation_step(batch, samples)
+ summary.append(metrics)
+
+ self.callbacks.on_test_end()
+
+ # Compute mean of all test metrics.
+ metrics_mean = {
+ "test_" + metric: np.mean([x[metric] for x in summary])
+ for metric in summary[0]
+ }
+
+ # "Teardown".
+ self.model = None
+
+ return metrics_mean
diff --git a/training/trainer/util.py b/training/trainer/util.py
new file mode 100644
index 0000000..7cf1b45
--- /dev/null
+++ b/training/trainer/util.py
@@ -0,0 +1,28 @@
+"""Utility functions for training neural networks."""
+from typing import Dict, Optional
+
+from loguru import logger
+
+
+def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None:
+ """Logging of val metrics to file/terminal."""
+ log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
+ logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()))
+
+
+class RunningAverage:
+ """Maintains a running average."""
+
+ def __init__(self) -> None:
+ """Initializes the parameters."""
+ self.steps = 0
+ self.total = 0
+
+ def update(self, val: float) -> None:
+ """Updates the parameters."""
+ self.total += val
+ self.steps += 1
+
+ def __call__(self) -> float:
+ """Computes the running average."""
+ return self.total / float(self.steps)