diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-10-22 22:45:58 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-10-22 22:45:58 +0200 |
commit | 4d7713746eb936832e84852e90292936b933e87d (patch) | |
tree | 2b2519d1d2ce53d4e1390590f52018d55dadbc7c /src/training/trainer/callbacks | |
parent | 1b3b8073a19f939d18a0bb85247eb0d99284f7cc (diff) |
Transfomer added, many other changes.
Diffstat (limited to 'src/training/trainer/callbacks')
-rw-r--r-- | src/training/trainer/callbacks/base.py | 18 | ||||
-rw-r--r-- | src/training/trainer/callbacks/checkpoint.py | 6 | ||||
-rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 34 |
3 files changed, 47 insertions, 11 deletions
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py index f81fc1f..500b642 100644 --- a/src/training/trainer/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -62,6 +62,14 @@ class Callback: """Called at the end of an epoch.""" pass + def on_test_begin(self) -> None: + """Called at the beginning of test.""" + pass + + def on_test_end(self) -> None: + """Called at the end of test.""" + pass + class CallbackList: """Container for abstracting away callback calls.""" @@ -104,6 +112,16 @@ class CallbackList: for callback in self._callbacks: callback.on_fit_end() + def on_test_begin(self) -> None: + """Called when test begins.""" + for callback in self._callbacks: + callback.on_test_begin() + + def on_test_end(self) -> None: + """Called when test ends.""" + for callback in self._callbacks: + callback.on_test_end() + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch.""" for callback in self._callbacks: diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py index 6fe06d3..a54e0a9 100644 --- a/src/training/trainer/callbacks/checkpoint.py +++ b/src/training/trainer/callbacks/checkpoint.py @@ -21,7 +21,7 @@ class Checkpoint(Callback): def __init__( self, - checkpoint_path: Path, + checkpoint_path: Union[str, Path], monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0, @@ -29,14 +29,14 @@ class Checkpoint(Callback): """Monitors a quantity that will allow us to determine the best model weights. Args: - checkpoint_path (Path): Path to the experiment with the checkpoint. + checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. monitor (str): Name of the quantity to monitor. Defaults to "accuracy". mode (str): Description of parameter `mode`. Defaults to "auto". min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. """ super().__init__() - self.checkpoint_path = checkpoint_path + self.checkpoint_path = Path(checkpoint_path) self.monitor = monitor self.mode = mode self.min_delta = torch.tensor(min_delta) diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index d2df4d7..f24e5cc 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -64,37 +64,55 @@ class WandbImageLogger(Callback): """ super().__init__() + self.caption = None self.example_indices = example_indices + self.test_sample_indices = None self.num_examples = num_examples self.transpose = Transpose() if use_transpose else None def set_model(self, model: Type[Model]) -> None: """Sets the model and extracts validation images from the dataset.""" self.model = model + self.caption = "Validation Examples" if self.example_indices is None: self.example_indices = np.random.randint( 0, len(self.model.val_dataset), self.num_examples ) - self.val_images = self.model.val_dataset.dataset.data[self.example_indices] - self.val_targets = self.model.val_dataset.dataset.targets[self.example_indices] - self.val_targets = self.val_targets.tolist() + self.images = self.model.val_dataset.dataset.data[self.example_indices] + self.targets = self.model.val_dataset.dataset.targets[self.example_indices] + self.targets = self.targets.tolist() + + def on_test_begin(self) -> None: + """Get samples from test dataset.""" + self.caption = "Test Examples" + if self.test_sample_indices is None: + self.test_sample_indices = np.random.randint( + 0, len(self.model.test_dataset), self.num_examples + ) + self.images = self.model.test_dataset.data[self.test_sample_indices] + self.targets = self.model.test_dataset.targets[self.test_sample_indices] + self.targets = self.targets.tolist() + + def on_test_end(self) -> None: + """Log test images.""" + self.on_epoch_end(0, {}) def on_epoch_end(self, epoch: int, logs: Dict) -> None: """Get network predictions on validation images.""" images = [] - for i, image in enumerate(self.val_images): + for i, image in enumerate(self.images): image = self.transpose(image) if self.transpose is not None else image pred, conf = self.model.predict_on_image(image) - if isinstance(self.val_targets[i], list): + if isinstance(self.targets[i], list): ground_truth = "".join( [ self.model.mapper(int(target_index)) - for target_index in self.val_targets[i] + for target_index in self.targets[i] ] ).rstrip("_") else: - ground_truth = self.val_targets[i] + ground_truth = self.targets[i] caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" images.append(wandb.Image(image, caption=caption)) - wandb.log({"examples": images}, commit=False) + wandb.log({f"{self.caption}": images}, commit=False) |