diff options
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r-- | src/training/trainer/train.py | 170 |
1 files changed, 108 insertions, 62 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index a75ae8f..b240157 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -8,8 +8,9 @@ from loguru import logger import numpy as np import torch from torch import Tensor +from torch.optim.swa_utils import update_bn from training.trainer.callbacks import Callback, CallbackList -from training.trainer.util import RunningAverage +from training.trainer.util import log_val_metric, RunningAverage import wandb from text_recognizer.models import Model @@ -24,37 +25,55 @@ torch.cuda.manual_seed(4711) class Trainer: """Trainer for training PyTorch models.""" - def __init__( - self, - model: Type[Model], - model_dir: Path, - train_args: Dict, - callbacks: CallbackList, - checkpoint_path: Optional[Path] = None, - ) -> None: + # TODO: proper add teardown? + + def __init__(self, max_epochs: int, callbacks: List[Type[Callback]],) -> None: """Initialization of the Trainer. Args: - model (Type[Model]): A model object. - model_dir (Path): Path to the model directory. - train_args (Dict): The training arguments. + max_epochs (int): The maximum number of epochs in the training loop. callbacks (CallbackList): List of callbacks to be called. - checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. """ - self.model = model - self.model_dir = model_dir - self.checkpoint_path = checkpoint_path + # Training arguments. self.start_epoch = 1 - self.epochs = train_args["epochs"] + self.max_epochs = max_epochs self.callbacks = callbacks - if self.checkpoint_path is not None: - self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + # Flag for setting callbacks. + self.callbacks_configured = False + + # Model placeholders + self.model = None + + def _configure_callbacks(self) -> None: + if not self.callbacks_configured: + # Instantiate a CallbackList. + self.callbacks = CallbackList(self.model, self.callbacks) + + def compute_metrics( + self, + output: Tensor, + targets: Tensor, + loss: Tensor, + loss_avg: Type[RunningAverage], + ) -> Dict: + """Computes metrics for output and target pairs.""" + # Compute metrics. + loss = loss.detach().float().item() + loss_avg.update(loss) + output = output.detach() + targets = targets.detach() + if self.model.metrics is not None: + metrics = { + metric: self.model.metrics[metric](output, targets) + for metric in self.model.metrics + } + else: + metrics = {} + metrics["loss"] = loss - # Parse the name of the experiment. - experiment_dir = str(self.model_dir.parents[1]).split("/") - self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1] + return metrics def training_step( self, @@ -75,11 +94,12 @@ class Trainer: output = self.model.network(data) # Compute the loss. - loss = self.model.criterion(output, targets) + loss = self.model.loss_fn(output, targets) # Backward pass. # Clear the previous gradients. - self.model.optimizer.zero_grad() + for p in self.model.network.parameters(): + p.grad = None # Compute the gradients. loss.backward() @@ -87,15 +107,8 @@ class Trainer: # Perform updates using calculated gradients. self.model.optimizer.step() - # Compute metrics. - loss_avg.update(loss.item()) - output = output.data.cpu() - targets = targets.data.cpu() - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } - metrics["loss"] = loss_avg() + metrics = self.compute_metrics(output, targets, loss, loss_avg) + return metrics def train(self) -> None: @@ -106,9 +119,7 @@ class Trainer: # Running average for the loss. loss_avg = RunningAverage() - data_loader = self.model.data_loaders["train"] - - for batch, samples in enumerate(data_loader): + for batch, samples in enumerate(self.model.train_dataloader()): self.callbacks.on_train_batch_begin(batch) metrics = self.training_step(batch, samples, loss_avg) self.callbacks.on_train_batch_end(batch, logs=metrics) @@ -119,6 +130,7 @@ class Trainer: batch: int, samples: Tuple[Tensor, Tensor], loss_avg: Type[RunningAverage], + use_swa: bool = False, ) -> Dict: """Performs the validation step.""" # Pass the tensor to the device for computation. @@ -130,44 +142,32 @@ class Trainer: # Forward pass. # Get the network prediction. - output = self.model.network(data) + # Use SWA if available and using test dataset. + if use_swa and self.model.swa_network is None: + output = self.model.swa_network(data) + else: + output = self.model.network(data) # Compute the loss. - loss = self.model.criterion(output, targets) + loss = self.model.loss_fn(output, targets) # Compute metrics. - loss_avg.update(loss.item()) - output = output.data.cpu() - targets = targets.data.cpu() - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } - metrics["loss"] = loss.item() + metrics = self.compute_metrics(output, targets, loss, loss_avg) return metrics - def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None: - log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") - logger.debug( - log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) - ) - - def validate(self, epoch: Optional[int] = None) -> Dict: + def validate(self) -> Dict: """Runs the validation loop for one epoch.""" # Set model to eval mode. self.model.eval() # Running average for the loss. - data_loader = self.model.data_loaders["val"] - - # Running average for the loss. loss_avg = RunningAverage() # Summary for the current eval loop. summary = [] - for batch, samples in enumerate(data_loader): + for batch, samples in enumerate(self.model.val_dataloader()): self.callbacks.on_validation_batch_begin(batch) metrics = self.validation_step(batch, samples, loss_avg) self.callbacks.on_validation_batch_end(batch, logs=metrics) @@ -178,14 +178,19 @@ class Trainer: "val_" + metric: np.mean([x[metric] for x in summary]) for metric in summary[0] } - self._log_val_metric(metrics_mean, epoch) return metrics_mean - def fit(self) -> None: + def fit(self, model: Type[Model]) -> None: """Runs the training and evaluation loop.""" - logger.debug(f"Running an experiment called {self.experiment_name}.") + # Sets model, loads the data, criterion, and optimizers. + self.model = model + self.model.prepare_data() + self.model.configure_model() + + # Configure callbacks. + self._configure_callbacks() # Set start time. t_start = time.time() @@ -193,14 +198,15 @@ class Trainer: self.callbacks.on_fit_begin() # Run the training loop. - for epoch in range(self.start_epoch, self.epochs + 1): + for epoch in range(self.start_epoch, self.max_epochs + 1): self.callbacks.on_epoch_begin(epoch) # Perform one training pass over the training set. self.train() # Evaluate the model on the validation set. - val_metrics = self.validate(epoch) + val_metrics = self.validate() + log_val_metric(val_metrics, epoch) self.callbacks.on_epoch_end(epoch, logs=val_metrics) @@ -214,3 +220,43 @@ class Trainer: self.callbacks.on_fit_end() logger.info(f"Training took {t_training:.2f} s.") + + # "Teardown". + self.model = None + + def test(self, model: Type[Model]) -> Dict: + """Run inference on test data.""" + + # Sets model, loads the data, criterion, and optimizers. + self.model = model + self.model.prepare_data() + self.model.configure_model() + + # Configure callbacks. + self._configure_callbacks() + + self.model.eval() + + # Check if SWA network is available. + use_swa = True if self.model.swa_network is not None else False + + # Running average for the loss. + loss_avg = RunningAverage() + + # Summary for the current test loop. + summary = [] + + for batch, samples in enumerate(self.model.test_dataloader()): + metrics = self.validation_step(batch, samples, loss_avg, use_swa) + summary.append(metrics) + + # Compute mean of all test metrics. + metrics_mean = { + "test_" + metric: np.mean([x[metric] for x in summary]) + for metric in summary[0] + } + + # "Teardown". + self.model = None + + return metrics_mean |