diff options
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r-- | src/training/trainer/train.py | 42 |
1 files changed, 37 insertions, 5 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index bd6a491..223d9c6 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -4,6 +4,7 @@ from pathlib import Path import time from typing import Dict, List, Optional, Tuple, Type +from einops import rearrange from loguru import logger import numpy as np import torch @@ -27,12 +28,20 @@ class Trainer: # TODO: proper add teardown? - def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None: + def __init__( + self, + max_epochs: int, + callbacks: List[Type[Callback]], + transformer_model: bool = False, + max_norm: float = 0.0, + ) -> None: """Initialization of the Trainer. Args: max_epochs (int): The maximum number of epochs in the training loop. callbacks (CallbackList): List of callbacks to be called. + transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. + max_norm (float): Max norm for gradient clipping. Defaults to 0.0. """ # Training arguments. @@ -43,6 +52,10 @@ class Trainer: # Flag for setting callbacks. self.callbacks_configured = False + self.transformer_model = transformer_model + + self.max_norm = max_norm + # Model placeholders self.model = None @@ -97,10 +110,15 @@ class Trainer: # Forward pass. # Get the network prediction. - output = self.model.forward(data) + if self.transformer_model: + output = self.model.network.forward(data, targets[:, :-1]) + output = rearrange(output, "b t v -> (b t) v") + targets = rearrange(targets[:, 1:], "b t -> (b t)").long() + else: + output = self.model.forward(data) # Compute the loss. - loss = self.model.loss_fn(output, targets) + loss = self.model.criterion(output, targets) # Backward pass. # Clear the previous gradients. @@ -110,6 +128,11 @@ class Trainer: # Compute the gradients. loss.backward() + if self.max_norm > 0: + torch.nn.utils.clip_grad_norm_( + self.model.network.parameters(), self.max_norm + ) + # Perform updates using calculated gradients. self.model.optimizer.step() @@ -148,10 +171,15 @@ class Trainer: # Forward pass. # Get the network prediction. # Use SWA if available and using test dataset. - output = self.model.forward(data) + if self.transformer_model: + output = self.model.network.forward(data, targets[:, :-1]) + output = rearrange(output, "b t v -> (b t) v") + targets = rearrange(targets[:, 1:], "b t -> (b t)").long() + else: + output = self.model.forward(data) # Compute the loss. - loss = self.model.loss_fn(output, targets) + loss = self.model.criterion(output, targets) # Compute metrics. metrics = self.compute_metrics(output, targets, loss, loss_avg) @@ -237,6 +265,8 @@ class Trainer: # Configure callbacks. self._configure_callbacks() + self.callbacks.on_test_begin() + self.model.eval() # Check if SWA network is available. @@ -252,6 +282,8 @@ class Trainer: metrics = self.validation_step(batch, samples, loss_avg) summary.append(metrics) + self.callbacks.on_test_end() + # Compute mean of all test metrics. metrics_mean = { "test_" + metric: np.mean([x[metric] for x in summary]) |