summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/wandb_callbacks.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer/callbacks/wandb_callbacks.py')
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py34
1 files changed, 26 insertions, 8 deletions
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index d2df4d7..1627f17 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -64,37 +64,55 @@ class WandbImageLogger(Callback):
"""
super().__init__()
+ self.caption = None
self.example_indices = example_indices
+ self.test_sample_indices = None
self.num_examples = num_examples
self.transpose = Transpose() if use_transpose else None
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and extracts validation images from the dataset."""
self.model = model
+ self.caption = "Validation Examples"
if self.example_indices is None:
self.example_indices = np.random.randint(
0, len(self.model.val_dataset), self.num_examples
)
- self.val_images = self.model.val_dataset.dataset.data[self.example_indices]
- self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices]
- self.val_targets = self.val_targets.tolist()
+ self.images = self.model.val_dataset.dataset.data[self.example_indices]
+ self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_begin(self) -> None:
+ """Get samples from test dataset."""
+ self.caption = "Test Examples"
+ if self.test_sample_indices is None:
+ self.test_sample_indices = np.random.randint(
+ 0, len(self.model.test_dataset), self.num_examples
+ )
+ self.images = self.model.test_dataset.data[self.test_sample_indices]
+ self.targets = self.model.test_dataset.targets[self.test_sample_indices]
+ self.targets = self.targets.tolist()
+
+ def on_test_end(self) -> None:
+ """Log test images."""
+ self.on_epoch_end(0, {})
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
"""Get network predictions on validation images."""
images = []
- for i, image in enumerate(self.val_images):
+ for i, image in enumerate(self.images):
image = self.transpose(image) if self.transpose is not None else image
pred, conf = self.model.predict_on_image(image)
- if isinstance(self.val_targets[i], list):
+ if isinstance(self.targets[i], list):
ground_truth = "".join(
[
self.model.mapper(int(target_index))
- for target_index in self.val_targets[i]
+ for target_index in self.targets[i]
]
).rstrip("_")
else:
- ground_truth = self.val_targets[i]
+ ground_truth = self.model.mapper(int(self.targets[i]))
caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
images.append(wandb.Image(image, caption=caption))
- wandb.log({"examples": images}, commit=False)
+ wandb.log({f"{self.caption}": images}, commit=False)