summaryrefslogtreecommitdiff
path: root/src/training/trainer
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer')
-rw-r--r--src/training/trainer/callbacks/__init__.py8
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py58
-rw-r--r--src/training/trainer/train.py94
3 files changed, 119 insertions, 41 deletions
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index 95ec142..80c4177 100644
--- a/src/training/trainer/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -7,7 +7,12 @@ from .lr_schedulers import (
SWA,
)
from .progress_bar import ProgressBar
-from .wandb_callbacks import WandbCallback, WandbImageLogger, WandbSegmentationLogger
+from .wandb_callbacks import (
+ WandbCallback,
+ WandbImageLogger,
+ WandbReconstructionLogger,
+ WandbSegmentationLogger,
+)
__all__ = [
"Callback",
@@ -17,6 +22,7 @@ __all__ = [
"LRScheduler",
"WandbCallback",
"WandbImageLogger",
+ "WandbReconstructionLogger",
"WandbSegmentationLogger",
"ProgressBar",
"SWA",
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index 20414df..552a4f4 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -201,3 +201,61 @@ class WandbSegmentationLogger(Callback):
)
wandb.log({f"{self.caption}": images}, commit=False)
+
+
+class WandbReconstructionLogger(Callback):
+ """Custom W&B callback for image reconstructions logging."""
+
+ def __init__(
+ self, example_indices: Optional[List] = None, num_examples: int = 4,
+ ) -> None:
+ """Initializes the WandbImageLogger with the model to train.
+
+ Args:
+ example_indices (Optional[List]): Indices for validation images. Defaults to None.
+ num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+
+ """
+
+ super().__init__()
+ self.caption = None
+ self.example_indices = example_indices
+ self.test_sample_indices = None
+ self.num_examples = num_examples
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and extracts validation images from the dataset."""
+ self.model = model
+ self.caption = "Validation Reconstructions Examples"
+ if self.example_indices is None:
+ self.example_indices = np.random.randint(
+ 0, len(self.model.val_dataset), self.num_examples
+ )
+ self.images = self.model.val_dataset.dataset.data[self.example_indices]
+
+ def on_test_begin(self) -> None:
+ """Get samples from test dataset."""
+ self.caption = "Test Reconstructions Examples"
+ if self.test_sample_indices is None:
+ self.test_sample_indices = np.random.randint(
+ 0, len(self.model.test_dataset), self.num_examples
+ )
+ self.images = self.model.test_dataset.data[self.test_sample_indices]
+
+ def on_test_end(self) -> None:
+ """Log test images."""
+ self.on_epoch_end(0, {})
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Get network predictions on validation images."""
+ images = []
+ for image in self.images:
+ reconstructed_image = (
+ self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
+ )
+ images.append(image)
+ images.append(reconstructed_image)
+
+ wandb.log(
+ {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False,
+ )
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index 40a25da..b770c94 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -12,7 +12,7 @@ import torch
from torch import Tensor
from torch.optim.swa_utils import update_bn
from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA
-from training.trainer.util import log_val_metric, RunningAverage
+from training.trainer.util import log_val_metric
import wandb
from text_recognizer.models import Model
@@ -30,8 +30,6 @@ warnings.filterwarnings("ignore")
class Trainer:
"""Trainer for training PyTorch models."""
- # TODO: proper add teardown?
-
def __init__(
self,
max_epochs: int,
@@ -46,7 +44,7 @@ class Trainer:
max_epochs (int): The maximum number of epochs in the training loop.
callbacks (CallbackList): List of callbacks to be called.
transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
- max_norm (float): Max norm for gradient clipping. Defaults to 0.0.
+ max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0.
freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training
Transformers. Default is None.
@@ -79,35 +77,32 @@ class Trainer:
self.callbacks = CallbackList(self.model, self.callbacks)
def compute_metrics(
- self,
- output: Tensor,
- targets: Tensor,
- loss: Tensor,
- loss_avg: Type[RunningAverage],
+ self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int
) -> Dict:
"""Computes metrics for output and target pairs."""
# Compute metrics.
loss = loss.detach().float().item()
- loss_avg.update(loss)
output = output.detach()
targets = targets.detach()
if self.model.metrics is not None:
- metrics = {
- metric: self.model.metrics[metric](output, targets)
- for metric in self.model.metrics
- }
+ metrics = {}
+ for metric in self.model.metrics:
+ if metric == "cer" or metric == "wer":
+ metrics[metric] = self.model.metrics[metric](
+ output,
+ targets,
+ batch_size,
+ self.model.mapper(self.model.pad_token),
+ )
+ else:
+ metrics[metric] = self.model.metrics[metric](output, targets)
else:
metrics = {}
metrics["loss"] = loss
return metrics
- def training_step(
- self,
- batch: int,
- samples: Tuple[Tensor, Tensor],
- loss_avg: Type[RunningAverage],
- ) -> Dict:
+ def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
"""Performs the training step."""
# Pass the tensor to the device for computation.
data, targets = samples
@@ -116,25 +111,43 @@ class Trainer:
targets.to(self.model.device),
)
+ batch_size = data.shape[0]
+
+ # Placeholder for uxiliary loss.
+ aux_loss = None
+
# Forward pass.
# Get the network prediction.
if self.transformer_model:
if self.freeze_backbone is not None and batch < self.freeze_backbone:
with torch.no_grad():
image_features = self.model.network.extract_image_features(data)
+
+ if isinstance(image_features, Tuple):
+ image_features, _ = image_features
+
output = self.model.network.decode_image_features(
image_features, targets[:, :-1]
)
else:
output = self.model.network.forward(data, targets[:, :-1])
+ if isinstance(output, Tuple):
+ output, aux_loss = output
output = rearrange(output, "b t v -> (b t) v")
targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
else:
output = self.model.forward(data)
+ if isinstance(output, Tuple):
+ output, aux_loss = output
+ targets = data
+
# Compute the loss.
loss = self.model.criterion(output, targets)
+ if aux_loss is not None:
+ loss += aux_loss
+
# Backward pass.
# Clear the previous gradients.
for p in self.model.network.parameters():
@@ -151,7 +164,7 @@ class Trainer:
# Perform updates using calculated gradients.
self.model.optimizer.step()
- metrics = self.compute_metrics(output, targets, loss, loss_avg)
+ metrics = self.compute_metrics(output, targets, loss, batch_size)
return metrics
@@ -160,22 +173,15 @@ class Trainer:
# Set model to traning mode.
self.model.train()
- # Running average for the loss.
- loss_avg = RunningAverage()
-
for batch, samples in enumerate(self.model.train_dataloader()):
self.callbacks.on_train_batch_begin(batch)
- metrics = self.training_step(batch, samples, loss_avg)
+ metrics = self.training_step(batch, samples)
self.callbacks.on_train_batch_end(batch, logs=metrics)
@torch.no_grad()
- def validation_step(
- self,
- batch: int,
- samples: Tuple[Tensor, Tensor],
- loss_avg: Type[RunningAverage],
- ) -> Dict:
+ def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
"""Performs the validation step."""
+
# Pass the tensor to the device for computation.
data, targets = samples
data, targets = (
@@ -183,21 +189,35 @@ class Trainer:
targets.to(self.model.device),
)
+ batch_size = data.shape[0]
+
+ # Placeholder for uxiliary loss.
+ aux_loss = None
+
# Forward pass.
# Get the network prediction.
# Use SWA if available and using test dataset.
if self.transformer_model:
output = self.model.network.forward(data, targets[:, :-1])
+ if isinstance(output, Tuple):
+ output, aux_loss = output
output = rearrange(output, "b t v -> (b t) v")
targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
else:
output = self.model.forward(data)
+ if isinstance(output, Tuple):
+ output, aux_loss = output
+ targets = data
+
# Compute the loss.
loss = self.model.criterion(output, targets)
+ if aux_loss is not None:
+ loss += aux_loss
+
# Compute metrics.
- metrics = self.compute_metrics(output, targets, loss, loss_avg)
+ metrics = self.compute_metrics(output, targets, loss, batch_size)
return metrics
@@ -206,15 +226,12 @@ class Trainer:
# Set model to eval mode.
self.model.eval()
- # Running average for the loss.
- loss_avg = RunningAverage()
-
# Summary for the current eval loop.
summary = []
for batch, samples in enumerate(self.model.val_dataloader()):
self.callbacks.on_validation_batch_begin(batch)
- metrics = self.validation_step(batch, samples, loss_avg)
+ metrics = self.validation_step(batch, samples)
self.callbacks.on_validation_batch_end(batch, logs=metrics)
summary.append(metrics)
@@ -287,14 +304,11 @@ class Trainer:
# Check if SWA network is available.
self.model.use_swa_model()
- # Running average for the loss.
- loss_avg = RunningAverage()
-
# Summary for the current test loop.
summary = []
for batch, samples in enumerate(self.model.test_dataloader()):
- metrics = self.validation_step(batch, samples, loss_avg)
+ metrics = self.validation_step(batch, samples)
summary.append(metrics)
self.callbacks.on_test_end()