From 7e8e54e84c63171e748bbf09516fd517e6821ace Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 20 Mar 2021 18:09:06 +0100
Subject: Inital commit for refactoring to lightning

---
 training/trainer/callbacks/__init__.py        |  29 +++
 training/trainer/callbacks/base.py            | 188 +++++++++++++++++++
 training/trainer/callbacks/checkpoint.py      |  95 ++++++++++
 training/trainer/callbacks/early_stopping.py  | 108 +++++++++++
 training/trainer/callbacks/lr_schedulers.py   |  77 ++++++++
 training/trainer/callbacks/progress_bar.py    |  65 +++++++
 training/trainer/callbacks/wandb_callbacks.py | 261 ++++++++++++++++++++++++++
 7 files changed, 823 insertions(+)
 create mode 100644 training/trainer/callbacks/__init__.py
 create mode 100644 training/trainer/callbacks/base.py
 create mode 100644 training/trainer/callbacks/checkpoint.py
 create mode 100644 training/trainer/callbacks/early_stopping.py
 create mode 100644 training/trainer/callbacks/lr_schedulers.py
 create mode 100644 training/trainer/callbacks/progress_bar.py
 create mode 100644 training/trainer/callbacks/wandb_callbacks.py

(limited to 'training/trainer/callbacks')

diff --git a/training/trainer/callbacks/__init__.py b/training/trainer/callbacks/__init__.py
new file mode 100644
index 0000000..80c4177
--- /dev/null
+++ b/training/trainer/callbacks/__init__.py
@@ -0,0 +1,29 @@
+"""The callback modules used in the training script."""
+from .base import Callback, CallbackList
+from .checkpoint import Checkpoint
+from .early_stopping import EarlyStopping
+from .lr_schedulers import (
+    LRScheduler,
+    SWA,
+)
+from .progress_bar import ProgressBar
+from .wandb_callbacks import (
+    WandbCallback,
+    WandbImageLogger,
+    WandbReconstructionLogger,
+    WandbSegmentationLogger,
+)
+
+__all__ = [
+    "Callback",
+    "CallbackList",
+    "Checkpoint",
+    "EarlyStopping",
+    "LRScheduler",
+    "WandbCallback",
+    "WandbImageLogger",
+    "WandbReconstructionLogger",
+    "WandbSegmentationLogger",
+    "ProgressBar",
+    "SWA",
+]
diff --git a/training/trainer/callbacks/base.py b/training/trainer/callbacks/base.py
new file mode 100644
index 0000000..500b642
--- /dev/null
+++ b/training/trainer/callbacks/base.py
@@ -0,0 +1,188 @@
+"""Metaclass for callback functions."""
+
+from enum import Enum
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from loguru import logger
+import numpy as np
+import torch
+
+from text_recognizer.models import Model
+
+
+class ModeKeys:
+    """Mode keys for CallbackList."""
+
+    TRAIN = "train"
+    VALIDATION = "validation"
+
+
+class Callback:
+    """Metaclass for callbacks used in training."""
+
+    def __init__(self) -> None:
+        """Initializes the Callback instance."""
+        self.model = None
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Set the model."""
+        self.model = model
+
+    def on_fit_begin(self) -> None:
+        """Called when fit begins."""
+        pass
+
+    def on_fit_end(self) -> None:
+        """Called when fit ends."""
+        pass
+
+    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the beginning of an epoch. Only used in training mode."""
+        pass
+
+    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch. Only used in training mode."""
+        pass
+
+    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the beginning of an epoch."""
+        pass
+
+    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch."""
+        pass
+
+    def on_validation_batch_begin(
+        self, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Called at the beginning of an epoch."""
+        pass
+
+    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch."""
+        pass
+
+    def on_test_begin(self) -> None:
+        """Called at the beginning of test."""
+        pass
+
+    def on_test_end(self) -> None:
+        """Called at the end of test."""
+        pass
+
+
+class CallbackList:
+    """Container for abstracting away callback calls."""
+
+    mode_keys = ModeKeys()
+
+    def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None:
+        """Container for `Callback` instances.
+
+        This object wraps a list of `Callback` instances and allows them all to be
+        called via a single end point.
+
+        Args:
+            model (Type[Model]): A `Model` instance.
+            callbacks (List[Callback]): List of `Callback` instances. Defaults to None.
+
+        """
+
+        self._callbacks = callbacks or []
+        if model:
+            self.set_model(model)
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Set the model for all callbacks."""
+        self.model = model
+        for callback in self._callbacks:
+            callback.set_model(model=self.model)
+
+    def append(self, callback: Type[Callback]) -> None:
+        """Append new callback to callback list."""
+        self._callbacks.append(callback)
+
+    def on_fit_begin(self) -> None:
+        """Called when fit begins."""
+        for callback in self._callbacks:
+            callback.on_fit_begin()
+
+    def on_fit_end(self) -> None:
+        """Called when fit ends."""
+        for callback in self._callbacks:
+            callback.on_fit_end()
+
+    def on_test_begin(self) -> None:
+        """Called when test begins."""
+        for callback in self._callbacks:
+            callback.on_test_begin()
+
+    def on_test_end(self) -> None:
+        """Called when test ends."""
+        for callback in self._callbacks:
+            callback.on_test_end()
+
+    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the beginning of an epoch."""
+        for callback in self._callbacks:
+            callback.on_epoch_begin(epoch, logs)
+
+    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch."""
+        for callback in self._callbacks:
+            callback.on_epoch_end(epoch, logs)
+
+    def _call_batch_hook(
+        self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Helper function for all batch_{begin | end} methods."""
+        if hook == "begin":
+            self._call_batch_begin_hook(mode, batch, logs)
+        elif hook == "end":
+            self._call_batch_end_hook(mode, batch, logs)
+        else:
+            raise ValueError(f"Unrecognized hook {hook}.")
+
+    def _call_batch_begin_hook(
+        self, mode: str, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Helper function for all `on_*_batch_begin` methods."""
+        hook_name = f"on_{mode}_batch_begin"
+        self._call_batch_hook_helper(hook_name, batch, logs)
+
+    def _call_batch_end_hook(
+        self, mode: str, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Helper function for all `on_*_batch_end` methods."""
+        hook_name = f"on_{mode}_batch_end"
+        self._call_batch_hook_helper(hook_name, batch, logs)
+
+    def _call_batch_hook_helper(
+        self, hook_name: str, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Helper function for `on_*_batch_begin` methods."""
+        for callback in self._callbacks:
+            hook = getattr(callback, hook_name)
+            hook(batch, logs)
+
+    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the beginning of an epoch."""
+        self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)
+
+    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch."""
+        self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)
+
+    def on_validation_batch_begin(
+        self, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Called at the beginning of an epoch."""
+        self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)
+
+    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch."""
+        self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)
+
+    def __iter__(self) -> iter:
+        """Iter function for callback list."""
+        return iter(self._callbacks)
diff --git a/training/trainer/callbacks/checkpoint.py b/training/trainer/callbacks/checkpoint.py
new file mode 100644
index 0000000..a54e0a9
--- /dev/null
+++ b/training/trainer/callbacks/checkpoint.py
@@ -0,0 +1,95 @@
+"""Callback checkpoint for training models."""
+from enum import Enum
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from loguru import logger
+import numpy as np
+import torch
+from training.trainer.callbacks import Callback
+
+from text_recognizer.models import Model
+
+
+class Checkpoint(Callback):
+    """Saving model parameters at the end of each epoch."""
+
+    mode_dict = {
+        "min": torch.lt,
+        "max": torch.gt,
+    }
+
+    def __init__(
+        self,
+        checkpoint_path: Union[str, Path],
+        monitor: str = "accuracy",
+        mode: str = "auto",
+        min_delta: float = 0.0,
+    ) -> None:
+        """Monitors a quantity that will allow us to determine the best model weights.
+
+        Args:
+            checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint.
+            monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
+            mode (str): Description of parameter `mode`. Defaults to "auto".
+            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
+
+        """
+        super().__init__()
+        self.checkpoint_path = Path(checkpoint_path)
+        self.monitor = monitor
+        self.mode = mode
+        self.min_delta = torch.tensor(min_delta)
+
+        if mode not in ["auto", "min", "max"]:
+            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
+
+            self.mode = "auto"
+
+        if self.mode == "auto":
+            if "accuracy" in self.monitor:
+                self.mode = "max"
+            else:
+                self.mode = "min"
+            logger.debug(
+                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
+            )
+
+        torch_inf = torch.tensor(np.inf)
+        self.min_delta *= 1 if self.monitor_op == torch.gt else -1
+        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
+
+    @property
+    def monitor_op(self) -> float:
+        """Returns the comparison method."""
+        return self.mode_dict[self.mode]
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Saves a checkpoint for the network parameters.
+
+        Args:
+            epoch (int): The current epoch.
+            logs (Dict): The log containing the monitored metrics.
+
+        """
+        current = self.get_monitor_value(logs)
+        if current is None:
+            return
+        if self.monitor_op(current - self.min_delta, self.best_score):
+            self.best_score = current
+            is_best = True
+        else:
+            is_best = False
+
+        self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor)
+
+    def get_monitor_value(self, logs: Dict) -> Union[float, None]:
+        """Extracts the monitored value."""
+        monitor_value = logs.get(self.monitor)
+        if monitor_value is None:
+            logger.warning(
+                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
+                + f" metrics are: {','.join(list(logs.keys()))}"
+            )
+            return None
+        return monitor_value
diff --git a/training/trainer/callbacks/early_stopping.py b/training/trainer/callbacks/early_stopping.py
new file mode 100644
index 0000000..02b431f
--- /dev/null
+++ b/training/trainer/callbacks/early_stopping.py
@@ -0,0 +1,108 @@
+"""Implements Early stopping for PyTorch model."""
+from typing import Dict, Union
+
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from training.trainer.callbacks import Callback
+
+
+class EarlyStopping(Callback):
+    """Stops training when a monitored metric stops improving."""
+
+    mode_dict = {
+        "min": torch.lt,
+        "max": torch.gt,
+    }
+
+    def __init__(
+        self,
+        monitor: str = "val_loss",
+        min_delta: float = 0.0,
+        patience: int = 3,
+        mode: str = "auto",
+    ) -> None:
+        """Initializes the EarlyStopping callback.
+
+        Args:
+            monitor (str): Description of parameter `monitor`. Defaults to "val_loss".
+            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
+            patience (int): Description of parameter `patience`. Defaults to 3.
+            mode (str): Description of parameter `mode`. Defaults to "auto".
+
+        """
+        super().__init__()
+        self.monitor = monitor
+        self.patience = patience
+        self.min_delta = torch.tensor(min_delta)
+        self.mode = mode
+        self.wait_count = 0
+        self.stopped_epoch = 0
+
+        if mode not in ["auto", "min", "max"]:
+            logger.warning(
+                f"EarlyStopping mode {mode} is unkown, fallback to auto mode."
+            )
+
+            self.mode = "auto"
+
+        if self.mode == "auto":
+            if "accuracy" in self.monitor:
+                self.mode = "max"
+            else:
+                self.mode = "min"
+            logger.debug(
+                f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}."
+            )
+
+        self.torch_inf = torch.tensor(np.inf)
+        self.min_delta *= 1 if self.monitor_op == torch.gt else -1
+        self.best_score = (
+            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
+        )
+
+    @property
+    def monitor_op(self) -> float:
+        """Returns the comparison method."""
+        return self.mode_dict[self.mode]
+
+    def on_fit_begin(self) -> Union[torch.lt, torch.gt]:
+        """Reset the early stopping variables for reuse."""
+        self.wait_count = 0
+        self.stopped_epoch = 0
+        self.best_score = (
+            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
+        )
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Computes the early stop criterion."""
+        current = self.get_monitor_value(logs)
+        if current is None:
+            return
+        if self.monitor_op(current - self.min_delta, self.best_score):
+            self.best_score = current
+            self.wait_count = 0
+        else:
+            self.wait_count += 1
+            if self.wait_count >= self.patience:
+                self.stopped_epoch = epoch
+                self.model.stop_training = True
+
+    def on_fit_end(self) -> None:
+        """Logs if early stopping was used."""
+        if self.stopped_epoch > 0:
+            logger.info(
+                f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
+            )
+
+    def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]:
+        """Extracts the monitor value."""
+        monitor_value = logs.get(self.monitor)
+        if monitor_value is None:
+            logger.warning(
+                f"Early stopping is conditioned on metric {self.monitor} which is not available. Available"
+                + f"metrics are: {','.join(list(logs.keys()))}"
+            )
+            return None
+        return torch.tensor(monitor_value)
diff --git a/training/trainer/callbacks/lr_schedulers.py b/training/trainer/callbacks/lr_schedulers.py
new file mode 100644
index 0000000..630c434
--- /dev/null
+++ b/training/trainer/callbacks/lr_schedulers.py
@@ -0,0 +1,77 @@
+"""Callbacks for learning rate schedulers."""
+from typing import Callable, Dict, List, Optional, Type
+
+from torch.optim.swa_utils import update_bn
+from training.trainer.callbacks import Callback
+
+from text_recognizer.models import Model
+
+
+class LRScheduler(Callback):
+    """Generic learning rate scheduler callback."""
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Sets the model and lr scheduler."""
+        self.model = model
+        self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+        self.interval = self.model.lr_scheduler["interval"]
+
+    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Takes a step at the end of every epoch."""
+        if self.interval == "epoch":
+            if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__:
+                self.lr_scheduler.step(logs["val_loss"])
+            else:
+                self.lr_scheduler.step()
+
+    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Takes a step at the end of every training batch."""
+        if self.interval == "step":
+            self.lr_scheduler.step()
+
+
+class SWA(Callback):
+    """Stochastic Weight Averaging callback."""
+
+    def __init__(self) -> None:
+        """Initializes the callback."""
+        super().__init__()
+        self.lr_scheduler = None
+        self.interval = None
+        self.swa_scheduler = None
+        self.swa_start = None
+        self.current_epoch = 1
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Sets the model and lr scheduler."""
+        self.model = model
+        self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+        self.interval = self.model.lr_scheduler["interval"]
+        self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"]
+        self.swa_start = self.model.swa_scheduler["swa_start"]
+
+    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Takes a step at the end of every training batch."""
+        if epoch > self.swa_start:
+            self.model.swa_network.update_parameters(self.model.network)
+            self.swa_scheduler.step()
+        elif self.interval == "epoch":
+            self.lr_scheduler.step()
+        self.current_epoch = epoch
+
+    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Takes a step at the end of every training batch."""
+        if self.current_epoch < self.swa_start and self.interval == "step":
+            self.lr_scheduler.step()
+
+    def on_fit_end(self) -> None:
+        """Update batch norm statistics for the swa model at the end of training."""
+        if self.model.swa_network:
+            update_bn(
+                self.model.val_dataloader(),
+                self.model.swa_network,
+                device=self.model.device,
+            )
diff --git a/training/trainer/callbacks/progress_bar.py b/training/trainer/callbacks/progress_bar.py
new file mode 100644
index 0000000..6c4305a
--- /dev/null
+++ b/training/trainer/callbacks/progress_bar.py
@@ -0,0 +1,65 @@
+"""Progress bar callback for the training loop."""
+from typing import Dict, Optional
+
+from tqdm import tqdm
+from training.trainer.callbacks import Callback
+
+
+class ProgressBar(Callback):
+    """A TQDM progress bar for the training loop."""
+
+    def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
+        """Initializes the tqdm callback."""
+        self.epochs = epochs
+        print(epochs, type(epochs))
+        self.log_batch_frequency = log_batch_frequency
+        self.progress_bar = None
+        self.val_metrics = {}
+
+    def _configure_progress_bar(self) -> None:
+        """Configures the tqdm progress bar with custom bar format."""
+        self.progress_bar = tqdm(
+            total=len(self.model.train_dataloader()),
+            leave=False,
+            unit="steps",
+            mininterval=self.log_batch_frequency,
+            bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
+        )
+
+    def _key_abbreviations(self, logs: Dict) -> Dict:
+        """Changes the length of keys, so that the progress bar fits better."""
+
+        def rename(key: str) -> str:
+            """Renames accuracy to acc."""
+            return key.replace("accuracy", "acc")
+
+        return {rename(key): value for key, value in logs.items()}
+
+    # def on_fit_begin(self) -> None:
+    #     """Creates a tqdm progress bar."""
+    #     self._configure_progress_bar()
+
+    def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
+        """Updates the description with the current epoch."""
+        if epoch == 1:
+            self._configure_progress_bar()
+        else:
+            self.progress_bar.reset()
+        self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """At the end of each epoch, the validation metrics are updated to the progress bar."""
+        self.val_metrics = logs
+        self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+        self.progress_bar.update()
+
+    def on_train_batch_end(self, batch: int, logs: Dict) -> None:
+        """Updates the progress bar for each training step."""
+        if self.val_metrics:
+            logs.update(self.val_metrics)
+        self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+        self.progress_bar.update()
+
+    def on_fit_end(self) -> None:
+        """Closes the tqdm progress bar."""
+        self.progress_bar.close()
diff --git a/training/trainer/callbacks/wandb_callbacks.py b/training/trainer/callbacks/wandb_callbacks.py
new file mode 100644
index 0000000..552a4f4
--- /dev/null
+++ b/training/trainer/callbacks/wandb_callbacks.py
@@ -0,0 +1,261 @@
+"""Callback for W&B."""
+from typing import Callable, Dict, List, Optional, Type
+
+import numpy as np
+from training.trainer.callbacks import Callback
+import wandb
+
+import text_recognizer.datasets.transforms as transforms
+from text_recognizer.models.base import Model
+
+
+class WandbCallback(Callback):
+    """A custom W&B metric logger for the trainer."""
+
+    def __init__(self, log_batch_frequency: int = None) -> None:
+        """Short summary.
+
+        Args:
+            log_batch_frequency (int): If None, metrics will be logged every epoch.
+                If set to an integer, callback will log every metrics every log_batch_frequency.
+
+        """
+        super().__init__()
+        self.log_batch_frequency = log_batch_frequency
+
+    def _on_batch_end(self, batch: int, logs: Dict) -> None:
+        if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
+            wandb.log(logs, commit=True)
+
+    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Logs training metrics."""
+        if logs is not None:
+            logs["lr"] = self.model.optimizer.param_groups[0]["lr"]
+            self._on_batch_end(batch, logs)
+
+    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Logs validation metrics."""
+        if logs is not None:
+            self._on_batch_end(batch, logs)
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Logs at epoch end."""
+        wandb.log(logs, commit=True)
+
+
+class WandbImageLogger(Callback):
+    """Custom W&B callback for image logging."""
+
+    def __init__(
+        self,
+        example_indices: Optional[List] = None,
+        num_examples: int = 4,
+        transform: Optional[bool] = None,
+    ) -> None:
+        """Initializes the WandbImageLogger with the model to train.
+
+        Args:
+            example_indices (Optional[List]): Indices for validation images. Defaults to None.
+            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+            transform (Optional[Dict]): Use transform on image or not. Defaults to None.
+
+        """
+
+        super().__init__()
+        self.caption = None
+        self.example_indices = example_indices
+        self.test_sample_indices = None
+        self.num_examples = num_examples
+        self.transform = (
+            self._configure_transform(transform) if transform is not None else None
+        )
+
+    def _configure_transform(self, transform: Dict) -> Callable:
+        args = transform["args"] or {}
+        return getattr(transforms, transform["type"])(**args)
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Sets the model and extracts validation images from the dataset."""
+        self.model = model
+        self.caption = "Validation Examples"
+        if self.example_indices is None:
+            self.example_indices = np.random.randint(
+                0, len(self.model.val_dataset), self.num_examples
+            )
+        self.images = self.model.val_dataset.dataset.data[self.example_indices]
+        self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
+        self.targets = self.targets.tolist()
+
+    def on_test_begin(self) -> None:
+        """Get samples from test dataset."""
+        self.caption = "Test Examples"
+        if self.test_sample_indices is None:
+            self.test_sample_indices = np.random.randint(
+                0, len(self.model.test_dataset), self.num_examples
+            )
+        self.images = self.model.test_dataset.data[self.test_sample_indices]
+        self.targets = self.model.test_dataset.targets[self.test_sample_indices]
+        self.targets = self.targets.tolist()
+
+    def on_test_end(self) -> None:
+        """Log test images."""
+        self.on_epoch_end(0, {})
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Get network predictions on validation images."""
+        images = []
+        for i, image in enumerate(self.images):
+            image = self.transform(image) if self.transform is not None else image
+            pred, conf = self.model.predict_on_image(image)
+            if isinstance(self.targets[i], list):
+                ground_truth = "".join(
+                    [
+                        self.model.mapper(int(target_index) - 26)
+                        if target_index > 35
+                        else self.model.mapper(int(target_index))
+                        for target_index in self.targets[i]
+                    ]
+                ).rstrip("_")
+            else:
+                ground_truth = self.model.mapper(int(self.targets[i]))
+            caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
+            images.append(wandb.Image(image, caption=caption))
+
+        wandb.log({f"{self.caption}": images}, commit=False)
+
+
+class WandbSegmentationLogger(Callback):
+    """Custom W&B callback for image logging."""
+
+    def __init__(
+        self,
+        class_labels: Dict,
+        example_indices: Optional[List] = None,
+        num_examples: int = 4,
+    ) -> None:
+        """Initializes the WandbImageLogger with the model to train.
+
+        Args:
+            class_labels (Dict): A dict with int as key and class string as value.
+            example_indices (Optional[List]): Indices for validation images. Defaults to None.
+            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+
+        """
+
+        super().__init__()
+        self.caption = None
+        self.class_labels = {int(k): v for k, v in class_labels.items()}
+        self.example_indices = example_indices
+        self.test_sample_indices = None
+        self.num_examples = num_examples
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Sets the model and extracts validation images from the dataset."""
+        self.model = model
+        self.caption = "Validation Segmentation Examples"
+        if self.example_indices is None:
+            self.example_indices = np.random.randint(
+                0, len(self.model.val_dataset), self.num_examples
+            )
+        self.images = self.model.val_dataset.dataset.data[self.example_indices]
+        self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
+        self.targets = self.targets.tolist()
+
+    def on_test_begin(self) -> None:
+        """Get samples from test dataset."""
+        self.caption = "Test Segmentation Examples"
+        if self.test_sample_indices is None:
+            self.test_sample_indices = np.random.randint(
+                0, len(self.model.test_dataset), self.num_examples
+            )
+        self.images = self.model.test_dataset.data[self.test_sample_indices]
+        self.targets = self.model.test_dataset.targets[self.test_sample_indices]
+        self.targets = self.targets.tolist()
+
+    def on_test_end(self) -> None:
+        """Log test images."""
+        self.on_epoch_end(0, {})
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Get network predictions on validation images."""
+        images = []
+        for i, image in enumerate(self.images):
+            pred_mask = (
+                self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
+            )
+            gt_mask = np.array(self.targets[i])
+            images.append(
+                wandb.Image(
+                    image,
+                    masks={
+                        "predictions": {
+                            "mask_data": pred_mask,
+                            "class_labels": self.class_labels,
+                        },
+                        "ground_truth": {
+                            "mask_data": gt_mask,
+                            "class_labels": self.class_labels,
+                        },
+                    },
+                )
+            )
+
+        wandb.log({f"{self.caption}": images}, commit=False)
+
+
+class WandbReconstructionLogger(Callback):
+    """Custom W&B callback for image reconstructions logging."""
+
+    def __init__(
+        self, example_indices: Optional[List] = None, num_examples: int = 4,
+    ) -> None:
+        """Initializes the  WandbImageLogger with the model to train.
+
+        Args:
+            example_indices (Optional[List]): Indices for validation images. Defaults to None.
+            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+
+        """
+
+        super().__init__()
+        self.caption = None
+        self.example_indices = example_indices
+        self.test_sample_indices = None
+        self.num_examples = num_examples
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Sets the model and extracts validation images from the dataset."""
+        self.model = model
+        self.caption = "Validation Reconstructions Examples"
+        if self.example_indices is None:
+            self.example_indices = np.random.randint(
+                0, len(self.model.val_dataset), self.num_examples
+            )
+        self.images = self.model.val_dataset.dataset.data[self.example_indices]
+
+    def on_test_begin(self) -> None:
+        """Get samples from test dataset."""
+        self.caption = "Test Reconstructions Examples"
+        if self.test_sample_indices is None:
+            self.test_sample_indices = np.random.randint(
+                0, len(self.model.test_dataset), self.num_examples
+            )
+        self.images = self.model.test_dataset.data[self.test_sample_indices]
+
+    def on_test_end(self) -> None:
+        """Log test images."""
+        self.on_epoch_end(0, {})
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Get network predictions on validation images."""
+        images = []
+        for image in self.images:
+            reconstructed_image = (
+                self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
+            )
+            images.append(image)
+            images.append(reconstructed_image)
+
+        wandb.log(
+            {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False,
+        )
-- 
cgit v1.2.3-70-g09d2