summaryrefslogtreecommitdiff
path: root/src/training/trainer/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r--src/training/trainer/train.py170
1 files changed, 108 insertions, 62 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index a75ae8f..b240157 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -8,8 +8,9 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
+from torch.optim.swa_utils import update_bn
from training.trainer.callbacks import Callback, CallbackList
-from training.trainer.util import RunningAverage
+from training.trainer.util import log_val_metric, RunningAverage
import wandb
from text_recognizer.models import Model
@@ -24,37 +25,55 @@ torch.cuda.manual_seed(4711)
class Trainer:
"""Trainer for training PyTorch models."""
- def __init__(
- self,
- model: Type[Model],
- model_dir: Path,
- train_args: Dict,
- callbacks: CallbackList,
- checkpoint_path: Optional[Path] = None,
- ) -> None:
+ # TODO: proper add teardown?
+
+ def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None:
"""Initialization of the Trainer.
Args:
- model (Type[Model]): A model object.
- model_dir (Path): Path to the model directory.
- train_args (Dict): The training arguments.
+ max_epochs (int): The maximum number of epochs in the training loop.
callbacks (CallbackList): List of callbacks to be called.
- checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None.
"""
- self.model = model
- self.model_dir = model_dir
- self.checkpoint_path = checkpoint_path
+ # Training arguments.
self.start_epoch = 1
- self.epochs = train_args["epochs"]
+ self.max_epochs = max_epochs
self.callbacks = callbacks
- if self.checkpoint_path is not None:
- self.start_epoch = self.model.load_checkpoint(self.checkpoint_path)
+ # Flag for setting callbacks.
+ self.callbacks_configured = False
+
+ # Model placeholders
+ self.model = None
+
+ def _configure_callbacks(self) -> None:
+ if not self.callbacks_configured:
+ # Instantiate a CallbackList.
+ self.callbacks = CallbackList(self.model, self.callbacks)
+
+ def compute_metrics(
+ self,
+ output: Tensor,
+ targets: Tensor,
+ loss: Tensor,
+ loss_avg: Type[RunningAverage],
+ ) -> Dict:
+ """Computes metrics for output and target pairs."""
+ # Compute metrics.
+ loss = loss.detach().float().item()
+ loss_avg.update(loss)
+ output = output.detach()
+ targets = targets.detach()
+ if self.model.metrics is not None:
+ metrics = {
+ metric: self.model.metrics[metric](output, targets)
+ for metric in self.model.metrics
+ }
+ else:
+ metrics = {}
+ metrics["loss"] = loss
- # Parse the name of the experiment.
- experiment_dir = str(self.model_dir.parents[1]).split("/")
- self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1]
+ return metrics
def training_step(
self,
@@ -75,11 +94,12 @@ class Trainer:
output = self.model.network(data)
# Compute the loss.
- loss = self.model.criterion(output, targets)
+ loss = self.model.loss_fn(output, targets)
# Backward pass.
# Clear the previous gradients.
- self.model.optimizer.zero_grad()
+ for p in self.model.network.parameters():
+ p.grad = None
# Compute the gradients.
loss.backward()
@@ -87,15 +107,8 @@ class Trainer:
# Perform updates using calculated gradients.
self.model.optimizer.step()
- # Compute metrics.
- loss_avg.update(loss.item())
- output = output.data.cpu()
- targets = targets.data.cpu()
- metrics = {
- metric: self.model.metrics[metric](output, targets)
- for metric in self.model.metrics
- }
- metrics["loss"] = loss_avg()
+ metrics = self.compute_metrics(output, targets, loss, loss_avg)
+
return metrics
def train(self) -> None:
@@ -106,9 +119,7 @@ class Trainer:
# Running average for the loss.
loss_avg = RunningAverage()
- data_loader = self.model.data_loaders["train"]
-
- for batch, samples in enumerate(data_loader):
+ for batch, samples in enumerate(self.model.train_dataloader()):
self.callbacks.on_train_batch_begin(batch)
metrics = self.training_step(batch, samples, loss_avg)
self.callbacks.on_train_batch_end(batch, logs=metrics)
@@ -119,6 +130,7 @@ class Trainer:
batch: int,
samples: Tuple[Tensor, Tensor],
loss_avg: Type[RunningAverage],
+ use_swa: bool = False,
) -> Dict:
"""Performs the validation step."""
# Pass the tensor to the device for computation.
@@ -130,44 +142,32 @@ class Trainer:
# Forward pass.
# Get the network prediction.
- output = self.model.network(data)
+ # Use SWA if available and using test dataset.
+ if use_swa and self.model.swa_network is None:
+ output = self.model.swa_network(data)
+ else:
+ output = self.model.network(data)
# Compute the loss.
- loss = self.model.criterion(output, targets)
+ loss = self.model.loss_fn(output, targets)
# Compute metrics.
- loss_avg.update(loss.item())
- output = output.data.cpu()
- targets = targets.data.cpu()
- metrics = {
- metric: self.model.metrics[metric](output, targets)
- for metric in self.model.metrics
- }
- metrics["loss"] = loss.item()
+ metrics = self.compute_metrics(output, targets, loss, loss_avg)
return metrics
- def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None:
- log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
- logger.debug(
- log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
- )
-
- def validate(self, epoch: Optional[int] = None) -> Dict:
+ def validate(self) -> Dict:
"""Runs the validation loop for one epoch."""
# Set model to eval mode.
self.model.eval()
# Running average for the loss.
- data_loader = self.model.data_loaders["val"]
-
- # Running average for the loss.
loss_avg = RunningAverage()
# Summary for the current eval loop.
summary = []
- for batch, samples in enumerate(data_loader):
+ for batch, samples in enumerate(self.model.val_dataloader()):
self.callbacks.on_validation_batch_begin(batch)
metrics = self.validation_step(batch, samples, loss_avg)
self.callbacks.on_validation_batch_end(batch, logs=metrics)
@@ -178,14 +178,19 @@ class Trainer:
"val_" + metric: np.mean([x[metric] for x in summary])
for metric in summary[0]
}
- self._log_val_metric(metrics_mean, epoch)
return metrics_mean
- def fit(self) -> None:
+ def fit(self, model: Type[Model]) -> None:
"""Runs the training and evaluation loop."""
- logger.debug(f"Running an experiment called {self.experiment_name}.")
+ # Sets model, loads the data, criterion, and optimizers.
+ self.model = model
+ self.model.prepare_data()
+ self.model.configure_model()
+
+ # Configure callbacks.
+ self._configure_callbacks()
# Set start time.
t_start = time.time()
@@ -193,14 +198,15 @@ class Trainer:
self.callbacks.on_fit_begin()
# Run the training loop.
- for epoch in range(self.start_epoch, self.epochs + 1):
+ for epoch in range(self.start_epoch, self.max_epochs + 1):
self.callbacks.on_epoch_begin(epoch)
# Perform one training pass over the training set.
self.train()
# Evaluate the model on the validation set.
- val_metrics = self.validate(epoch)
+ val_metrics = self.validate()
+ log_val_metric(val_metrics, epoch)
self.callbacks.on_epoch_end(epoch, logs=val_metrics)
@@ -214,3 +220,43 @@ class Trainer:
self.callbacks.on_fit_end()
logger.info(f"Training took {t_training:.2f} s.")
+
+ # "Teardown".
+ self.model = None
+
+ def test(self, model: Type[Model]) -> Dict:
+ """Run inference on test data."""
+
+ # Sets model, loads the data, criterion, and optimizers.
+ self.model = model
+ self.model.prepare_data()
+ self.model.configure_model()
+
+ # Configure callbacks.
+ self._configure_callbacks()
+
+ self.model.eval()
+
+ # Check if SWA network is available.
+ use_swa = True if self.model.swa_network is not None else False
+
+ # Running average for the loss.
+ loss_avg = RunningAverage()
+
+ # Summary for the current test loop.
+ summary = []
+
+ for batch, samples in enumerate(self.model.test_dataloader()):
+ metrics = self.validation_step(batch, samples, loss_avg, use_swa)
+ summary.append(metrics)
+
+ # Compute mean of all test metrics.
+ metrics_mean = {
+ "test_" + metric: np.mean([x[metric] for x in summary])
+ for metric in summary[0]
+ }
+
+ # "Teardown".
+ self.model = None
+
+ return metrics_mean