diff options
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r-- | src/training/trainer/train.py | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index b240157..bd6a491 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -9,7 +9,7 @@ import numpy as np import torch from torch import Tensor from torch.optim.swa_utils import update_bn -from training.trainer.callbacks import Callback, CallbackList +from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA from training.trainer.util import log_val_metric, RunningAverage import wandb @@ -47,8 +47,14 @@ class Trainer: self.model = None def _configure_callbacks(self) -> None: + """Instantiate the CallbackList.""" if not self.callbacks_configured: - # Instantiate a CallbackList. + # If learning rate schedulers are present, they need to be added to the callbacks. + if self.model.swa_scheduler is not None: + self.callbacks.append(SWA()) + elif self.model.lr_scheduler is not None: + self.callbacks.append(LRScheduler()) + self.callbacks = CallbackList(self.model, self.callbacks) def compute_metrics( @@ -91,7 +97,7 @@ class Trainer: # Forward pass. # Get the network prediction. - output = self.model.network(data) + output = self.model.forward(data) # Compute the loss. loss = self.model.loss_fn(output, targets) @@ -130,7 +136,6 @@ class Trainer: batch: int, samples: Tuple[Tensor, Tensor], loss_avg: Type[RunningAverage], - use_swa: bool = False, ) -> Dict: """Performs the validation step.""" # Pass the tensor to the device for computation. @@ -143,10 +148,7 @@ class Trainer: # Forward pass. # Get the network prediction. # Use SWA if available and using test dataset. - if use_swa and self.model.swa_network is None: - output = self.model.swa_network(data) - else: - output = self.model.network(data) + output = self.model.forward(data) # Compute the loss. loss = self.model.loss_fn(output, targets) @@ -238,7 +240,7 @@ class Trainer: self.model.eval() # Check if SWA network is available. - use_swa = True if self.model.swa_network is not None else False + self.model.use_swa_model() # Running average for the loss. loss_avg = RunningAverage() @@ -247,7 +249,7 @@ class Trainer: summary = [] for batch, samples in enumerate(self.model.test_dataloader()): - metrics = self.validation_step(batch, samples, loss_avg, use_swa) + metrics = self.validation_step(batch, samples, loss_avg) summary.append(metrics) # Compute mean of all test metrics. |