diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-12-07 22:54:04 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-12-07 22:54:04 +0100 |
commit | 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (patch) | |
tree | 526ba739714b3d040f7810c1a6be3ff0ba37fdb1 /src/training/trainer | |
parent | 5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (diff) |
Segmentation working!
Diffstat (limited to 'src/training/trainer')
-rw-r--r-- | src/training/trainer/callbacks/__init__.py | 3 | ||||
-rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 97 | ||||
-rw-r--r-- | src/training/trainer/train.py | 13 |
3 files changed, 104 insertions, 9 deletions
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index e1bd858..95ec142 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -7,7 +7,7 @@ from .lr_schedulers import ( SWA, ) from .progress_bar import ProgressBar -from .wandb_callbacks import WandbCallback, WandbImageLogger +from .wandb_callbacks import WandbCallback, WandbImageLogger, WandbSegmentationLogger __all__ = [ "Callback", @@ -17,6 +17,7 @@ __all__ = [ "LRScheduler", "WandbCallback", "WandbImageLogger", + "WandbSegmentationLogger", "ProgressBar", "SWA", ] diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index 1627f17..df1fd8f 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -2,12 +2,10 @@ from typing import Callable, Dict, List, Optional, Type import numpy as np -import torch -from torchvision.transforms import ToTensor from training.trainer.callbacks import Callback import wandb -from text_recognizer.datasets import Transpose +import text_recognizer.datasets.transforms as transforms from text_recognizer.models.base import Model @@ -52,14 +50,14 @@ class WandbImageLogger(Callback): self, example_indices: Optional[List] = None, num_examples: int = 4, - use_transpose: Optional[bool] = False, + transform: Optional[bool] = None, ) -> None: """Initializes the WandbImageLogger with the model to train. Args: example_indices (Optional[List]): Indices for validation images. Defaults to None. num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - use_transpose (Optional[bool]): Use transpose on image or not. Defaults to False. + transform (Optional[Dict]): Use transform on image or not. Defaults to None. """ @@ -68,7 +66,13 @@ class WandbImageLogger(Callback): self.example_indices = example_indices self.test_sample_indices = None self.num_examples = num_examples - self.transpose = Transpose() if use_transpose else None + self.transform = ( + self._configure_transform(transform) if transform is not None else None + ) + + def _configure_transform(self, transform: Dict) -> Callable: + args = transform["args"] or {} + return getattr(transforms, transform["type"])(**args) def set_model(self, model: Type[Model]) -> None: """Sets the model and extracts validation images from the dataset.""" @@ -101,7 +105,7 @@ class WandbImageLogger(Callback): """Get network predictions on validation images.""" images = [] for i, image in enumerate(self.images): - image = self.transpose(image) if self.transpose is not None else image + image = self.transform(image) if self.transform is not None else image pred, conf = self.model.predict_on_image(image) if isinstance(self.targets[i], list): ground_truth = "".join( @@ -116,3 +120,82 @@ class WandbImageLogger(Callback): images.append(wandb.Image(image, caption=caption)) wandb.log({f"{self.caption}": images}, commit=False) + + +class WandbSegmentationLogger(Callback): + """Custom W&B callback for image logging.""" + + def __init__( + self, + class_labels: Dict, + example_indices: Optional[List] = None, + num_examples: int = 4, + ) -> None: + """Initializes the WandbImageLogger with the model to train. + + Args: + class_labels (Dict): A dict with int as key and class string as value. + example_indices (Optional[List]): Indices for validation images. Defaults to None. + num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. + + """ + + super().__init__() + self.caption = None + self.class_labels = {int(k): v for k, v in class_labels.items()} + self.example_indices = example_indices + self.test_sample_indices = None + self.num_examples = num_examples + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and extracts validation images from the dataset.""" + self.model = model + self.caption = "Validation Segmentation Examples" + if self.example_indices is None: + self.example_indices = np.random.randint( + 0, len(self.model.val_dataset), self.num_examples + ) + self.images = self.model.val_dataset.dataset.data[self.example_indices] + self.targets = self.model.val_dataset.dataset.targets[self.example_indices] + self.targets = self.targets.tolist() + + def on_test_begin(self) -> None: + """Get samples from test dataset.""" + self.caption = "Test Segmentation Examples" + if self.test_sample_indices is None: + self.test_sample_indices = np.random.randint( + 0, len(self.model.test_dataset), self.num_examples + ) + self.images = self.model.test_dataset.data[self.test_sample_indices] + self.targets = self.model.test_dataset.targets[self.test_sample_indices] + self.targets = self.targets.tolist() + + def on_test_end(self) -> None: + """Log test images.""" + self.on_epoch_end(0, {}) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Get network predictions on validation images.""" + images = [] + for i, image in enumerate(self.images): + pred_mask = ( + self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() + ) + gt_mask = np.array(self.targets[i]) + images.append( + wandb.Image( + image, + masks={ + "predictions": { + "mask_data": pred_mask, + "class_labels": self.class_labels, + }, + "ground_truth": { + "mask_data": gt_mask, + "class_labels": self.class_labels, + }, + }, + ) + ) + + wandb.log({f"{self.caption}": images}, commit=False) diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index 8ae994a..40a25da 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -38,6 +38,7 @@ class Trainer: callbacks: List[Type[Callback]], transformer_model: bool = False, max_norm: float = 0.0, + freeze_backbone: Optional[int] = None, ) -> None: """Initialization of the Trainer. @@ -46,12 +47,15 @@ class Trainer: callbacks (CallbackList): List of callbacks to be called. transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. max_norm (float): Max norm for gradient clipping. Defaults to 0.0. + freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training + Transformers. Default is None. """ # Training arguments. self.start_epoch = 1 self.max_epochs = max_epochs self.callbacks = callbacks + self.freeze_backbone = freeze_backbone # Flag for setting callbacks. self.callbacks_configured = False @@ -115,7 +119,14 @@ class Trainer: # Forward pass. # Get the network prediction. if self.transformer_model: - output = self.model.network.forward(data, targets[:, :-1]) + if self.freeze_backbone is not None and batch < self.freeze_backbone: + with torch.no_grad(): + image_features = self.model.network.extract_image_features(data) + output = self.model.network.decode_image_features( + image_features, targets[:, :-1] + ) + else: + output = self.model.network.forward(data, targets[:, :-1]) output = rearrange(output, "b t v -> (b t) v") targets = rearrange(targets[:, 1:], "b t -> (b t)").long() else: |