diff options
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r-- | src/training/trainer/train.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index fb49103..223d9c6 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -33,6 +33,7 @@ class Trainer: max_epochs: int, callbacks: List[Type[Callback]], transformer_model: bool = False, + max_norm: float = 0.0, ) -> None: """Initialization of the Trainer. @@ -40,6 +41,7 @@ class Trainer: max_epochs (int): The maximum number of epochs in the training loop. callbacks (CallbackList): List of callbacks to be called. transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. + max_norm (float): Max norm for gradient clipping. Defaults to 0.0. """ # Training arguments. @@ -52,6 +54,8 @@ class Trainer: self.transformer_model = transformer_model + self.max_norm = max_norm + # Model placeholders self.model = None @@ -124,6 +128,11 @@ class Trainer: # Compute the gradients. loss.backward() + if self.max_norm > 0: + torch.nn.utils.clip_grad_norm_( + self.model.network.parameters(), self.max_norm + ) + # Perform updates using calculated gradients. self.model.optimizer.step() |