summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/wandb_callbacks.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer/callbacks/wandb_callbacks.py')
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py32
1 files changed, 19 insertions, 13 deletions
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index e44c745..6643a44 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -2,7 +2,8 @@
from typing import Callable, Dict, List, Optional, Type
import numpy as np
-from torchvision.transforms import Compose, ToTensor
+import torch
+from torchvision.transforms import ToTensor
from training.trainer.callbacks import Callback
import wandb
@@ -50,43 +51,48 @@ class WandbImageLogger(Callback):
self,
example_indices: Optional[List] = None,
num_examples: int = 4,
- transfroms: Optional[Callable] = None,
+ use_transpose: Optional[bool] = False,
) -> None:
"""Initializes the WandbImageLogger with the model to train.
Args:
example_indices (Optional[List]): Indices for validation images. Defaults to None.
num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
- transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to
- None.
+ use_transpose (Optional[bool]): Use transpose on image or not. Defaults to False.
"""
super().__init__()
self.example_indices = example_indices
self.num_examples = num_examples
- self.transfroms = transfroms
- if self.transfroms is None:
- self.transforms = Compose([Transpose()])
+ self.transpose = Transpose() if use_transpose else None
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and extracts validation images from the dataset."""
self.model = model
- data_loader = self.model.data_loaders["val"]
if self.example_indices is None:
self.example_indices = np.random.randint(
- 0, len(data_loader.dataset.data), self.num_examples
+ 0, len(self.model.val_dataset), self.num_examples
)
- self.val_images = data_loader.dataset.data[self.example_indices]
- self.val_targets = data_loader.dataset.targets[self.example_indices].numpy()
+ self.val_images = self.model.val_dataset.dataset.data[self.example_indices]
+ self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices]
+ self.val_targets = self.val_targets.tolist()
def on_epoch_end(self, epoch: int, logs: Dict) -> None:
"""Get network predictions on validation images."""
images = []
for i, image in enumerate(self.val_images):
- image = self.transforms(image)
+ image = self.transpose(image) if self.transpose is not None else image
pred, conf = self.model.predict_on_image(image)
- ground_truth = self.model.mapper(int(self.val_targets[i]))
+ if isinstance(self.val_targets[i], list):
+ ground_truth = "".join(
+ [
+ self.model.mapper(int(target_index))
+ for target_index in self.val_targets[i]
+ ]
+ ).rstrip("_")
+ else:
+ ground_truth = self.val_targets[i]
caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
images.append(wandb.Image(image, caption=caption))