diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-03 23:33:34 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-03 23:33:34 +0200 |
commit | 07dd14116fe1d8148fb614b160245287533620fc (patch) | |
tree | 63395d88b17a14ad453c52889fcf541e6cbbdd3e /src/training/train.py | |
parent | 704451318eb6b0b600ab314cb5aabfac82416bda (diff) |
Working Emnist lines dataset.
Diffstat (limited to 'src/training/train.py')
-rw-r--r-- | src/training/train.py | 237 |
1 files changed, 132 insertions, 105 deletions
diff --git a/src/training/train.py b/src/training/train.py index 8cd5110..3334c2e 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -2,17 +2,19 @@ from pathlib import Path import time -from typing import Dict, Optional, Type +from typing import Dict, List, Optional, Tuple, Type from loguru import logger import numpy as np import torch from tqdm import tqdm, trange +from training.callbacks import Callback, CallbackList from training.util import RunningAverage import wandb from text_recognizer.models import Model + torch.backends.cudnn.benchmark = True np.random.seed(4711) torch.manual_seed(4711) @@ -22,51 +24,82 @@ torch.cuda.manual_seed(4711) class Trainer: """Trainer for training PyTorch models.""" - # TODO implement wandb. - # TODO implement Bayesian parameter search. - def __init__( self, model: Type[Model], model_dir: Path, - epochs: int, - val_metric: str = "accuracy", + train_args: Dict, + callbacks: CallbackList, checkpoint_path: Optional[Path] = None, - use_wandb: Optional[bool] = False, ) -> None: """Initialization of the Trainer. Args: model (Type[Model]): A model object. model_dir (Path): Path to the model directory. - epochs (int): Number of epochs to train. - val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy". + train_args (Dict): The training arguments. + callbacks (CallbackList): List of callbacks to be called. checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. - use_wandb (Optional[bool]): Sync training to wandb. """ self.model = model self.model_dir = model_dir - self.epochs = epochs self.checkpoint_path = checkpoint_path - self.start_epoch = 0 + self.start_epoch = 1 + self.epochs = train_args["epochs"] + self.start_epoch + self.callbacks = callbacks if self.checkpoint_path is not None: - self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) - - if use_wandb: - # TODO implement wandb logging. - pass - - self.val_metric = val_metric - self.best_val_metric = 0.0 + self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + 1 # Parse the name of the experiment. experiment_dir = str(self.model_dir.parents[1]).split("/") self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1] + def training_step( + self, + batch: int, + samples: Tuple[torch.Tensor, torch.Tensor], + loss_avg: Type[RunningAverage], + ) -> Dict: + """Performs the training step.""" + # Pass the tensor to the device for computation. + data, targets = samples + data, targets = ( + data.to(self.model.device), + targets.to(self.model.device), + ) + + # Forward pass. + # Get the network prediction. + output = self.model.network(data) + + # Compute the loss. + loss = self.model.criterion(output, targets) + + # Backward pass. + # Clear the previous gradients. + self.model.optimizer.zero_grad() + + # Compute the gradients. + loss.backward() + + # Perform updates using calculated gradients. + self.model.optimizer.step() + + # Compute metrics. + loss_avg.update(loss.item()) + output = output.data.cpu() + targets = targets.data.cpu() + metrics = { + metric: self.model.metrics[metric](output, targets) + for metric in self.model.metrics + } + metrics["loss"] = loss_avg() + return metrics + def train(self) -> None: - """Training loop.""" + """Runs the training loop for one epoch.""" # Set model to traning mode. self.model.train() @@ -79,57 +112,54 @@ class Trainer: total=len(data_loader), leave=False, unit="step", - bar_format="{n_fmt}/{total_fmt} |{bar:20}| {remaining} {rate_inv_fmt}{postfix}", + bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}", ) as t: - for data, targets in data_loader: + for batch, samples in enumerate(data_loader): + self.callbacks.on_train_batch_begin(batch) - data, targets = ( - data.to(self.model.device), - targets.to(self.model.device), - ) + metrics = self.training_step(batch, samples, loss_avg) - # Forward pass. - # Get the network prediction. - output = self.model.network(data) - - # Compute the loss. - loss = self.model.criterion(output, targets) - - # Backward pass. - # Clear the previous gradients. - self.model.optimizer.zero_grad() - - # Compute the gradients. - loss.backward() - - # Perform updates using calculated gradients. - self.model.optimizer.step() - - # Compute metrics. - loss_avg.update(loss.item()) - output = output.data.cpu() - targets = targets.data.cpu() - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } - metrics["loss"] = loss_avg() + self.callbacks.on_train_batch_end(batch, logs=metrics) # Update Tqdm progress bar. t.set_postfix(**metrics) t.update() - # If the model has a learning rate scheduler, compute a step. - if self.model.lr_scheduler is not None: - self.model.lr_scheduler.step() - - def validate(self) -> Dict: - """Evaluation loop. + def validation_step( + self, + batch: int, + samples: Tuple[torch.Tensor, torch.Tensor], + loss_avg: Type[RunningAverage], + ) -> Dict: + """Performs the validation step.""" + # Pass the tensor to the device for computation. + data, targets = samples + data, targets = ( + data.to(self.model.device), + targets.to(self.model.device), + ) + + # Forward pass. + # Get the network prediction. + output = self.model.network(data) + + # Compute the loss. + loss = self.model.criterion(output, targets) + + # Compute metrics. + loss_avg.update(loss.item()) + output = output.data.cpu() + targets = targets.data.cpu() + metrics = { + metric: self.model.metrics[metric](output, targets) + for metric in self.model.metrics + } + metrics["loss"] = loss.item() - Returns: - Dict: A dictionary of evaluation metrics. + return metrics - """ + def validate(self, epoch: Optional[int] = None) -> Dict: + """Runs the validation loop for one epoch.""" # Set model to eval mode. self.model.eval() @@ -146,44 +176,37 @@ class Trainer: total=len(data_loader), leave=False, unit="step", - bar_format="{n_fmt}/{total_fmt} |{bar:20}| {remaining} {rate_inv_fmt}{postfix}", + bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}", ) as t: - for data, targets in data_loader: - data, targets = ( - data.to(self.model.device), - targets.to(self.model.device), - ) - - with torch.no_grad(): - # Forward pass. - # Get the network prediction. - output = self.model.network(data) - - # Compute the loss. - loss = self.model.criterion(output, targets) - - # Compute metrics. - loss_avg.update(loss.item()) - output = output.data.cpu() - targets = targets.data.cpu() - metrics = { - metric: self.model.metrics[metric](output, targets) - for metric in self.model.metrics - } - metrics["loss"] = loss.item() - - summary.append(metrics) + with torch.no_grad(): + for batch, samples in enumerate(data_loader): + self.callbacks.on_validation_batch_begin(batch) - # Update Tqdm progress bar. - t.set_postfix(**metrics) - t.update() + metrics = self.validation_step(batch, samples, loss_avg) + + self.callbacks.on_validation_batch_end(batch, logs=metrics) + + summary.append(metrics) + + # Update Tqdm progress bar. + t.set_postfix(**metrics) + t.update() # Compute mean of all metrics. metrics_mean = { - metric: np.mean([x[metric] for x in summary]) for metric in summary[0] + "val_" + metric: np.mean([x[metric] for x in summary]) + for metric in summary[0] } - metrics_str = " - ".join(f"{k}: {v}" for k, v in metrics_mean.items()) - logger.debug(metrics_str) + if epoch: + logger.debug( + f"Validation metrics at epoch {epoch} - " + + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) + ) + else: + logger.debug( + "Validation metrics - " + + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) + ) return metrics_mean @@ -192,31 +215,35 @@ class Trainer: logger.debug(f"Running an experiment called {self.experiment_name}.") t_start = time.time() + + self.callbacks.on_fit_begin() + + # TODO: fix progress bar as callback. # Run the training loop. for epoch in trange( + self.start_epoch, self.epochs, - initial=self.start_epoch, leave=False, - bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:10}| {remaining}{postfix}", + bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:30}| {remaining}{postfix}", desc="Epoch", ): + self.callbacks.on_epoch_begin(epoch) + # Perform one training pass over the training set. self.train() # Evaluate the model on the validation set. - val_metrics = self.validate() + val_metrics = self.validate(epoch) - # The validation metric to evaluate the model on, e.g. accuracy. - val_metric = val_metrics[self.val_metric] - is_best = val_metric >= self.best_val_metric - self.best_val_metric = val_metric if is_best else self.best_val_metric - # Save checkpoint. - self.model.save_checkpoint(self.model_dir, is_best, epoch, self.val_metric) + self.callbacks.on_epoch_end(epoch, logs=val_metrics) - if self.start_epoch > 0 and epoch + self.start_epoch == self.epochs: - logger.debug(f"Trained the model for {self.epochs} number of epochs.") + if self.model.stop_training: break + # Calculate the total training time. t_end = time.time() t_training = t_end - t_start + + self.callbacks.on_fit_end() + logger.info(f"Training took {t_training:.2f} s.") |