summaryrefslogtreecommitdiff
path: root/src/training/trainer/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer/train.py')
-rw-r--r--src/training/trainer/train.py94
1 files changed, 54 insertions, 40 deletions
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index 40a25da..b770c94 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -12,7 +12,7 @@ import torch
from torch import Tensor
from torch.optim.swa_utils import update_bn
from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA
-from training.trainer.util import log_val_metric, RunningAverage
+from training.trainer.util import log_val_metric
import wandb
from text_recognizer.models import Model
@@ -30,8 +30,6 @@ warnings.filterwarnings("ignore")
class Trainer:
"""Trainer for training PyTorch models."""
- # TODO: proper add teardown?
-
def __init__(
self,
max_epochs: int,
@@ -46,7 +44,7 @@ class Trainer:
max_epochs (int): The maximum number of epochs in the training loop.
callbacks (CallbackList): List of callbacks to be called.
transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
- max_norm (float): Max norm for gradient clipping. Defaults to 0.0.
+ max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0.
freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training
Transformers. Default is None.
@@ -79,35 +77,32 @@ class Trainer:
self.callbacks = CallbackList(self.model, self.callbacks)
def compute_metrics(
- self,
- output: Tensor,
- targets: Tensor,
- loss: Tensor,
- loss_avg: Type[RunningAverage],
+ self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int
) -> Dict:
"""Computes metrics for output and target pairs."""
# Compute metrics.
loss = loss.detach().float().item()
- loss_avg.update(loss)
output = output.detach()
targets = targets.detach()
if self.model.metrics is not None:
- metrics = {
- metric: self.model.metrics[metric](output, targets)
- for metric in self.model.metrics
- }
+ metrics = {}
+ for metric in self.model.metrics:
+ if metric == "cer" or metric == "wer":
+ metrics[metric] = self.model.metrics[metric](
+ output,
+ targets,
+ batch_size,
+ self.model.mapper(self.model.pad_token),
+ )
+ else:
+ metrics[metric] = self.model.metrics[metric](output, targets)
else:
metrics = {}
metrics["loss"] = loss
return metrics
- def training_step(
- self,
- batch: int,
- samples: Tuple[Tensor, Tensor],
- loss_avg: Type[RunningAverage],
- ) -> Dict:
+ def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
"""Performs the training step."""
# Pass the tensor to the device for computation.
data, targets = samples
@@ -116,25 +111,43 @@ class Trainer:
targets.to(self.model.device),
)
+ batch_size = data.shape[0]
+
+ # Placeholder for uxiliary loss.
+ aux_loss = None
+
# Forward pass.
# Get the network prediction.
if self.transformer_model:
if self.freeze_backbone is not None and batch < self.freeze_backbone:
with torch.no_grad():
image_features = self.model.network.extract_image_features(data)
+
+ if isinstance(image_features, Tuple):
+ image_features, _ = image_features
+
output = self.model.network.decode_image_features(
image_features, targets[:, :-1]
)
else:
output = self.model.network.forward(data, targets[:, :-1])
+ if isinstance(output, Tuple):
+ output, aux_loss = output
output = rearrange(output, "b t v -> (b t) v")
targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
else:
output = self.model.forward(data)
+ if isinstance(output, Tuple):
+ output, aux_loss = output
+ targets = data
+
# Compute the loss.
loss = self.model.criterion(output, targets)
+ if aux_loss is not None:
+ loss += aux_loss
+
# Backward pass.
# Clear the previous gradients.
for p in self.model.network.parameters():
@@ -151,7 +164,7 @@ class Trainer:
# Perform updates using calculated gradients.
self.model.optimizer.step()
- metrics = self.compute_metrics(output, targets, loss, loss_avg)
+ metrics = self.compute_metrics(output, targets, loss, batch_size)
return metrics
@@ -160,22 +173,15 @@ class Trainer:
# Set model to traning mode.
self.model.train()
- # Running average for the loss.
- loss_avg = RunningAverage()
-
for batch, samples in enumerate(self.model.train_dataloader()):
self.callbacks.on_train_batch_begin(batch)
- metrics = self.training_step(batch, samples, loss_avg)
+ metrics = self.training_step(batch, samples)
self.callbacks.on_train_batch_end(batch, logs=metrics)
@torch.no_grad()
- def validation_step(
- self,
- batch: int,
- samples: Tuple[Tensor, Tensor],
- loss_avg: Type[RunningAverage],
- ) -> Dict:
+ def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
"""Performs the validation step."""
+
# Pass the tensor to the device for computation.
data, targets = samples
data, targets = (
@@ -183,21 +189,35 @@ class Trainer:
targets.to(self.model.device),
)
+ batch_size = data.shape[0]
+
+ # Placeholder for uxiliary loss.
+ aux_loss = None
+
# Forward pass.
# Get the network prediction.
# Use SWA if available and using test dataset.
if self.transformer_model:
output = self.model.network.forward(data, targets[:, :-1])
+ if isinstance(output, Tuple):
+ output, aux_loss = output
output = rearrange(output, "b t v -> (b t) v")
targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
else:
output = self.model.forward(data)
+ if isinstance(output, Tuple):
+ output, aux_loss = output
+ targets = data
+
# Compute the loss.
loss = self.model.criterion(output, targets)
+ if aux_loss is not None:
+ loss += aux_loss
+
# Compute metrics.
- metrics = self.compute_metrics(output, targets, loss, loss_avg)
+ metrics = self.compute_metrics(output, targets, loss, batch_size)
return metrics
@@ -206,15 +226,12 @@ class Trainer:
# Set model to eval mode.
self.model.eval()
- # Running average for the loss.
- loss_avg = RunningAverage()
-
# Summary for the current eval loop.
summary = []
for batch, samples in enumerate(self.model.val_dataloader()):
self.callbacks.on_validation_batch_begin(batch)
- metrics = self.validation_step(batch, samples, loss_avg)
+ metrics = self.validation_step(batch, samples)
self.callbacks.on_validation_batch_end(batch, logs=metrics)
summary.append(metrics)
@@ -287,14 +304,11 @@ class Trainer:
# Check if SWA network is available.
self.model.use_swa_model()
- # Running average for the loss.
- loss_avg = RunningAverage()
-
# Summary for the current test loop.
summary = []
for batch, samples in enumerate(self.model.test_dataloader()):
- metrics = self.validation_step(batch, samples, loss_avg)
+ metrics = self.validation_step(batch, samples)
summary.append(metrics)
self.callbacks.on_test_end()