summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
commitbeeaef529e7c893a3475fe27edc880e283373725 (patch)
tree59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/training/trainer/callbacks
parent4d7713746eb936832e84852e90292936b933e87d (diff)
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/training/trainer/callbacks')
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py5
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py2
2 files changed, 5 insertions, 2 deletions
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index 907e292..630c434 100644
--- a/src/training/trainer/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -22,7 +22,10 @@ class LRScheduler(Callback):
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
if self.interval == "epoch":
- self.lr_scheduler.step()
+ if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__:
+ self.lr_scheduler.step(logs["val_loss"])
+ else:
+ self.lr_scheduler.step()
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index f24e5cc..1627f17 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -111,7 +111,7 @@ class WandbImageLogger(Callback):
]
).rstrip("_")
else:
- ground_truth = self.targets[i]
+ ground_truth = self.model.mapper(int(self.targets[i]))
caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
images.append(wandb.Image(image, caption=caption))