diff options
Diffstat (limited to 'src/training/trainer/callbacks/wandb_callbacks.py')
-rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 32 |
1 files changed, 19 insertions, 13 deletions
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index e44c745..6643a44 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -2,7 +2,8 @@ from typing import Callable, Dict, List, Optional, Type import numpy as np -from torchvision.transforms import Compose, ToTensor +import torch +from torchvision.transforms import ToTensor from training.trainer.callbacks import Callback import wandb @@ -50,43 +51,48 @@ class WandbImageLogger(Callback): self, example_indices: Optional[List] = None, num_examples: int = 4, - transfroms: Optional[Callable] = None, + use_transpose: Optional[bool] = False, ) -> None: """Initializes the WandbImageLogger with the model to train. Args: example_indices (Optional[List]): Indices for validation images. Defaults to None. num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to - None. + use_transpose (Optional[bool]): Use transpose on image or not. Defaults to False. """ super().__init__() self.example_indices = example_indices self.num_examples = num_examples - self.transfroms = transfroms - if self.transfroms is None: - self.transforms = Compose([Transpose()]) + self.transpose = Transpose() if use_transpose else None def set_model(self, model: Type[Model]) -> None: """Sets the model and extracts validation images from the dataset.""" self.model = model - data_loader = self.model.data_loaders["val"] if self.example_indices is None: self.example_indices = np.random.randint( - 0, len(data_loader.dataset.data), self.num_examples + 0, len(self.model.val_dataset), self.num_examples ) - self.val_images = data_loader.dataset.data[self.example_indices] - self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() + self.val_images = self.model.val_dataset.dataset.data[self.example_indices] + self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices] + self.val_targets = self.val_targets.tolist() def on_epoch_end(self, epoch: int, logs: Dict) -> None: """Get network predictions on validation images.""" images = [] for i, image in enumerate(self.val_images): - image = self.transforms(image) + image = self.transpose(image) if self.transpose is not None else image pred, conf = self.model.predict_on_image(image) - ground_truth = self.model.mapper(int(self.val_targets[i])) + if isinstance(self.val_targets[i], list): + ground_truth = "".join( + [ + self.model.mapper(int(target_index)) + for target_index in self.val_targets[i] + ] + ).rstrip("_") + else: + ground_truth = self.val_targets[i] caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" images.append(wandb.Image(image, caption=caption)) |