diff options
Diffstat (limited to 'src/training/callbacks')
-rw-r--r-- | src/training/callbacks/wandb_callbacks.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/callbacks/wandb_callbacks.py index f64cbe1..6ada6df 100644 --- a/src/training/callbacks/wandb_callbacks.py +++ b/src/training/callbacks/wandb_callbacks.py @@ -72,7 +72,7 @@ class WandbImageLogger(Callback): def set_model(self, model: Type[Model]) -> None: """Sets the model and extracts validation images from the dataset.""" self.model = model - data_loader = self.model.data_loaders("val") + data_loader = self.model.data_loaders["val"] if self.example_indices is None: self.example_indices = np.random.randint( 0, len(data_loader.dataset.data), self.num_examples @@ -86,7 +86,7 @@ class WandbImageLogger(Callback): for i, image in enumerate(self.val_images): image = self.transforms(image) pred, conf = self.model.predict_on_image(image) - ground_truth = self.model._mapping[self.val_targets[i]] + ground_truth = self.model.mapper(int(self.val_targets[i])) caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" images.append(wandb.Image(image, caption=caption)) |