diff options
Diffstat (limited to 'src/training/trainer')
5 files changed, 14 insertions, 4 deletions
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index 907e292..630c434 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -22,7 +22,10 @@ class LRScheduler(Callback): def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" if self.interval == "epoch": - self.lr_scheduler.step() + if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__: + self.lr_scheduler.step(logs["val_loss"]) + else: + self.lr_scheduler.step() def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index f24e5cc..1627f17 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -111,7 +111,7 @@ class WandbImageLogger(Callback): ] ).rstrip("_") else: - ground_truth = self.targets[i] + ground_truth = self.model.mapper(int(self.targets[i])) caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" images.append(wandb.Image(image, caption=caption)) diff --git a/src/training/trainer/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/trainer/population_based_training/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/trainer/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py deleted file mode 100644 index 868d739..0000000 --- a/src/training/trainer/population_based_training/population_based_training.py +++ /dev/null @@ -1 +0,0 @@ -"""TBC.""" diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index fb49103..223d9c6 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -33,6 +33,7 @@ class Trainer: max_epochs: int, callbacks: List[Type[Callback]], transformer_model: bool = False, + max_norm: float = 0.0, ) -> None: """Initialization of the Trainer. @@ -40,6 +41,7 @@ class Trainer: max_epochs (int): The maximum number of epochs in the training loop. callbacks (CallbackList): List of callbacks to be called. transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. + max_norm (float): Max norm for gradient clipping. Defaults to 0.0. """ # Training arguments. @@ -52,6 +54,8 @@ class Trainer: self.transformer_model = transformer_model + self.max_norm = max_norm + # Model placeholders self.model = None @@ -124,6 +128,11 @@ class Trainer: # Compute the gradients. loss.backward() + if self.max_norm > 0: + torch.nn.utils.clip_grad_norm_( + self.model.network.parameters(), self.max_norm + ) + # Perform updates using calculated gradients. self.model.optimizer.step() |