summaryrefslogtreecommitdiff
path: root/src/training/callbacks
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-08-09 23:24:02 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-08-09 23:24:02 +0200
commit53677be4ec14854ea4881b0d78730e0414c8dedd (patch)
tree56eaace5e9906c7d408b6a251ca100b5c8b4e991 /src/training/callbacks
parent125d5da5fb845d03bda91426e172bca7f537584a (diff)
Working bash scripts etc.
Diffstat (limited to 'src/training/callbacks')
-rw-r--r--src/training/callbacks/wandb_callbacks.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/callbacks/wandb_callbacks.py
index f64cbe1..6ada6df 100644
--- a/src/training/callbacks/wandb_callbacks.py
+++ b/src/training/callbacks/wandb_callbacks.py
@@ -72,7 +72,7 @@ class WandbImageLogger(Callback):
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and extracts validation images from the dataset."""
self.model = model
- data_loader = self.model.data_loaders("val")
+ data_loader = self.model.data_loaders["val"]
if self.example_indices is None:
self.example_indices = np.random.randint(
0, len(data_loader.dataset.data), self.num_examples
@@ -86,7 +86,7 @@ class WandbImageLogger(Callback):
for i, image in enumerate(self.val_images):
image = self.transforms(image)
pred, conf = self.model.predict_on_image(image)
- ground_truth = self.model._mapping[self.val_targets[i]]
+ ground_truth = self.model.mapper(int(self.val_targets[i]))
caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
images.append(wandb.Image(image, caption=caption))