summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-12-07 22:54:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-12-07 22:54:04 +0100
commit25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (patch)
tree526ba739714b3d040f7810c1a6be3ff0ba37fdb1 /src/training/trainer/callbacks
parent5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (diff)
Segmentation working!
Diffstat (limited to 'src/training/trainer/callbacks')
-rw-r--r--src/training/trainer/callbacks/__init__.py3
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py97
2 files changed, 92 insertions, 8 deletions
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index e1bd858..95ec142 100644
--- a/src/training/trainer/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -7,7 +7,7 @@ from .lr_schedulers import (
SWA,
)
from .progress_bar import ProgressBar
-from .wandb_callbacks import WandbCallback, WandbImageLogger
+from .wandb_callbacks import WandbCallback, WandbImageLogger, WandbSegmentationLogger
__all__ = [
"Callback",
@@ -17,6 +17,7 @@ __all__ = [
"LRScheduler",
"WandbCallback",
"WandbImageLogger",
+ "WandbSegmentationLogger",
"ProgressBar",
"SWA",
]
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index 1627f17..df1fd8f 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -2,12 +2,10 @@
from typing import Callable, Dict, List, Optional, Type
import numpy as np
-import torch
-from torchvision.transforms import ToTensor
from training.trainer.callbacks import Callback
import wandb
-from text_recognizer.datasets import Transpose
+import text_recognizer.datasets.transforms as transforms
from text_recognizer.models.base import Model
@@ -52,14 +50,14 @@ class WandbImageLogger(Callback):
self,
example_indices: Optional[List] = None,
num_examples: int = 4,
- use_transpose: Optional[bool] = False,
+ transform: Optional[bool] = None,
) -> None:
"""Initializes the WandbImageLogger with the model to train.
Args:
example_indices (Optional[List]): Indices for validation images. Defaults to None.
num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
- use_transpose (Optional[bool]): Use transpose on image or not. Defaults to False.
+ transform (Optional[Dict]): Use transform on image or not. Defaults to None.
"""
@@ -68,7 +66,13 @@ class WandbImageLogger(Callback):
self.example_indices = example_indices
self.test_sample_indices = None
self.num_examples = num_examples
- self.transpose = Transpose() if use_transpose else None
+ self.transform = (
+ self._configure_transform(transform) if transform is not None else None
+ )
+
+ def _configure_transform(self, transform: Dict) -> Callable:
+ args = transform["args"] or {}
+ return getattr(transforms, transform["type"])(**args)
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and extracts validation images from the dataset."""
@@ -101,7 +105,7 @@ class WandbImageLogger(Callback):
"""Get network predictions on validation images."""
images = []
for i, image in enumerate(self.images):
- image = self.transpose(image) if self.transpose is not None else image
+ image = self.transform(image) if self.transform is not None else image
pred, conf = self.model.predict_on_image(image)
if isinstance(self.targets[i], list):
ground_truth = "".join(
@@ -116,3 +120,82 @@ class WandbImageLogger(Callback):
images.append(wandb.Image(image, caption=caption))
wandb.log({f"{self.caption}": images}, commit=False)
+
+
+class WandbSegmentationLogger(Callback):
+ """Custom W&B callback for image logging."""
+
+ def __init__(
+ self,
+ class_labels: Dict,
+ example_indices: Optional[List] = None,
+ num_examples: int = 4,
+ ) -> None:
+ """Initializes the WandbImageLogger with the model to train.
+
+ Args:
+ class_labels (Dict): A dict with int as key and class string as value.
+ example_indices (Optional[List]): Indices for validation images. Defaults to None.
+ num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+
+ """
+
+ super().__init__()
+ self.caption = None
+ self.class_labels = {int(k): v for k, v in class_labels.items()}
+ self.example_indices = example_indices
+ self.test_sample_indices = None
+ self.num_examples = num_examples
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and extracts validation images from the dataset."""
+ self.model = model
+ self.caption = "Validation Segmentation Examples"
+ if self.example_indices is None:
+ self.example_indices = np.random.randint(
+ 0, len(self.model.val_dataset), self.num_examples
+ )
+ self.images = self.model.val_dataset.dataset.data[self.example_indices]
+ self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_begin(self) -> None:
+ """Get samples from test dataset."""
+ self.caption = "Test Segmentation Examples"
+ if self.test_sample_indices is None:
+ self.test_sample_indices = np.random.randint(
+ 0, len(self.model.test_dataset), self.num_examples
+ )
+ self.images = self.model.test_dataset.data[self.test_sample_indices]
+ self.targets = self.model.test_dataset.targets[self.test_sample_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_end(self) -> None:
+ """Log test images."""
+ self.on_epoch_end(0, {})
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Get network predictions on validation images."""
+ images = []
+ for i, image in enumerate(self.images):
+ pred_mask = (
+ self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
+ )
+ gt_mask = np.array(self.targets[i])
+ images.append(
+ wandb.Image(
+ image,
+ masks={
+ "predictions": {
+ "mask_data": pred_mask,
+ "class_labels": self.class_labels,
+ },
+ "ground_truth": {
+ "mask_data": gt_mask,
+ "class_labels": self.class_labels,
+ },
+ },
+ )
+ )
+
+ wandb.log({f"{self.caption}": images}, commit=False)