summaryrefslogtreecommitdiff
path: root/src/training/trainer/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r--src/training/trainer/train.py22
1 files changed, 12 insertions, 10 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index b240157..bd6a491 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -9,7 +9,7 @@ import numpy as np
import torch
from torch import Tensor
from torch.optim.swa_utils import update_bn
-from training.trainer.callbacks import Callback, CallbackList
+from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA
from training.trainer.util import log_val_metric, RunningAverage
import wandb
@@ -47,8 +47,14 @@ class Trainer:
self.model = None
def _configure_callbacks(self) -> None:
+ """Instantiate the CallbackList."""
if not self.callbacks_configured:
- # Instantiate a CallbackList.
+ # If learning rate schedulers are present, they need to be added to the callbacks.
+ if self.model.swa_scheduler is not None:
+ self.callbacks.append(SWA())
+ elif self.model.lr_scheduler is not None:
+ self.callbacks.append(LRScheduler())
+
self.callbacks = CallbackList(self.model, self.callbacks)
def compute_metrics(
@@ -91,7 +97,7 @@ class Trainer:
# Forward pass.
# Get the network prediction.
- output = self.model.network(data)
+ output = self.model.forward(data)
# Compute the loss.
loss = self.model.loss_fn(output, targets)
@@ -130,7 +136,6 @@ class Trainer:
batch: int,
samples: Tuple[Tensor, Tensor],
loss_avg: Type[RunningAverage],
- use_swa: bool = False,
) -> Dict:
"""Performs the validation step."""
# Pass the tensor to the device for computation.
@@ -143,10 +148,7 @@ class Trainer:
# Forward pass.
# Get the network prediction.
# Use SWA if available and using test dataset.
- if use_swa and self.model.swa_network is None:
- output = self.model.swa_network(data)
- else:
- output = self.model.network(data)
+ output = self.model.forward(data)
# Compute the loss.
loss = self.model.loss_fn(output, targets)
@@ -238,7 +240,7 @@ class Trainer:
self.model.eval()
# Check if SWA network is available.
- use_swa = True if self.model.swa_network is not None else False
+ self.model.use_swa_model()
# Running average for the loss.
loss_avg = RunningAverage()
@@ -247,7 +249,7 @@ class Trainer:
summary = []
for batch, samples in enumerate(self.model.test_dataloader()):
- metrics = self.validation_step(batch, samples, loss_avg, use_swa)
+ metrics = self.validation_step(batch, samples, loss_avg)
summary.append(metrics)
# Compute mean of all test metrics.