From 4a54d7e690897dd6e6c719fb908fd371a44c2952 Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Sun, 24 Jan 2021 22:14:17 +0100
Subject: Many updates, cool stuff on the way.

---
 src/training/trainer/callbacks/__init__.py        |  8 +++-
 src/training/trainer/callbacks/wandb_callbacks.py | 58 +++++++++++++++++++++++
 2 files changed, 65 insertions(+), 1 deletion(-)

(limited to 'src/training/trainer/callbacks')

diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index 95ec142..80c4177 100644
--- a/src/training/trainer/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -7,7 +7,12 @@ from .lr_schedulers import (
     SWA,
 )
 from .progress_bar import ProgressBar
-from .wandb_callbacks import WandbCallback, WandbImageLogger, WandbSegmentationLogger
+from .wandb_callbacks import (
+    WandbCallback,
+    WandbImageLogger,
+    WandbReconstructionLogger,
+    WandbSegmentationLogger,
+)
 
 __all__ = [
     "Callback",
@@ -17,6 +22,7 @@ __all__ = [
     "LRScheduler",
     "WandbCallback",
     "WandbImageLogger",
+    "WandbReconstructionLogger",
     "WandbSegmentationLogger",
     "ProgressBar",
     "SWA",
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index 20414df..552a4f4 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -201,3 +201,61 @@ class WandbSegmentationLogger(Callback):
             )
 
         wandb.log({f"{self.caption}": images}, commit=False)
+
+
+class WandbReconstructionLogger(Callback):
+    """Custom W&B callback for image reconstructions logging."""
+
+    def __init__(
+        self, example_indices: Optional[List] = None, num_examples: int = 4,
+    ) -> None:
+        """Initializes the  WandbImageLogger with the model to train.
+
+        Args:
+            example_indices (Optional[List]): Indices for validation images. Defaults to None.
+            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+
+        """
+
+        super().__init__()
+        self.caption = None
+        self.example_indices = example_indices
+        self.test_sample_indices = None
+        self.num_examples = num_examples
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Sets the model and extracts validation images from the dataset."""
+        self.model = model
+        self.caption = "Validation Reconstructions Examples"
+        if self.example_indices is None:
+            self.example_indices = np.random.randint(
+                0, len(self.model.val_dataset), self.num_examples
+            )
+        self.images = self.model.val_dataset.dataset.data[self.example_indices]
+
+    def on_test_begin(self) -> None:
+        """Get samples from test dataset."""
+        self.caption = "Test Reconstructions Examples"
+        if self.test_sample_indices is None:
+            self.test_sample_indices = np.random.randint(
+                0, len(self.model.test_dataset), self.num_examples
+            )
+        self.images = self.model.test_dataset.data[self.test_sample_indices]
+
+    def on_test_end(self) -> None:
+        """Log test images."""
+        self.on_epoch_end(0, {})
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Get network predictions on validation images."""
+        images = []
+        for image in self.images:
+            reconstructed_image = (
+                self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
+            )
+            images.append(image)
+            images.append(reconstructed_image)
+
+        wandb.log(
+            {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False,
+        )
-- 
cgit v1.2.3-70-g09d2