diff options
Diffstat (limited to 'training/trainer/callbacks')
| -rw-r--r-- | training/trainer/callbacks/__init__.py | 29 | ||||
| -rw-r--r-- | training/trainer/callbacks/base.py | 188 | ||||
| -rw-r--r-- | training/trainer/callbacks/checkpoint.py | 95 | ||||
| -rw-r--r-- | training/trainer/callbacks/early_stopping.py | 108 | ||||
| -rw-r--r-- | training/trainer/callbacks/lr_schedulers.py | 77 | ||||
| -rw-r--r-- | training/trainer/callbacks/progress_bar.py | 65 | ||||
| -rw-r--r-- | training/trainer/callbacks/wandb_callbacks.py | 261 | 
7 files changed, 0 insertions, 823 deletions
| diff --git a/training/trainer/callbacks/__init__.py b/training/trainer/callbacks/__init__.py deleted file mode 100644 index 80c4177..0000000 --- a/training/trainer/callbacks/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -"""The callback modules used in the training script.""" -from .base import Callback, CallbackList -from .checkpoint import Checkpoint -from .early_stopping import EarlyStopping -from .lr_schedulers import ( -    LRScheduler, -    SWA, -) -from .progress_bar import ProgressBar -from .wandb_callbacks import ( -    WandbCallback, -    WandbImageLogger, -    WandbReconstructionLogger, -    WandbSegmentationLogger, -) - -__all__ = [ -    "Callback", -    "CallbackList", -    "Checkpoint", -    "EarlyStopping", -    "LRScheduler", -    "WandbCallback", -    "WandbImageLogger", -    "WandbReconstructionLogger", -    "WandbSegmentationLogger", -    "ProgressBar", -    "SWA", -] diff --git a/training/trainer/callbacks/base.py b/training/trainer/callbacks/base.py deleted file mode 100644 index 500b642..0000000 --- a/training/trainer/callbacks/base.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Metaclass for callback functions.""" - -from enum import Enum -from typing import Callable, Dict, List, Optional, Type, Union - -from loguru import logger -import numpy as np -import torch - -from text_recognizer.models import Model - - -class ModeKeys: -    """Mode keys for CallbackList.""" - -    TRAIN = "train" -    VALIDATION = "validation" - - -class Callback: -    """Metaclass for callbacks used in training.""" - -    def __init__(self) -> None: -        """Initializes the Callback instance.""" -        self.model = None - -    def set_model(self, model: Type[Model]) -> None: -        """Set the model.""" -        self.model = model - -    def on_fit_begin(self) -> None: -        """Called when fit begins.""" -        pass - -    def on_fit_end(self) -> None: -        """Called when fit ends.""" -        pass - -    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Called at the beginning of an epoch. Only used in training mode.""" -        pass - -    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch. Only used in training mode.""" -        pass - -    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the beginning of an epoch.""" -        pass - -    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch.""" -        pass - -    def on_validation_batch_begin( -        self, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Called at the beginning of an epoch.""" -        pass - -    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch.""" -        pass - -    def on_test_begin(self) -> None: -        """Called at the beginning of test.""" -        pass - -    def on_test_end(self) -> None: -        """Called at the end of test.""" -        pass - - -class CallbackList: -    """Container for abstracting away callback calls.""" - -    mode_keys = ModeKeys() - -    def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: -        """Container for `Callback` instances. - -        This object wraps a list of `Callback` instances and allows them all to be -        called via a single end point. - -        Args: -            model (Type[Model]): A `Model` instance. -            callbacks (List[Callback]): List of `Callback` instances. Defaults to None. - -        """ - -        self._callbacks = callbacks or [] -        if model: -            self.set_model(model) - -    def set_model(self, model: Type[Model]) -> None: -        """Set the model for all callbacks.""" -        self.model = model -        for callback in self._callbacks: -            callback.set_model(model=self.model) - -    def append(self, callback: Type[Callback]) -> None: -        """Append new callback to callback list.""" -        self._callbacks.append(callback) - -    def on_fit_begin(self) -> None: -        """Called when fit begins.""" -        for callback in self._callbacks: -            callback.on_fit_begin() - -    def on_fit_end(self) -> None: -        """Called when fit ends.""" -        for callback in self._callbacks: -            callback.on_fit_end() - -    def on_test_begin(self) -> None: -        """Called when test begins.""" -        for callback in self._callbacks: -            callback.on_test_begin() - -    def on_test_end(self) -> None: -        """Called when test ends.""" -        for callback in self._callbacks: -            callback.on_test_end() - -    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Called at the beginning of an epoch.""" -        for callback in self._callbacks: -            callback.on_epoch_begin(epoch, logs) - -    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch.""" -        for callback in self._callbacks: -            callback.on_epoch_end(epoch, logs) - -    def _call_batch_hook( -        self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Helper function for all batch_{begin | end} methods.""" -        if hook == "begin": -            self._call_batch_begin_hook(mode, batch, logs) -        elif hook == "end": -            self._call_batch_end_hook(mode, batch, logs) -        else: -            raise ValueError(f"Unrecognized hook {hook}.") - -    def _call_batch_begin_hook( -        self, mode: str, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Helper function for all `on_*_batch_begin` methods.""" -        hook_name = f"on_{mode}_batch_begin" -        self._call_batch_hook_helper(hook_name, batch, logs) - -    def _call_batch_end_hook( -        self, mode: str, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Helper function for all `on_*_batch_end` methods.""" -        hook_name = f"on_{mode}_batch_end" -        self._call_batch_hook_helper(hook_name, batch, logs) - -    def _call_batch_hook_helper( -        self, hook_name: str, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Helper function for `on_*_batch_begin` methods.""" -        for callback in self._callbacks: -            hook = getattr(callback, hook_name) -            hook(batch, logs) - -    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the beginning of an epoch.""" -        self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) - -    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch.""" -        self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) - -    def on_validation_batch_begin( -        self, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Called at the beginning of an epoch.""" -        self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) - -    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch.""" -        self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) - -    def __iter__(self) -> iter: -        """Iter function for callback list.""" -        return iter(self._callbacks) diff --git a/training/trainer/callbacks/checkpoint.py b/training/trainer/callbacks/checkpoint.py deleted file mode 100644 index a54e0a9..0000000 --- a/training/trainer/callbacks/checkpoint.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Callback checkpoint for training models.""" -from enum import Enum -from pathlib import Path -from typing import Callable, Dict, List, Optional, Type, Union - -from loguru import logger -import numpy as np -import torch -from training.trainer.callbacks import Callback - -from text_recognizer.models import Model - - -class Checkpoint(Callback): -    """Saving model parameters at the end of each epoch.""" - -    mode_dict = { -        "min": torch.lt, -        "max": torch.gt, -    } - -    def __init__( -        self, -        checkpoint_path: Union[str, Path], -        monitor: str = "accuracy", -        mode: str = "auto", -        min_delta: float = 0.0, -    ) -> None: -        """Monitors a quantity that will allow us to determine the best model weights. - -        Args: -            checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. -            monitor (str): Name of the quantity to monitor. Defaults to "accuracy". -            mode (str): Description of parameter `mode`. Defaults to "auto". -            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - -        """ -        super().__init__() -        self.checkpoint_path = Path(checkpoint_path) -        self.monitor = monitor -        self.mode = mode -        self.min_delta = torch.tensor(min_delta) - -        if mode not in ["auto", "min", "max"]: -            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - -            self.mode = "auto" - -        if self.mode == "auto": -            if "accuracy" in self.monitor: -                self.mode = "max" -            else: -                self.mode = "min" -            logger.debug( -                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." -            ) - -        torch_inf = torch.tensor(np.inf) -        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 -        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - -    @property -    def monitor_op(self) -> float: -        """Returns the comparison method.""" -        return self.mode_dict[self.mode] - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Saves a checkpoint for the network parameters. - -        Args: -            epoch (int): The current epoch. -            logs (Dict): The log containing the monitored metrics. - -        """ -        current = self.get_monitor_value(logs) -        if current is None: -            return -        if self.monitor_op(current - self.min_delta, self.best_score): -            self.best_score = current -            is_best = True -        else: -            is_best = False - -        self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor) - -    def get_monitor_value(self, logs: Dict) -> Union[float, None]: -        """Extracts the monitored value.""" -        monitor_value = logs.get(self.monitor) -        if monitor_value is None: -            logger.warning( -                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" -                + f" metrics are: {','.join(list(logs.keys()))}" -            ) -            return None -        return monitor_value diff --git a/training/trainer/callbacks/early_stopping.py b/training/trainer/callbacks/early_stopping.py deleted file mode 100644 index 02b431f..0000000 --- a/training/trainer/callbacks/early_stopping.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Implements Early stopping for PyTorch model.""" -from typing import Dict, Union - -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from training.trainer.callbacks import Callback - - -class EarlyStopping(Callback): -    """Stops training when a monitored metric stops improving.""" - -    mode_dict = { -        "min": torch.lt, -        "max": torch.gt, -    } - -    def __init__( -        self, -        monitor: str = "val_loss", -        min_delta: float = 0.0, -        patience: int = 3, -        mode: str = "auto", -    ) -> None: -        """Initializes the EarlyStopping callback. - -        Args: -            monitor (str): Description of parameter `monitor`. Defaults to "val_loss". -            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. -            patience (int): Description of parameter `patience`. Defaults to 3. -            mode (str): Description of parameter `mode`. Defaults to "auto". - -        """ -        super().__init__() -        self.monitor = monitor -        self.patience = patience -        self.min_delta = torch.tensor(min_delta) -        self.mode = mode -        self.wait_count = 0 -        self.stopped_epoch = 0 - -        if mode not in ["auto", "min", "max"]: -            logger.warning( -                f"EarlyStopping mode {mode} is unkown, fallback to auto mode." -            ) - -            self.mode = "auto" - -        if self.mode == "auto": -            if "accuracy" in self.monitor: -                self.mode = "max" -            else: -                self.mode = "min" -            logger.debug( -                f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." -            ) - -        self.torch_inf = torch.tensor(np.inf) -        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 -        self.best_score = ( -            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf -        ) - -    @property -    def monitor_op(self) -> float: -        """Returns the comparison method.""" -        return self.mode_dict[self.mode] - -    def on_fit_begin(self) -> Union[torch.lt, torch.gt]: -        """Reset the early stopping variables for reuse.""" -        self.wait_count = 0 -        self.stopped_epoch = 0 -        self.best_score = ( -            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf -        ) - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Computes the early stop criterion.""" -        current = self.get_monitor_value(logs) -        if current is None: -            return -        if self.monitor_op(current - self.min_delta, self.best_score): -            self.best_score = current -            self.wait_count = 0 -        else: -            self.wait_count += 1 -            if self.wait_count >= self.patience: -                self.stopped_epoch = epoch -                self.model.stop_training = True - -    def on_fit_end(self) -> None: -        """Logs if early stopping was used.""" -        if self.stopped_epoch > 0: -            logger.info( -                f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." -            ) - -    def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: -        """Extracts the monitor value.""" -        monitor_value = logs.get(self.monitor) -        if monitor_value is None: -            logger.warning( -                f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" -                + f"metrics are: {','.join(list(logs.keys()))}" -            ) -            return None -        return torch.tensor(monitor_value) diff --git a/training/trainer/callbacks/lr_schedulers.py b/training/trainer/callbacks/lr_schedulers.py deleted file mode 100644 index 630c434..0000000 --- a/training/trainer/callbacks/lr_schedulers.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Callbacks for learning rate schedulers.""" -from typing import Callable, Dict, List, Optional, Type - -from torch.optim.swa_utils import update_bn -from training.trainer.callbacks import Callback - -from text_recognizer.models import Model - - -class LRScheduler(Callback): -    """Generic learning rate scheduler callback.""" - -    def __init__(self) -> None: -        super().__init__() - -    def set_model(self, model: Type[Model]) -> None: -        """Sets the model and lr scheduler.""" -        self.model = model -        self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] -        self.interval = self.model.lr_scheduler["interval"] - -    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Takes a step at the end of every epoch.""" -        if self.interval == "epoch": -            if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__: -                self.lr_scheduler.step(logs["val_loss"]) -            else: -                self.lr_scheduler.step() - -    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Takes a step at the end of every training batch.""" -        if self.interval == "step": -            self.lr_scheduler.step() - - -class SWA(Callback): -    """Stochastic Weight Averaging callback.""" - -    def __init__(self) -> None: -        """Initializes the callback.""" -        super().__init__() -        self.lr_scheduler = None -        self.interval = None -        self.swa_scheduler = None -        self.swa_start = None -        self.current_epoch = 1 - -    def set_model(self, model: Type[Model]) -> None: -        """Sets the model and lr scheduler.""" -        self.model = model -        self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] -        self.interval = self.model.lr_scheduler["interval"] -        self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"] -        self.swa_start = self.model.swa_scheduler["swa_start"] - -    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Takes a step at the end of every training batch.""" -        if epoch > self.swa_start: -            self.model.swa_network.update_parameters(self.model.network) -            self.swa_scheduler.step() -        elif self.interval == "epoch": -            self.lr_scheduler.step() -        self.current_epoch = epoch - -    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Takes a step at the end of every training batch.""" -        if self.current_epoch < self.swa_start and self.interval == "step": -            self.lr_scheduler.step() - -    def on_fit_end(self) -> None: -        """Update batch norm statistics for the swa model at the end of training.""" -        if self.model.swa_network: -            update_bn( -                self.model.val_dataloader(), -                self.model.swa_network, -                device=self.model.device, -            ) diff --git a/training/trainer/callbacks/progress_bar.py b/training/trainer/callbacks/progress_bar.py deleted file mode 100644 index 6c4305a..0000000 --- a/training/trainer/callbacks/progress_bar.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Progress bar callback for the training loop.""" -from typing import Dict, Optional - -from tqdm import tqdm -from training.trainer.callbacks import Callback - - -class ProgressBar(Callback): -    """A TQDM progress bar for the training loop.""" - -    def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: -        """Initializes the tqdm callback.""" -        self.epochs = epochs -        print(epochs, type(epochs)) -        self.log_batch_frequency = log_batch_frequency -        self.progress_bar = None -        self.val_metrics = {} - -    def _configure_progress_bar(self) -> None: -        """Configures the tqdm progress bar with custom bar format.""" -        self.progress_bar = tqdm( -            total=len(self.model.train_dataloader()), -            leave=False, -            unit="steps", -            mininterval=self.log_batch_frequency, -            bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", -        ) - -    def _key_abbreviations(self, logs: Dict) -> Dict: -        """Changes the length of keys, so that the progress bar fits better.""" - -        def rename(key: str) -> str: -            """Renames accuracy to acc.""" -            return key.replace("accuracy", "acc") - -        return {rename(key): value for key, value in logs.items()} - -    # def on_fit_begin(self) -> None: -    #     """Creates a tqdm progress bar.""" -    #     self._configure_progress_bar() - -    def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: -        """Updates the description with the current epoch.""" -        if epoch == 1: -            self._configure_progress_bar() -        else: -            self.progress_bar.reset() -        self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """At the end of each epoch, the validation metrics are updated to the progress bar.""" -        self.val_metrics = logs -        self.progress_bar.set_postfix(**self._key_abbreviations(logs)) -        self.progress_bar.update() - -    def on_train_batch_end(self, batch: int, logs: Dict) -> None: -        """Updates the progress bar for each training step.""" -        if self.val_metrics: -            logs.update(self.val_metrics) -        self.progress_bar.set_postfix(**self._key_abbreviations(logs)) -        self.progress_bar.update() - -    def on_fit_end(self) -> None: -        """Closes the tqdm progress bar.""" -        self.progress_bar.close() diff --git a/training/trainer/callbacks/wandb_callbacks.py b/training/trainer/callbacks/wandb_callbacks.py deleted file mode 100644 index 552a4f4..0000000 --- a/training/trainer/callbacks/wandb_callbacks.py +++ /dev/null @@ -1,261 +0,0 @@ -"""Callback for W&B.""" -from typing import Callable, Dict, List, Optional, Type - -import numpy as np -from training.trainer.callbacks import Callback -import wandb - -import text_recognizer.datasets.transforms as transforms -from text_recognizer.models.base import Model - - -class WandbCallback(Callback): -    """A custom W&B metric logger for the trainer.""" - -    def __init__(self, log_batch_frequency: int = None) -> None: -        """Short summary. - -        Args: -            log_batch_frequency (int): If None, metrics will be logged every epoch. -                If set to an integer, callback will log every metrics every log_batch_frequency. - -        """ -        super().__init__() -        self.log_batch_frequency = log_batch_frequency - -    def _on_batch_end(self, batch: int, logs: Dict) -> None: -        if self.log_batch_frequency and batch % self.log_batch_frequency == 0: -            wandb.log(logs, commit=True) - -    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Logs training metrics.""" -        if logs is not None: -            logs["lr"] = self.model.optimizer.param_groups[0]["lr"] -            self._on_batch_end(batch, logs) - -    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Logs validation metrics.""" -        if logs is not None: -            self._on_batch_end(batch, logs) - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Logs at epoch end.""" -        wandb.log(logs, commit=True) - - -class WandbImageLogger(Callback): -    """Custom W&B callback for image logging.""" - -    def __init__( -        self, -        example_indices: Optional[List] = None, -        num_examples: int = 4, -        transform: Optional[bool] = None, -    ) -> None: -        """Initializes the WandbImageLogger with the model to train. - -        Args: -            example_indices (Optional[List]): Indices for validation images. Defaults to None. -            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. -            transform (Optional[Dict]): Use transform on image or not. Defaults to None. - -        """ - -        super().__init__() -        self.caption = None -        self.example_indices = example_indices -        self.test_sample_indices = None -        self.num_examples = num_examples -        self.transform = ( -            self._configure_transform(transform) if transform is not None else None -        ) - -    def _configure_transform(self, transform: Dict) -> Callable: -        args = transform["args"] or {} -        return getattr(transforms, transform["type"])(**args) - -    def set_model(self, model: Type[Model]) -> None: -        """Sets the model and extracts validation images from the dataset.""" -        self.model = model -        self.caption = "Validation Examples" -        if self.example_indices is None: -            self.example_indices = np.random.randint( -                0, len(self.model.val_dataset), self.num_examples -            ) -        self.images = self.model.val_dataset.dataset.data[self.example_indices] -        self.targets = self.model.val_dataset.dataset.targets[self.example_indices] -        self.targets = self.targets.tolist() - -    def on_test_begin(self) -> None: -        """Get samples from test dataset.""" -        self.caption = "Test Examples" -        if self.test_sample_indices is None: -            self.test_sample_indices = np.random.randint( -                0, len(self.model.test_dataset), self.num_examples -            ) -        self.images = self.model.test_dataset.data[self.test_sample_indices] -        self.targets = self.model.test_dataset.targets[self.test_sample_indices] -        self.targets = self.targets.tolist() - -    def on_test_end(self) -> None: -        """Log test images.""" -        self.on_epoch_end(0, {}) - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Get network predictions on validation images.""" -        images = [] -        for i, image in enumerate(self.images): -            image = self.transform(image) if self.transform is not None else image -            pred, conf = self.model.predict_on_image(image) -            if isinstance(self.targets[i], list): -                ground_truth = "".join( -                    [ -                        self.model.mapper(int(target_index) - 26) -                        if target_index > 35 -                        else self.model.mapper(int(target_index)) -                        for target_index in self.targets[i] -                    ] -                ).rstrip("_") -            else: -                ground_truth = self.model.mapper(int(self.targets[i])) -            caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" -            images.append(wandb.Image(image, caption=caption)) - -        wandb.log({f"{self.caption}": images}, commit=False) - - -class WandbSegmentationLogger(Callback): -    """Custom W&B callback for image logging.""" - -    def __init__( -        self, -        class_labels: Dict, -        example_indices: Optional[List] = None, -        num_examples: int = 4, -    ) -> None: -        """Initializes the WandbImageLogger with the model to train. - -        Args: -            class_labels (Dict): A dict with int as key and class string as value. -            example_indices (Optional[List]): Indices for validation images. Defaults to None. -            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - -        """ - -        super().__init__() -        self.caption = None -        self.class_labels = {int(k): v for k, v in class_labels.items()} -        self.example_indices = example_indices -        self.test_sample_indices = None -        self.num_examples = num_examples - -    def set_model(self, model: Type[Model]) -> None: -        """Sets the model and extracts validation images from the dataset.""" -        self.model = model -        self.caption = "Validation Segmentation Examples" -        if self.example_indices is None: -            self.example_indices = np.random.randint( -                0, len(self.model.val_dataset), self.num_examples -            ) -        self.images = self.model.val_dataset.dataset.data[self.example_indices] -        self.targets = self.model.val_dataset.dataset.targets[self.example_indices] -        self.targets = self.targets.tolist() - -    def on_test_begin(self) -> None: -        """Get samples from test dataset.""" -        self.caption = "Test Segmentation Examples" -        if self.test_sample_indices is None: -            self.test_sample_indices = np.random.randint( -                0, len(self.model.test_dataset), self.num_examples -            ) -        self.images = self.model.test_dataset.data[self.test_sample_indices] -        self.targets = self.model.test_dataset.targets[self.test_sample_indices] -        self.targets = self.targets.tolist() - -    def on_test_end(self) -> None: -        """Log test images.""" -        self.on_epoch_end(0, {}) - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Get network predictions on validation images.""" -        images = [] -        for i, image in enumerate(self.images): -            pred_mask = ( -                self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() -            ) -            gt_mask = np.array(self.targets[i]) -            images.append( -                wandb.Image( -                    image, -                    masks={ -                        "predictions": { -                            "mask_data": pred_mask, -                            "class_labels": self.class_labels, -                        }, -                        "ground_truth": { -                            "mask_data": gt_mask, -                            "class_labels": self.class_labels, -                        }, -                    }, -                ) -            ) - -        wandb.log({f"{self.caption}": images}, commit=False) - - -class WandbReconstructionLogger(Callback): -    """Custom W&B callback for image reconstructions logging.""" - -    def __init__( -        self, example_indices: Optional[List] = None, num_examples: int = 4, -    ) -> None: -        """Initializes the  WandbImageLogger with the model to train. - -        Args: -            example_indices (Optional[List]): Indices for validation images. Defaults to None. -            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - -        """ - -        super().__init__() -        self.caption = None -        self.example_indices = example_indices -        self.test_sample_indices = None -        self.num_examples = num_examples - -    def set_model(self, model: Type[Model]) -> None: -        """Sets the model and extracts validation images from the dataset.""" -        self.model = model -        self.caption = "Validation Reconstructions Examples" -        if self.example_indices is None: -            self.example_indices = np.random.randint( -                0, len(self.model.val_dataset), self.num_examples -            ) -        self.images = self.model.val_dataset.dataset.data[self.example_indices] - -    def on_test_begin(self) -> None: -        """Get samples from test dataset.""" -        self.caption = "Test Reconstructions Examples" -        if self.test_sample_indices is None: -            self.test_sample_indices = np.random.randint( -                0, len(self.model.test_dataset), self.num_examples -            ) -        self.images = self.model.test_dataset.data[self.test_sample_indices] - -    def on_test_end(self) -> None: -        """Log test images.""" -        self.on_epoch_end(0, {}) - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Get network predictions on validation images.""" -        images = [] -        for image in self.images: -            reconstructed_image = ( -                self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() -            ) -            images.append(image) -            images.append(reconstructed_image) - -        wandb.log( -            {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False, -        ) |