summaryrefslogtreecommitdiff
path: root/src/training/trainer/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r--src/training/trainer/train.py42
1 files changed, 37 insertions, 5 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index bd6a491..223d9c6 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -4,6 +4,7 @@ from pathlib import Path
import time
from typing import Dict, List, Optional, Tuple, Type
+from einops import rearrange
from loguru import logger
import numpy as np
import torch
@@ -27,12 +28,20 @@ class Trainer:
# TODO: proper add teardown?
- def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None:
+ def __init__(
+ self,
+ max_epochs: int,
+ callbacks: List[Type[Callback]],
+ transformer_model: bool = False,
+ max_norm: float = 0.0,
+ ) -> None:
"""Initialization of the Trainer.
Args:
max_epochs (int): The maximum number of epochs in the training loop.
callbacks (CallbackList): List of callbacks to be called.
+ transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
+ max_norm (float): Max norm for gradient clipping. Defaults to 0.0.
"""
# Training arguments.
@@ -43,6 +52,10 @@ class Trainer:
# Flag for setting callbacks.
self.callbacks_configured = False
+ self.transformer_model = transformer_model
+
+ self.max_norm = max_norm
+
# Model placeholders
self.model = None
@@ -97,10 +110,15 @@ class Trainer:
# Forward pass.
# Get the network prediction.
- output = self.model.forward(data)
+ if self.transformer_model:
+ output = self.model.network.forward(data, targets[:, :-1])
+ output = rearrange(output, "b t v -> (b t) v")
+ targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
+ else:
+ output = self.model.forward(data)
# Compute the loss.
- loss = self.model.loss_fn(output, targets)
+ loss = self.model.criterion(output, targets)
# Backward pass.
# Clear the previous gradients.
@@ -110,6 +128,11 @@ class Trainer:
# Compute the gradients.
loss.backward()
+ if self.max_norm > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.network.parameters(), self.max_norm
+ )
+
# Perform updates using calculated gradients.
self.model.optimizer.step()
@@ -148,10 +171,15 @@ class Trainer:
# Forward pass.
# Get the network prediction.
# Use SWA if available and using test dataset.
- output = self.model.forward(data)
+ if self.transformer_model:
+ output = self.model.network.forward(data, targets[:, :-1])
+ output = rearrange(output, "b t v -> (b t) v")
+ targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
+ else:
+ output = self.model.forward(data)
# Compute the loss.
- loss = self.model.loss_fn(output, targets)
+ loss = self.model.criterion(output, targets)
# Compute metrics.
metrics = self.compute_metrics(output, targets, loss, loss_avg)
@@ -237,6 +265,8 @@ class Trainer:
# Configure callbacks.
self._configure_callbacks()
+ self.callbacks.on_test_begin()
+
self.model.eval()
# Check if SWA network is available.
@@ -252,6 +282,8 @@ class Trainer:
metrics = self.validation_step(batch, samples, loss_avg)
summary.append(metrics)
+ self.callbacks.on_test_end()
+
# Compute mean of all test metrics.
metrics_mean = {
"test_" + metric: np.mean([x[metric] for x in summary])