summaryrefslogtreecommitdiff
path: root/src/training/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/train.py')
-rw-r--r--src/training/train.py237
1 files changed, 132 insertions, 105 deletions
diff --git a/src/training/train.py b/src/training/train.py
index 8cd5110..3334c2e 100644
--- a/src/training/train.py
+++ b/src/training/train.py
@@ -2,17 +2,19 @@
from pathlib import Path
import time
-from typing import Dict, Optional, Type
+from typing import Dict, List, Optional, Tuple, Type
from loguru import logger
import numpy as np
import torch
from tqdm import tqdm, trange
+from training.callbacks import Callback, CallbackList
from training.util import RunningAverage
import wandb
from text_recognizer.models import Model
+
torch.backends.cudnn.benchmark = True
np.random.seed(4711)
torch.manual_seed(4711)
@@ -22,51 +24,82 @@ torch.cuda.manual_seed(4711)
class Trainer:
"""Trainer for training PyTorch models."""
- # TODO implement wandb.
- # TODO implement Bayesian parameter search.
-
def __init__(
self,
model: Type[Model],
model_dir: Path,
- epochs: int,
- val_metric: str = "accuracy",
+ train_args: Dict,
+ callbacks: CallbackList,
checkpoint_path: Optional[Path] = None,
- use_wandb: Optional[bool] = False,
) -> None:
"""Initialization of the Trainer.
Args:
model (Type[Model]): A model object.
model_dir (Path): Path to the model directory.
- epochs (int): Number of epochs to train.
- val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy".
+ train_args (Dict): The training arguments.
+ callbacks (CallbackList): List of callbacks to be called.
checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None.
- use_wandb (Optional[bool]): Sync training to wandb.
"""
self.model = model
self.model_dir = model_dir
- self.epochs = epochs
self.checkpoint_path = checkpoint_path
- self.start_epoch = 0
+ self.start_epoch = 1
+ self.epochs = train_args["epochs"] + self.start_epoch
+ self.callbacks = callbacks
if self.checkpoint_path is not None:
- self.start_epoch = self.model.load_checkpoint(self.checkpoint_path)
-
- if use_wandb:
- # TODO implement wandb logging.
- pass
-
- self.val_metric = val_metric
- self.best_val_metric = 0.0
+ self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + 1
# Parse the name of the experiment.
experiment_dir = str(self.model_dir.parents[1]).split("/")
self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1]
+ def training_step(
+ self,
+ batch: int,
+ samples: Tuple[torch.Tensor, torch.Tensor],
+ loss_avg: Type[RunningAverage],
+ ) -> Dict:
+ """Performs the training step."""
+ # Pass the tensor to the device for computation.
+ data, targets = samples
+ data, targets = (
+ data.to(self.model.device),
+ targets.to(self.model.device),
+ )
+
+ # Forward pass.
+ # Get the network prediction.
+ output = self.model.network(data)
+
+ # Compute the loss.
+ loss = self.model.criterion(output, targets)
+
+ # Backward pass.
+ # Clear the previous gradients.
+ self.model.optimizer.zero_grad()
+
+ # Compute the gradients.
+ loss.backward()
+
+ # Perform updates using calculated gradients.
+ self.model.optimizer.step()
+
+ # Compute metrics.
+ loss_avg.update(loss.item())
+ output = output.data.cpu()
+ targets = targets.data.cpu()
+ metrics = {
+ metric: self.model.metrics[metric](output, targets)
+ for metric in self.model.metrics
+ }
+ metrics["loss"] = loss_avg()
+ return metrics
+
def train(self) -> None:
- """Training loop."""
+ """Runs the training loop for one epoch."""
# Set model to traning mode.
self.model.train()
@@ -79,57 +112,54 @@ class Trainer:
total=len(data_loader),
leave=False,
unit="step",
- bar_format="{n_fmt}/{total_fmt} |{bar:20}| {remaining} {rate_inv_fmt}{postfix}",
+ bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}",
) as t:
- for data, targets in data_loader:
+ for batch, samples in enumerate(data_loader):
+ self.callbacks.on_train_batch_begin(batch)
- data, targets = (
- data.to(self.model.device),
- targets.to(self.model.device),
- )
+ metrics = self.training_step(batch, samples, loss_avg)
- # Forward pass.
- # Get the network prediction.
- output = self.model.network(data)
-
- # Compute the loss.
- loss = self.model.criterion(output, targets)
-
- # Backward pass.
- # Clear the previous gradients.
- self.model.optimizer.zero_grad()
-
- # Compute the gradients.
- loss.backward()
-
- # Perform updates using calculated gradients.
- self.model.optimizer.step()
-
- # Compute metrics.
- loss_avg.update(loss.item())
- output = output.data.cpu()
- targets = targets.data.cpu()
- metrics = {
- metric: self.model.metrics[metric](output, targets)
- for metric in self.model.metrics
- }
- metrics["loss"] = loss_avg()
+ self.callbacks.on_train_batch_end(batch, logs=metrics)
# Update Tqdm progress bar.
t.set_postfix(**metrics)
t.update()
- # If the model has a learning rate scheduler, compute a step.
- if self.model.lr_scheduler is not None:
- self.model.lr_scheduler.step()
-
- def validate(self) -> Dict:
- """Evaluation loop.
+ def validation_step(
+ self,
+ batch: int,
+ samples: Tuple[torch.Tensor, torch.Tensor],
+ loss_avg: Type[RunningAverage],
+ ) -> Dict:
+ """Performs the validation step."""
+ # Pass the tensor to the device for computation.
+ data, targets = samples
+ data, targets = (
+ data.to(self.model.device),
+ targets.to(self.model.device),
+ )
+
+ # Forward pass.
+ # Get the network prediction.
+ output = self.model.network(data)
+
+ # Compute the loss.
+ loss = self.model.criterion(output, targets)
+
+ # Compute metrics.
+ loss_avg.update(loss.item())
+ output = output.data.cpu()
+ targets = targets.data.cpu()
+ metrics = {
+ metric: self.model.metrics[metric](output, targets)
+ for metric in self.model.metrics
+ }
+ metrics["loss"] = loss.item()
- Returns:
- Dict: A dictionary of evaluation metrics.
+ return metrics
- """
+ def validate(self, epoch: Optional[int] = None) -> Dict:
+ """Runs the validation loop for one epoch."""
# Set model to eval mode.
self.model.eval()
@@ -146,44 +176,37 @@ class Trainer:
total=len(data_loader),
leave=False,
unit="step",
- bar_format="{n_fmt}/{total_fmt} |{bar:20}| {remaining} {rate_inv_fmt}{postfix}",
+ bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}",
) as t:
- for data, targets in data_loader:
- data, targets = (
- data.to(self.model.device),
- targets.to(self.model.device),
- )
-
- with torch.no_grad():
- # Forward pass.
- # Get the network prediction.
- output = self.model.network(data)
-
- # Compute the loss.
- loss = self.model.criterion(output, targets)
-
- # Compute metrics.
- loss_avg.update(loss.item())
- output = output.data.cpu()
- targets = targets.data.cpu()
- metrics = {
- metric: self.model.metrics[metric](output, targets)
- for metric in self.model.metrics
- }
- metrics["loss"] = loss.item()
-
- summary.append(metrics)
+ with torch.no_grad():
+ for batch, samples in enumerate(data_loader):
+ self.callbacks.on_validation_batch_begin(batch)
- # Update Tqdm progress bar.
- t.set_postfix(**metrics)
- t.update()
+ metrics = self.validation_step(batch, samples, loss_avg)
+
+ self.callbacks.on_validation_batch_end(batch, logs=metrics)
+
+ summary.append(metrics)
+
+ # Update Tqdm progress bar.
+ t.set_postfix(**metrics)
+ t.update()
# Compute mean of all metrics.
metrics_mean = {
- metric: np.mean([x[metric] for x in summary]) for metric in summary[0]
+ "val_" + metric: np.mean([x[metric] for x in summary])
+ for metric in summary[0]
}
- metrics_str = " - ".join(f"{k}: {v}" for k, v in metrics_mean.items())
- logger.debug(metrics_str)
+ if epoch:
+ logger.debug(
+ f"Validation metrics at epoch {epoch} - "
+ + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
+ )
+ else:
+ logger.debug(
+ "Validation metrics - "
+ + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
+ )
return metrics_mean
@@ -192,31 +215,35 @@ class Trainer:
logger.debug(f"Running an experiment called {self.experiment_name}.")
t_start = time.time()
+
+ self.callbacks.on_fit_begin()
+
+ # TODO: fix progress bar as callback.
# Run the training loop.
for epoch in trange(
+ self.start_epoch,
self.epochs,
- initial=self.start_epoch,
leave=False,
- bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:10}| {remaining}{postfix}",
+ bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:30}| {remaining}{postfix}",
desc="Epoch",
):
+ self.callbacks.on_epoch_begin(epoch)
+
# Perform one training pass over the training set.
self.train()
# Evaluate the model on the validation set.
- val_metrics = self.validate()
+ val_metrics = self.validate(epoch)
- # The validation metric to evaluate the model on, e.g. accuracy.
- val_metric = val_metrics[self.val_metric]
- is_best = val_metric >= self.best_val_metric
- self.best_val_metric = val_metric if is_best else self.best_val_metric
- # Save checkpoint.
- self.model.save_checkpoint(self.model_dir, is_best, epoch, self.val_metric)
+ self.callbacks.on_epoch_end(epoch, logs=val_metrics)
- if self.start_epoch > 0 and epoch + self.start_epoch == self.epochs:
- logger.debug(f"Trained the model for {self.epochs} number of epochs.")
+ if self.model.stop_training:
break
+ # Calculate the total training time.
t_end = time.time()
t_training = t_end - t_start
+
+ self.callbacks.on_fit_end()
+
logger.info(f"Training took {t_training:.2f} s.")