diff options
Diffstat (limited to 'src/training/trainer')
-rw-r--r-- | src/training/trainer/callbacks/__init__.py | 8 | ||||
-rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 58 | ||||
-rw-r--r-- | src/training/trainer/train.py | 94 |
3 files changed, 119 insertions, 41 deletions
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index 95ec142..80c4177 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -7,7 +7,12 @@ from .lr_schedulers import ( SWA, ) from .progress_bar import ProgressBar -from .wandb_callbacks import WandbCallback, WandbImageLogger, WandbSegmentationLogger +from .wandb_callbacks import ( + WandbCallback, + WandbImageLogger, + WandbReconstructionLogger, + WandbSegmentationLogger, +) __all__ = [ "Callback", @@ -17,6 +22,7 @@ __all__ = [ "LRScheduler", "WandbCallback", "WandbImageLogger", + "WandbReconstructionLogger", "WandbSegmentationLogger", "ProgressBar", "SWA", diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index 20414df..552a4f4 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -201,3 +201,61 @@ class WandbSegmentationLogger(Callback): ) wandb.log({f"{self.caption}": images}, commit=False) + + +class WandbReconstructionLogger(Callback): + """Custom W&B callback for image reconstructions logging.""" + + def __init__( + self, example_indices: Optional[List] = None, num_examples: int = 4, + ) -> None: + """Initializes the WandbImageLogger with the model to train. + + Args: + example_indices (Optional[List]): Indices for validation images. Defaults to None. + num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. + + """ + + super().__init__() + self.caption = None + self.example_indices = example_indices + self.test_sample_indices = None + self.num_examples = num_examples + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and extracts validation images from the dataset.""" + self.model = model + self.caption = "Validation Reconstructions Examples" + if self.example_indices is None: + self.example_indices = np.random.randint( + 0, len(self.model.val_dataset), self.num_examples + ) + self.images = self.model.val_dataset.dataset.data[self.example_indices] + + def on_test_begin(self) -> None: + """Get samples from test dataset.""" + self.caption = "Test Reconstructions Examples" + if self.test_sample_indices is None: + self.test_sample_indices = np.random.randint( + 0, len(self.model.test_dataset), self.num_examples + ) + self.images = self.model.test_dataset.data[self.test_sample_indices] + + def on_test_end(self) -> None: + """Log test images.""" + self.on_epoch_end(0, {}) + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """Get network predictions on validation images.""" + images = [] + for image in self.images: + reconstructed_image = ( + self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() + ) + images.append(image) + images.append(reconstructed_image) + + wandb.log( + {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False, + ) diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index 40a25da..b770c94 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -12,7 +12,7 @@ import torch from torch import Tensor from torch.optim.swa_utils import update_bn from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA -from training.trainer.util import log_val_metric, RunningAverage +from training.trainer.util import log_val_metric import wandb from text_recognizer.models import Model @@ -30,8 +30,6 @@ warnings.filterwarnings("ignore") class Trainer: """Trainer for training PyTorch models.""" - # TODO: proper add teardown? - def __init__( self, max_epochs: int, @@ -46,7 +44,7 @@ class Trainer: max_epochs (int): The maximum number of epochs in the training loop. callbacks (CallbackList): List of callbacks to be called. transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. - max_norm (float): Max norm for gradient clipping. Defaults to 0.0. + max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0. freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training Transformers. Default is None. @@ -79,35 +77,32 @@ class Trainer: self.callbacks = CallbackList(self.model, self.callbacks) def compute_metrics( - self, - output: Tensor, - targets: Tensor, - loss: Tensor, - loss_avg: Type[RunningAverage], + self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int ) -> Dict: """Computes metrics for output and target pairs.""" # Compute metrics. loss = loss.detach().float().item() - loss_avg.update(loss) output = output.detach() targets = targets.detach() if self.model.metrics is not None: - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } + metrics = {} + for metric in self.model.metrics: + if metric == "cer" or metric == "wer": + metrics[metric] = self.model.metrics[metric]( + output, + targets, + batch_size, + self.model.mapper(self.model.pad_token), + ) + else: + metrics[metric] = self.model.metrics[metric](output, targets) else: metrics = {} metrics["loss"] = loss return metrics - def training_step( - self, - batch: int, - samples: Tuple[Tensor, Tensor], - loss_avg: Type[RunningAverage], - ) -> Dict: + def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: """Performs the training step.""" # Pass the tensor to the device for computation. data, targets = samples @@ -116,25 +111,43 @@ class Trainer: targets.to(self.model.device), ) + batch_size = data.shape[0] + + # Placeholder for uxiliary loss. + aux_loss = None + # Forward pass. # Get the network prediction. if self.transformer_model: if self.freeze_backbone is not None and batch < self.freeze_backbone: with torch.no_grad(): image_features = self.model.network.extract_image_features(data) + + if isinstance(image_features, Tuple): + image_features, _ = image_features + output = self.model.network.decode_image_features( image_features, targets[:, :-1] ) else: output = self.model.network.forward(data, targets[:, :-1]) + if isinstance(output, Tuple): + output, aux_loss = output output = rearrange(output, "b t v -> (b t) v") targets = rearrange(targets[:, 1:], "b t -> (b t)").long() else: output = self.model.forward(data) + if isinstance(output, Tuple): + output, aux_loss = output + targets = data + # Compute the loss. loss = self.model.criterion(output, targets) + if aux_loss is not None: + loss += aux_loss + # Backward pass. # Clear the previous gradients. for p in self.model.network.parameters(): @@ -151,7 +164,7 @@ class Trainer: # Perform updates using calculated gradients. self.model.optimizer.step() - metrics = self.compute_metrics(output, targets, loss, loss_avg) + metrics = self.compute_metrics(output, targets, loss, batch_size) return metrics @@ -160,22 +173,15 @@ class Trainer: # Set model to traning mode. self.model.train() - # Running average for the loss. - loss_avg = RunningAverage() - for batch, samples in enumerate(self.model.train_dataloader()): self.callbacks.on_train_batch_begin(batch) - metrics = self.training_step(batch, samples, loss_avg) + metrics = self.training_step(batch, samples) self.callbacks.on_train_batch_end(batch, logs=metrics) @torch.no_grad() - def validation_step( - self, - batch: int, - samples: Tuple[Tensor, Tensor], - loss_avg: Type[RunningAverage], - ) -> Dict: + def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: """Performs the validation step.""" + # Pass the tensor to the device for computation. data, targets = samples data, targets = ( @@ -183,21 +189,35 @@ class Trainer: targets.to(self.model.device), ) + batch_size = data.shape[0] + + # Placeholder for uxiliary loss. + aux_loss = None + # Forward pass. # Get the network prediction. # Use SWA if available and using test dataset. if self.transformer_model: output = self.model.network.forward(data, targets[:, :-1]) + if isinstance(output, Tuple): + output, aux_loss = output output = rearrange(output, "b t v -> (b t) v") targets = rearrange(targets[:, 1:], "b t -> (b t)").long() else: output = self.model.forward(data) + if isinstance(output, Tuple): + output, aux_loss = output + targets = data + # Compute the loss. loss = self.model.criterion(output, targets) + if aux_loss is not None: + loss += aux_loss + # Compute metrics. - metrics = self.compute_metrics(output, targets, loss, loss_avg) + metrics = self.compute_metrics(output, targets, loss, batch_size) return metrics @@ -206,15 +226,12 @@ class Trainer: # Set model to eval mode. self.model.eval() - # Running average for the loss. - loss_avg = RunningAverage() - # Summary for the current eval loop. summary = [] for batch, samples in enumerate(self.model.val_dataloader()): self.callbacks.on_validation_batch_begin(batch) - metrics = self.validation_step(batch, samples, loss_avg) + metrics = self.validation_step(batch, samples) self.callbacks.on_validation_batch_end(batch, logs=metrics) summary.append(metrics) @@ -287,14 +304,11 @@ class Trainer: # Check if SWA network is available. self.model.use_swa_model() - # Running average for the loss. - loss_avg = RunningAverage() - # Summary for the current test loop. summary = [] for batch, samples in enumerate(self.model.test_dataloader()): - metrics = self.validation_step(batch, samples, loss_avg) + metrics = self.validation_step(batch, samples) summary.append(metrics) self.callbacks.on_test_end() |