summaryrefslogtreecommitdiff
path: root/src/training/trainer
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer')
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py5
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py2
-rw-r--r--src/training/trainer/population_based_training/__init__.py1
-rw-r--r--src/training/trainer/population_based_training/population_based_training.py1
-rw-r--r--src/training/trainer/train.py9
5 files changed, 14 insertions, 4 deletions
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index 907e292..630c434 100644
--- a/src/training/trainer/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -22,7 +22,10 @@ class LRScheduler(Callback):
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
if self.interval == "epoch":
- self.lr_scheduler.step()
+ if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__:
+ self.lr_scheduler.step(logs["val_loss"])
+ else:
+ self.lr_scheduler.step()
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index f24e5cc..1627f17 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -111,7 +111,7 @@ class WandbImageLogger(Callback):
]
).rstrip("_")
else:
- ground_truth = self.targets[i]
+ ground_truth = self.model.mapper(int(self.targets[i]))
caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
images.append(wandb.Image(image, caption=caption))
diff --git a/src/training/trainer/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py
deleted file mode 100644
index 868d739..0000000
--- a/src/training/trainer/population_based_training/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""TBC."""
diff --git a/src/training/trainer/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py
deleted file mode 100644
index 868d739..0000000
--- a/src/training/trainer/population_based_training/population_based_training.py
+++ /dev/null
@@ -1 +0,0 @@
-"""TBC."""
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index fb49103..223d9c6 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -33,6 +33,7 @@ class Trainer:
max_epochs: int,
callbacks: List[Type[Callback]],
transformer_model: bool = False,
+ max_norm: float = 0.0,
) -> None:
"""Initialization of the Trainer.
@@ -40,6 +41,7 @@ class Trainer:
max_epochs (int): The maximum number of epochs in the training loop.
callbacks (CallbackList): List of callbacks to be called.
transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
+ max_norm (float): Max norm for gradient clipping. Defaults to 0.0.
"""
# Training arguments.
@@ -52,6 +54,8 @@ class Trainer:
self.transformer_model = transformer_model
+ self.max_norm = max_norm
+
# Model placeholders
self.model = None
@@ -124,6 +128,11 @@ class Trainer:
# Compute the gradients.
loss.backward()
+ if self.max_norm > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.network.parameters(), self.max_norm
+ )
+
# Perform updates using calculated gradients.
self.model.optimizer.step()