From 4a54d7e690897dd6e6c719fb908fd371a44c2952 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 24 Jan 2021 22:14:17 +0100 Subject: Many updates, cool stuff on the way. --- src/training/trainer/callbacks/__init__.py | 8 +++- src/training/trainer/callbacks/wandb_callbacks.py | 58 +++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) (limited to 'src/training/trainer/callbacks') diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index 95ec142..80c4177 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -7,7 +7,12 @@ from .lr_schedulers import ( SWA, ) from .progress_bar import ProgressBar -from .wandb_callbacks import WandbCallback, WandbImageLogger, WandbSegmentationLogger +from .wandb_callbacks import ( + WandbCallback, + WandbImageLogger, + WandbReconstructionLogger, + WandbSegmentationLogger, +) __all__ = [ "Callback", @@ -17,6 +22,7 @@ __all__ = [ "LRScheduler", "WandbCallback", "WandbImageLogger", + "WandbReconstructionLogger", "WandbSegmentationLogger", "ProgressBar", "SWA", diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index 20414df..552a4f4 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -201,3 +201,61 @@ class WandbSegmentationLogger(Callback): ) wandb.log({f"{self.caption}": images}, commit=False) + + +class WandbReconstructionLogger(Callback): + """Custom W&B callback for image reconstructions logging.""" + + def __init__( + self, example_indices: Optional[List] = None, num_examples: int = 4, + ) -> None: + """Initializes the WandbImageLogger with the model to train. + + Args: + example_indices (Optional[List]): Indices for validation images. Defaults to None. + num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. + + """ + + super().__init__() + self.caption = None + self.example_indices = example_indices + self.test_sample_indices = None + self.num_examples = num_examples + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and extracts validation images from the dataset.""" + self.model = model + self.caption = "Validation Reconstructions Examples" + if self.example_indices is None: + self.example_indices = np.random.randint( + 0, len(self.model.val_dataset), self.num_examples + ) + self.images = self.model.val_dataset.dataset.data[self.example_indices] + + def on_test_begin(self) -> None: + """Get samples from test dataset.""" + self.caption = "Test Reconstructions Examples" + if self.test_sample_indices is None: + self.test_sample_indices = np.random.randint( + 0, len(self.model.test_dataset), self.num_examples + ) + self.images = self.model.test_dataset.data[self.test_sample_indices] + + def on_test_end(self) -> None: + """Log test images.""" + self.on_epoch_end(0, {}) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Get network predictions on validation images.""" + images = [] + for image in self.images: + reconstructed_image = ( + self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() + ) + images.append(image) + images.append(reconstructed_image) + + wandb.log( + {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False, + ) -- cgit v1.2.3-70-g09d2