summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/wandb_callbacks.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
commit4a54d7e690897dd6e6c719fb908fd371a44c2952 (patch)
tree04722ac94b9c3960baa5db7939d7ef01dbf535a6 /src/training/trainer/callbacks/wandb_callbacks.py
parentd691b548cd0b6fc4ea184d64261f633789fee021 (diff)
Many updates, cool stuff on the way.
Diffstat (limited to 'src/training/trainer/callbacks/wandb_callbacks.py')
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index 20414df..552a4f4 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -201,3 +201,61 @@ class WandbSegmentationLogger(Callback):
)
wandb.log({f"{self.caption}": images}, commit=False)
+
+
+class WandbReconstructionLogger(Callback):
+ """Custom W&B callback for image reconstructions logging."""
+
+ def __init__(
+ self, example_indices: Optional[List] = None, num_examples: int = 4,
+ ) -> None:
+ """Initializes the WandbImageLogger with the model to train.
+
+ Args:
+ example_indices (Optional[List]): Indices for validation images. Defaults to None.
+ num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+
+ """
+
+ super().__init__()
+ self.caption = None
+ self.example_indices = example_indices
+ self.test_sample_indices = None
+ self.num_examples = num_examples
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and extracts validation images from the dataset."""
+ self.model = model
+ self.caption = "Validation Reconstructions Examples"
+ if self.example_indices is None:
+ self.example_indices = np.random.randint(
+ 0, len(self.model.val_dataset), self.num_examples
+ )
+ self.images = self.model.val_dataset.dataset.data[self.example_indices]
+
+ def on_test_begin(self) -> None:
+ """Get samples from test dataset."""
+ self.caption = "Test Reconstructions Examples"
+ if self.test_sample_indices is None:
+ self.test_sample_indices = np.random.randint(
+ 0, len(self.model.test_dataset), self.num_examples
+ )
+ self.images = self.model.test_dataset.data[self.test_sample_indices]
+
+ def on_test_end(self) -> None:
+ """Log test images."""
+ self.on_epoch_end(0, {})
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Get network predictions on validation images."""
+ images = []
+ for image in self.images:
+ reconstructed_image = (
+ self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
+ )
+ images.append(image)
+ images.append(reconstructed_image)
+
+ wandb.log(
+ {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False,
+ )