summaryrefslogtreecommitdiff
path: root/src/training/trainer/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r--src/training/trainer/train.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index fb49103..223d9c6 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -33,6 +33,7 @@ class Trainer:
max_epochs: int,
callbacks: List[Type[Callback]],
transformer_model: bool = False,
+ max_norm: float = 0.0,
) -> None:
"""Initialization of the Trainer.
@@ -40,6 +41,7 @@ class Trainer:
max_epochs (int): The maximum number of epochs in the training loop.
callbacks (CallbackList): List of callbacks to be called.
transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
+ max_norm (float): Max norm for gradient clipping. Defaults to 0.0.
"""
# Training arguments.
@@ -52,6 +54,8 @@ class Trainer:
self.transformer_model = transformer_model
+ self.max_norm = max_norm
+
# Model placeholders
self.model = None
@@ -124,6 +128,11 @@ class Trainer:
# Compute the gradients.
loss.backward()
+ if self.max_norm > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.network.parameters(), self.max_norm
+ )
+
# Perform updates using calculated gradients.
self.model.optimizer.step()