diff options
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r-- | src/training/trainer/train.py | 33 |
1 files changed, 28 insertions, 5 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index bd6a491..fb49103 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -4,6 +4,7 @@ from pathlib import Path import time from typing import Dict, List, Optional, Tuple, Type +from einops import rearrange from loguru import logger import numpy as np import torch @@ -27,12 +28,18 @@ class Trainer: # TODO: proper add teardown? - def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None: + def __init__( + self, + max_epochs: int, + callbacks: List[Type[Callback]], + transformer_model: bool = False, + ) -> None: """Initialization of the Trainer. Args: max_epochs (int): The maximum number of epochs in the training loop. callbacks (CallbackList): List of callbacks to be called. + transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. """ # Training arguments. @@ -43,6 +50,8 @@ class Trainer: # Flag for setting callbacks. self.callbacks_configured = False + self.transformer_model = transformer_model + # Model placeholders self.model = None @@ -97,10 +106,15 @@ class Trainer: # Forward pass. # Get the network prediction. - output = self.model.forward(data) + if self.transformer_model: + output = self.model.network.forward(data, targets[:, :-1]) + output = rearrange(output, "b t v -> (b t) v") + targets = rearrange(targets[:, 1:], "b t -> (b t)").long() + else: + output = self.model.forward(data) # Compute the loss. - loss = self.model.loss_fn(output, targets) + loss = self.model.criterion(output, targets) # Backward pass. # Clear the previous gradients. @@ -148,10 +162,15 @@ class Trainer: # Forward pass. # Get the network prediction. # Use SWA if available and using test dataset. - output = self.model.forward(data) + if self.transformer_model: + output = self.model.network.forward(data, targets[:, :-1]) + output = rearrange(output, "b t v -> (b t) v") + targets = rearrange(targets[:, 1:], "b t -> (b t)").long() + else: + output = self.model.forward(data) # Compute the loss. - loss = self.model.loss_fn(output, targets) + loss = self.model.criterion(output, targets) # Compute metrics. metrics = self.compute_metrics(output, targets, loss, loss_avg) @@ -237,6 +256,8 @@ class Trainer: # Configure callbacks. self._configure_callbacks() + self.callbacks.on_test_begin() + self.model.eval() # Check if SWA network is available. @@ -252,6 +273,8 @@ class Trainer: metrics = self.validation_step(batch, samples, loss_avg) summary.append(metrics) + self.callbacks.on_test_end() + # Compute mean of all test metrics. metrics_mean = { "test_" + metric: np.mean([x[metric] for x in summary]) |