summaryrefslogtreecommitdiff
path: root/src/training/trainer/train.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
commitbeeaef529e7c893a3475fe27edc880e283373725 (patch)
tree59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/training/trainer/train.py
parent4d7713746eb936832e84852e90292936b933e87d (diff)
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r--src/training/trainer/train.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index fb49103..223d9c6 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -33,6 +33,7 @@ class Trainer:
max_epochs: int,
callbacks: List[Type[Callback]],
transformer_model: bool = False,
+ max_norm: float = 0.0,
) -> None:
"""Initialization of the Trainer.
@@ -40,6 +41,7 @@ class Trainer:
max_epochs (int): The maximum number of epochs in the training loop.
callbacks (CallbackList): List of callbacks to be called.
transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
+ max_norm (float): Max norm for gradient clipping. Defaults to 0.0.
"""
# Training arguments.
@@ -52,6 +54,8 @@ class Trainer:
self.transformer_model = transformer_model
+ self.max_norm = max_norm
+
# Model placeholders
self.model = None
@@ -124,6 +128,11 @@ class Trainer:
# Compute the gradients.
loss.backward()
+ if self.max_norm > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.network.parameters(), self.max_norm
+ )
+
# Perform updates using calculated gradients.
self.model.optimizer.step()