diff options
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r-- | src/training/trainer/train.py | 94 |
1 files changed, 54 insertions, 40 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index 40a25da..b770c94 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -12,7 +12,7 @@ import torch from torch import Tensor from torch.optim.swa_utils import update_bn from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA -from training.trainer.util import log_val_metric, RunningAverage +from training.trainer.util import log_val_metric import wandb from text_recognizer.models import Model @@ -30,8 +30,6 @@ warnings.filterwarnings("ignore") class Trainer: """Trainer for training PyTorch models.""" - # TODO: proper add teardown? - def __init__( self, max_epochs: int, @@ -46,7 +44,7 @@ class Trainer: max_epochs (int): The maximum number of epochs in the training loop. callbacks (CallbackList): List of callbacks to be called. transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. - max_norm (float): Max norm for gradient clipping. Defaults to 0.0. + max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0. freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training Transformers. Default is None. @@ -79,35 +77,32 @@ class Trainer: self.callbacks = CallbackList(self.model, self.callbacks) def compute_metrics( - self, - output: Tensor, - targets: Tensor, - loss: Tensor, - loss_avg: Type[RunningAverage], + self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int ) -> Dict: """Computes metrics for output and target pairs.""" # Compute metrics. loss = loss.detach().float().item() - loss_avg.update(loss) output = output.detach() targets = targets.detach() if self.model.metrics is not None: - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } + metrics = {} + for metric in self.model.metrics: + if metric == "cer" or metric == "wer": + metrics[metric] = self.model.metrics[metric]( + output, + targets, + batch_size, + self.model.mapper(self.model.pad_token), + ) + else: + metrics[metric] = self.model.metrics[metric](output, targets) else: metrics = {} metrics["loss"] = loss return metrics - def training_step( - self, - batch: int, - samples: Tuple[Tensor, Tensor], - loss_avg: Type[RunningAverage], - ) -> Dict: + def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: """Performs the training step.""" # Pass the tensor to the device for computation. data, targets = samples @@ -116,25 +111,43 @@ class Trainer: targets.to(self.model.device), ) + batch_size = data.shape[0] + + # Placeholder for uxiliary loss. + aux_loss = None + # Forward pass. # Get the network prediction. if self.transformer_model: if self.freeze_backbone is not None and batch < self.freeze_backbone: with torch.no_grad(): image_features = self.model.network.extract_image_features(data) + + if isinstance(image_features, Tuple): + image_features, _ = image_features + output = self.model.network.decode_image_features( image_features, targets[:, :-1] ) else: output = self.model.network.forward(data, targets[:, :-1]) + if isinstance(output, Tuple): + output, aux_loss = output output = rearrange(output, "b t v -> (b t) v") targets = rearrange(targets[:, 1:], "b t -> (b t)").long() else: output = self.model.forward(data) + if isinstance(output, Tuple): + output, aux_loss = output + targets = data + # Compute the loss. loss = self.model.criterion(output, targets) + if aux_loss is not None: + loss += aux_loss + # Backward pass. # Clear the previous gradients. for p in self.model.network.parameters(): @@ -151,7 +164,7 @@ class Trainer: # Perform updates using calculated gradients. self.model.optimizer.step() - metrics = self.compute_metrics(output, targets, loss, loss_avg) + metrics = self.compute_metrics(output, targets, loss, batch_size) return metrics @@ -160,22 +173,15 @@ class Trainer: # Set model to traning mode. self.model.train() - # Running average for the loss. - loss_avg = RunningAverage() - for batch, samples in enumerate(self.model.train_dataloader()): self.callbacks.on_train_batch_begin(batch) - metrics = self.training_step(batch, samples, loss_avg) + metrics = self.training_step(batch, samples) self.callbacks.on_train_batch_end(batch, logs=metrics) @torch.no_grad() - def validation_step( - self, - batch: int, - samples: Tuple[Tensor, Tensor], - loss_avg: Type[RunningAverage], - ) -> Dict: + def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: """Performs the validation step.""" + # Pass the tensor to the device for computation. data, targets = samples data, targets = ( @@ -183,21 +189,35 @@ class Trainer: targets.to(self.model.device), ) + batch_size = data.shape[0] + + # Placeholder for uxiliary loss. + aux_loss = None + # Forward pass. # Get the network prediction. # Use SWA if available and using test dataset. if self.transformer_model: output = self.model.network.forward(data, targets[:, :-1]) + if isinstance(output, Tuple): + output, aux_loss = output output = rearrange(output, "b t v -> (b t) v") targets = rearrange(targets[:, 1:], "b t -> (b t)").long() else: output = self.model.forward(data) + if isinstance(output, Tuple): + output, aux_loss = output + targets = data + # Compute the loss. loss = self.model.criterion(output, targets) + if aux_loss is not None: + loss += aux_loss + # Compute metrics. - metrics = self.compute_metrics(output, targets, loss, loss_avg) + metrics = self.compute_metrics(output, targets, loss, batch_size) return metrics @@ -206,15 +226,12 @@ class Trainer: # Set model to eval mode. self.model.eval() - # Running average for the loss. - loss_avg = RunningAverage() - # Summary for the current eval loop. summary = [] for batch, samples in enumerate(self.model.val_dataloader()): self.callbacks.on_validation_batch_begin(batch) - metrics = self.validation_step(batch, samples, loss_avg) + metrics = self.validation_step(batch, samples) self.callbacks.on_validation_batch_end(batch, logs=metrics) summary.append(metrics) @@ -287,14 +304,11 @@ class Trainer: # Check if SWA network is available. self.model.use_swa_model() - # Running average for the loss. - loss_avg = RunningAverage() - # Summary for the current test loop. summary = [] for batch, samples in enumerate(self.model.test_dataloader()): - metrics = self.validation_step(batch, samples, loss_avg) + metrics = self.validation_step(batch, samples) summary.append(metrics) self.callbacks.on_test_end() |