summaryrefslogtreecommitdiff
path: root/src/training/trainer
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer')
-rw-r--r--src/training/trainer/__init__.py2
-rw-r--r--src/training/trainer/callbacks/__init__.py21
-rw-r--r--src/training/trainer/callbacks/base.py248
-rw-r--r--src/training/trainer/callbacks/early_stopping.py108
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py97
-rw-r--r--src/training/trainer/callbacks/progress_bar.py61
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py93
-rw-r--r--src/training/trainer/population_based_training/__init__.py1
-rw-r--r--src/training/trainer/population_based_training/population_based_training.py1
-rw-r--r--src/training/trainer/train.py216
-rw-r--r--src/training/trainer/util.py19
11 files changed, 867 insertions, 0 deletions
diff --git a/src/training/trainer/__init__.py b/src/training/trainer/__init__.py
new file mode 100644
index 0000000..de41bfb
--- /dev/null
+++ b/src/training/trainer/__init__.py
@@ -0,0 +1,2 @@
+"""Trainer modules."""
+from .train import Trainer
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
new file mode 100644
index 0000000..5942276
--- /dev/null
+++ b/src/training/trainer/callbacks/__init__.py
@@ -0,0 +1,21 @@
+"""The callback modules used in the training script."""
+from .base import Callback, CallbackList, Checkpoint
+from .early_stopping import EarlyStopping
+from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR
+from .progress_bar import ProgressBar
+from .wandb_callbacks import WandbCallback, WandbImageLogger
+
+__all__ = [
+ "Callback",
+ "CallbackList",
+ "Checkpoint",
+ "EarlyStopping",
+ "WandbCallback",
+ "WandbImageLogger",
+ "CyclicLR",
+ "MultiStepLR",
+ "OneCycleLR",
+ "ProgressBar",
+ "ReduceLROnPlateau",
+ "StepLR",
+]
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py
new file mode 100644
index 0000000..8df94f3
--- /dev/null
+++ b/src/training/trainer/callbacks/base.py
@@ -0,0 +1,248 @@
+"""Metaclass for callback functions."""
+
+from enum import Enum
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from loguru import logger
+import numpy as np
+import torch
+
+from text_recognizer.models import Model
+
+
+class ModeKeys:
+ """Mode keys for CallbackList."""
+
+ TRAIN = "train"
+ VALIDATION = "validation"
+
+
+class Callback:
+ """Metaclass for callbacks used in training."""
+
+ def __init__(self) -> None:
+ """Initializes the Callback instance."""
+ self.model = None
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Set the model."""
+ self.model = model
+
+ def on_fit_begin(self) -> None:
+ """Called when fit begins."""
+ pass
+
+ def on_fit_end(self) -> None:
+ """Called when fit ends."""
+ pass
+
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the beginning of an epoch. Only used in training mode."""
+ pass
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch. Only used in training mode."""
+ pass
+
+ def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch."""
+ pass
+
+ def on_validation_batch_begin(
+ self, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Called at the beginning of an epoch."""
+ pass
+
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch."""
+ pass
+
+
+class CallbackList:
+ """Container for abstracting away callback calls."""
+
+ mode_keys = ModeKeys()
+
+ def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None:
+ """Container for `Callback` instances.
+
+ This object wraps a list of `Callback` instances and allows them all to be
+ called via a single end point.
+
+ Args:
+ model (Type[Model]): A `Model` instance.
+ callbacks (List[Callback]): List of `Callback` instances. Defaults to None.
+
+ """
+
+ self._callbacks = callbacks or []
+ if model:
+ self.set_model(model)
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Set the model for all callbacks."""
+ self.model = model
+ for callback in self._callbacks:
+ callback.set_model(model=self.model)
+
+ def append(self, callback: Type[Callback]) -> None:
+ """Append new callback to callback list."""
+ self.callbacks.append(callback)
+
+ def on_fit_begin(self) -> None:
+ """Called when fit begins."""
+ for callback in self._callbacks:
+ callback.on_fit_begin()
+
+ def on_fit_end(self) -> None:
+ """Called when fit ends."""
+ for callback in self._callbacks:
+ callback.on_fit_end()
+
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the beginning of an epoch."""
+ for callback in self._callbacks:
+ callback.on_epoch_begin(epoch, logs)
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch."""
+ for callback in self._callbacks:
+ callback.on_epoch_end(epoch, logs)
+
+ def _call_batch_hook(
+ self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Helper function for all batch_{begin | end} methods."""
+ if hook == "begin":
+ self._call_batch_begin_hook(mode, batch, logs)
+ elif hook == "end":
+ self._call_batch_end_hook(mode, batch, logs)
+ else:
+ raise ValueError(f"Unrecognized hook {hook}.")
+
+ def _call_batch_begin_hook(
+ self, mode: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Helper function for all `on_*_batch_begin` methods."""
+ hook_name = f"on_{mode}_batch_begin"
+ self._call_batch_hook_helper(hook_name, batch, logs)
+
+ def _call_batch_end_hook(
+ self, mode: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Helper function for all `on_*_batch_end` methods."""
+ hook_name = f"on_{mode}_batch_end"
+ self._call_batch_hook_helper(hook_name, batch, logs)
+
+ def _call_batch_hook_helper(
+ self, hook_name: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Helper function for `on_*_batch_begin` methods."""
+ for callback in self._callbacks:
+ hook = getattr(callback, hook_name)
+ hook(batch, logs)
+
+ def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the beginning of an epoch."""
+ self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch."""
+ self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)
+
+ def on_validation_batch_begin(
+ self, batch: int, logs: Optional[Dict] = None
+ ) -> None:
+ """Called at the beginning of an epoch."""
+ self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)
+
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Called at the end of an epoch."""
+ self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)
+
+ def __iter__(self) -> iter:
+ """Iter function for callback list."""
+ return iter(self._callbacks)
+
+
+class Checkpoint(Callback):
+ """Saving model parameters at the end of each epoch."""
+
+ mode_dict = {
+ "min": torch.lt,
+ "max": torch.gt,
+ }
+
+ def __init__(
+ self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0
+ ) -> None:
+ """Monitors a quantity that will allow us to determine the best model weights.
+
+ Args:
+ monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
+ mode (str): Description of parameter `mode`. Defaults to "auto".
+ min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
+
+ """
+ super().__init__()
+ self.monitor = monitor
+ self.mode = mode
+ self.min_delta = torch.tensor(min_delta)
+
+ if mode not in ["auto", "min", "max"]:
+ logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
+
+ self.mode = "auto"
+
+ if self.mode == "auto":
+ if "accuracy" in self.monitor:
+ self.mode = "max"
+ else:
+ self.mode = "min"
+ logger.debug(
+ f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
+ )
+
+ torch_inf = torch.tensor(np.inf)
+ self.min_delta *= 1 if self.monitor_op == torch.gt else -1
+ self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
+
+ @property
+ def monitor_op(self) -> float:
+ """Returns the comparison method."""
+ return self.mode_dict[self.mode]
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Saves a checkpoint for the network parameters.
+
+ Args:
+ epoch (int): The current epoch.
+ logs (Dict): The log containing the monitored metrics.
+
+ """
+ current = self.get_monitor_value(logs)
+ if current is None:
+ return
+ if self.monitor_op(current - self.min_delta, self.best_score):
+ self.best_score = current
+ is_best = True
+ else:
+ is_best = False
+
+ self.model.save_checkpoint(is_best, epoch, self.monitor)
+
+ def get_monitor_value(self, logs: Dict) -> Union[float, None]:
+ """Extracts the monitored value."""
+ monitor_value = logs.get(self.monitor)
+ if monitor_value is None:
+ logger.warning(
+ f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
+ + f"metrics are: {','.join(list(logs.keys()))}"
+ )
+ return None
+ return monitor_value
diff --git a/src/training/trainer/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py
new file mode 100644
index 0000000..02b431f
--- /dev/null
+++ b/src/training/trainer/callbacks/early_stopping.py
@@ -0,0 +1,108 @@
+"""Implements Early stopping for PyTorch model."""
+from typing import Dict, Union
+
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from training.trainer.callbacks import Callback
+
+
+class EarlyStopping(Callback):
+ """Stops training when a monitored metric stops improving."""
+
+ mode_dict = {
+ "min": torch.lt,
+ "max": torch.gt,
+ }
+
+ def __init__(
+ self,
+ monitor: str = "val_loss",
+ min_delta: float = 0.0,
+ patience: int = 3,
+ mode: str = "auto",
+ ) -> None:
+ """Initializes the EarlyStopping callback.
+
+ Args:
+ monitor (str): Description of parameter `monitor`. Defaults to "val_loss".
+ min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
+ patience (int): Description of parameter `patience`. Defaults to 3.
+ mode (str): Description of parameter `mode`. Defaults to "auto".
+
+ """
+ super().__init__()
+ self.monitor = monitor
+ self.patience = patience
+ self.min_delta = torch.tensor(min_delta)
+ self.mode = mode
+ self.wait_count = 0
+ self.stopped_epoch = 0
+
+ if mode not in ["auto", "min", "max"]:
+ logger.warning(
+ f"EarlyStopping mode {mode} is unkown, fallback to auto mode."
+ )
+
+ self.mode = "auto"
+
+ if self.mode == "auto":
+ if "accuracy" in self.monitor:
+ self.mode = "max"
+ else:
+ self.mode = "min"
+ logger.debug(
+ f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}."
+ )
+
+ self.torch_inf = torch.tensor(np.inf)
+ self.min_delta *= 1 if self.monitor_op == torch.gt else -1
+ self.best_score = (
+ self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
+ )
+
+ @property
+ def monitor_op(self) -> float:
+ """Returns the comparison method."""
+ return self.mode_dict[self.mode]
+
+ def on_fit_begin(self) -> Union[torch.lt, torch.gt]:
+ """Reset the early stopping variables for reuse."""
+ self.wait_count = 0
+ self.stopped_epoch = 0
+ self.best_score = (
+ self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
+ )
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Computes the early stop criterion."""
+ current = self.get_monitor_value(logs)
+ if current is None:
+ return
+ if self.monitor_op(current - self.min_delta, self.best_score):
+ self.best_score = current
+ self.wait_count = 0
+ else:
+ self.wait_count += 1
+ if self.wait_count >= self.patience:
+ self.stopped_epoch = epoch
+ self.model.stop_training = True
+
+ def on_fit_end(self) -> None:
+ """Logs if early stopping was used."""
+ if self.stopped_epoch > 0:
+ logger.info(
+ f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
+ )
+
+ def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]:
+ """Extracts the monitor value."""
+ monitor_value = logs.get(self.monitor)
+ if monitor_value is None:
+ logger.warning(
+ f"Early stopping is conditioned on metric {self.monitor} which is not available. Available"
+ + f"metrics are: {','.join(list(logs.keys()))}"
+ )
+ return None
+ return torch.tensor(monitor_value)
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
new file mode 100644
index 0000000..ba2226a
--- /dev/null
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -0,0 +1,97 @@
+"""Callbacks for learning rate schedulers."""
+from typing import Callable, Dict, List, Optional, Type
+
+from training.trainer.callbacks import Callback
+
+from text_recognizer.models import Model
+
+
+class StepLR(Callback):
+ """Callback for StepLR."""
+
+ def __init__(self) -> None:
+ """Initializes the callback."""
+ super().__init__()
+ self.lr_scheduler = None
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and lr scheduler."""
+ self.model = model
+ self.lr_scheduler = self.model.lr_scheduler
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every epoch."""
+ self.lr_scheduler.step()
+
+
+class MultiStepLR(Callback):
+ """Callback for MultiStepLR."""
+
+ def __init__(self) -> None:
+ """Initializes the callback."""
+ super().__init__()
+ self.lr_scheduler = None
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and lr scheduler."""
+ self.model = model
+ self.lr_scheduler = self.model.lr_scheduler
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every epoch."""
+ self.lr_scheduler.step()
+
+
+class ReduceLROnPlateau(Callback):
+ """Callback for ReduceLROnPlateau."""
+
+ def __init__(self) -> None:
+ """Initializes the callback."""
+ super().__init__()
+ self.lr_scheduler = None
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and lr scheduler."""
+ self.model = model
+ self.lr_scheduler = self.model.lr_scheduler
+
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every epoch."""
+ val_loss = logs["val_loss"]
+ self.lr_scheduler.step(val_loss)
+
+
+class CyclicLR(Callback):
+ """Callback for CyclicLR."""
+
+ def __init__(self) -> None:
+ """Initializes the callback."""
+ super().__init__()
+ self.lr_scheduler = None
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and lr scheduler."""
+ self.model = model
+ self.lr_scheduler = self.model.lr_scheduler
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every training batch."""
+ self.lr_scheduler.step()
+
+
+class OneCycleLR(Callback):
+ """Callback for OneCycleLR."""
+
+ def __init__(self) -> None:
+ """Initializes the callback."""
+ super().__init__()
+ self.lr_scheduler = None
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and lr scheduler."""
+ self.model = model
+ self.lr_scheduler = self.model.lr_scheduler
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every training batch."""
+ self.lr_scheduler.step()
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py
new file mode 100644
index 0000000..1970747
--- /dev/null
+++ b/src/training/trainer/callbacks/progress_bar.py
@@ -0,0 +1,61 @@
+"""Progress bar callback for the training loop."""
+from typing import Dict, Optional
+
+from tqdm import tqdm
+from training.trainer.callbacks import Callback
+
+
+class ProgressBar(Callback):
+ """A TQDM progress bar for the training loop."""
+
+ def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
+ """Initializes the tqdm callback."""
+ self.epochs = epochs
+ self.log_batch_frequency = log_batch_frequency
+ self.progress_bar = None
+ self.val_metrics = {}
+
+ def _configure_progress_bar(self) -> None:
+ """Configures the tqdm progress bar with custom bar format."""
+ self.progress_bar = tqdm(
+ total=len(self.model.data_loaders["train"]),
+ leave=True,
+ unit="step",
+ mininterval=self.log_batch_frequency,
+ bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
+ )
+
+ def _key_abbreviations(self, logs: Dict) -> Dict:
+ """Changes the length of keys, so that the progress bar fits better."""
+
+ def rename(key: str) -> str:
+ """Renames accuracy to acc."""
+ return key.replace("accuracy", "acc")
+
+ return {rename(key): value for key, value in logs.items()}
+
+ def on_fit_begin(self) -> None:
+ """Creates a tqdm progress bar."""
+ self._configure_progress_bar()
+
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
+ """Updates the description with the current epoch."""
+ self.progress_bar.reset()
+ self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """At the end of each epoch, the validation metrics are updated to the progress bar."""
+ self.val_metrics = logs
+ self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+ self.progress_bar.update()
+
+ def on_train_batch_end(self, batch: int, logs: Dict) -> None:
+ """Updates the progress bar for each training step."""
+ if self.val_metrics:
+ logs.update(self.val_metrics)
+ self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+ self.progress_bar.update()
+
+ def on_fit_end(self) -> None:
+ """Closes the tqdm progress bar."""
+ self.progress_bar.close()
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
new file mode 100644
index 0000000..e44c745
--- /dev/null
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -0,0 +1,93 @@
+"""Callback for W&B."""
+from typing import Callable, Dict, List, Optional, Type
+
+import numpy as np
+from torchvision.transforms import Compose, ToTensor
+from training.trainer.callbacks import Callback
+import wandb
+
+from text_recognizer.datasets import Transpose
+from text_recognizer.models.base import Model
+
+
+class WandbCallback(Callback):
+ """A custom W&B metric logger for the trainer."""
+
+ def __init__(self, log_batch_frequency: int = None) -> None:
+ """Short summary.
+
+ Args:
+ log_batch_frequency (int): If None, metrics will be logged every epoch.
+ If set to an integer, callback will log every metrics every log_batch_frequency.
+
+ """
+ super().__init__()
+ self.log_batch_frequency = log_batch_frequency
+
+ def _on_batch_end(self, batch: int, logs: Dict) -> None:
+ if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
+ wandb.log(logs, commit=True)
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Logs training metrics."""
+ if logs is not None:
+ self._on_batch_end(batch, logs)
+
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Logs validation metrics."""
+ if logs is not None:
+ self._on_batch_end(batch, logs)
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Logs at epoch end."""
+ wandb.log(logs, commit=True)
+
+
+class WandbImageLogger(Callback):
+ """Custom W&B callback for image logging."""
+
+ def __init__(
+ self,
+ example_indices: Optional[List] = None,
+ num_examples: int = 4,
+ transfroms: Optional[Callable] = None,
+ ) -> None:
+ """Initializes the WandbImageLogger with the model to train.
+
+ Args:
+ example_indices (Optional[List]): Indices for validation images. Defaults to None.
+ num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+ transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to
+ None.
+
+ """
+
+ super().__init__()
+ self.example_indices = example_indices
+ self.num_examples = num_examples
+ self.transfroms = transfroms
+ if self.transfroms is None:
+ self.transforms = Compose([Transpose()])
+
+ def set_model(self, model: Type[Model]) -> None:
+ """Sets the model and extracts validation images from the dataset."""
+ self.model = model
+ data_loader = self.model.data_loaders["val"]
+ if self.example_indices is None:
+ self.example_indices = np.random.randint(
+ 0, len(data_loader.dataset.data), self.num_examples
+ )
+ self.val_images = data_loader.dataset.data[self.example_indices]
+ self.val_targets = data_loader.dataset.targets[self.example_indices].numpy()
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """Get network predictions on validation images."""
+ images = []
+ for i, image in enumerate(self.val_images):
+ image = self.transforms(image)
+ pred, conf = self.model.predict_on_image(image)
+ ground_truth = self.model.mapper(int(self.val_targets[i]))
+ caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
+ images.append(wandb.Image(image, caption=caption))
+
+ wandb.log({"examples": images}, commit=False)
diff --git a/src/training/trainer/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py
new file mode 100644
index 0000000..868d739
--- /dev/null
+++ b/src/training/trainer/population_based_training/__init__.py
@@ -0,0 +1 @@
+"""TBC."""
diff --git a/src/training/trainer/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py
new file mode 100644
index 0000000..868d739
--- /dev/null
+++ b/src/training/trainer/population_based_training/population_based_training.py
@@ -0,0 +1 @@
+"""TBC."""
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
new file mode 100644
index 0000000..a75ae8f
--- /dev/null
+++ b/src/training/trainer/train.py
@@ -0,0 +1,216 @@
+"""Training script for PyTorch models."""
+
+from pathlib import Path
+import time
+from typing import Dict, List, Optional, Tuple, Type
+
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from training.trainer.callbacks import Callback, CallbackList
+from training.trainer.util import RunningAverage
+import wandb
+
+from text_recognizer.models import Model
+
+
+torch.backends.cudnn.benchmark = True
+np.random.seed(4711)
+torch.manual_seed(4711)
+torch.cuda.manual_seed(4711)
+
+
+class Trainer:
+ """Trainer for training PyTorch models."""
+
+ def __init__(
+ self,
+ model: Type[Model],
+ model_dir: Path,
+ train_args: Dict,
+ callbacks: CallbackList,
+ checkpoint_path: Optional[Path] = None,
+ ) -> None:
+ """Initialization of the Trainer.
+
+ Args:
+ model (Type[Model]): A model object.
+ model_dir (Path): Path to the model directory.
+ train_args (Dict): The training arguments.
+ callbacks (CallbackList): List of callbacks to be called.
+ checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None.
+
+ """
+ self.model = model
+ self.model_dir = model_dir
+ self.checkpoint_path = checkpoint_path
+ self.start_epoch = 1
+ self.epochs = train_args["epochs"]
+ self.callbacks = callbacks
+
+ if self.checkpoint_path is not None:
+ self.start_epoch = self.model.load_checkpoint(self.checkpoint_path)
+
+ # Parse the name of the experiment.
+ experiment_dir = str(self.model_dir.parents[1]).split("/")
+ self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1]
+
+ def training_step(
+ self,
+ batch: int,
+ samples: Tuple[Tensor, Tensor],
+ loss_avg: Type[RunningAverage],
+ ) -> Dict:
+ """Performs the training step."""
+ # Pass the tensor to the device for computation.
+ data, targets = samples
+ data, targets = (
+ data.to(self.model.device),
+ targets.to(self.model.device),
+ )
+
+ # Forward pass.
+ # Get the network prediction.
+ output = self.model.network(data)
+
+ # Compute the loss.
+ loss = self.model.criterion(output, targets)
+
+ # Backward pass.
+ # Clear the previous gradients.
+ self.model.optimizer.zero_grad()
+
+ # Compute the gradients.
+ loss.backward()
+
+ # Perform updates using calculated gradients.
+ self.model.optimizer.step()
+
+ # Compute metrics.
+ loss_avg.update(loss.item())
+ output = output.data.cpu()
+ targets = targets.data.cpu()
+ metrics = {
+ metric: self.model.metrics[metric](output, targets)
+ for metric in self.model.metrics
+ }
+ metrics["loss"] = loss_avg()
+ return metrics
+
+ def train(self) -> None:
+ """Runs the training loop for one epoch."""
+ # Set model to traning mode.
+ self.model.train()
+
+ # Running average for the loss.
+ loss_avg = RunningAverage()
+
+ data_loader = self.model.data_loaders["train"]
+
+ for batch, samples in enumerate(data_loader):
+ self.callbacks.on_train_batch_begin(batch)
+ metrics = self.training_step(batch, samples, loss_avg)
+ self.callbacks.on_train_batch_end(batch, logs=metrics)
+
+ @torch.no_grad()
+ def validation_step(
+ self,
+ batch: int,
+ samples: Tuple[Tensor, Tensor],
+ loss_avg: Type[RunningAverage],
+ ) -> Dict:
+ """Performs the validation step."""
+ # Pass the tensor to the device for computation.
+ data, targets = samples
+ data, targets = (
+ data.to(self.model.device),
+ targets.to(self.model.device),
+ )
+
+ # Forward pass.
+ # Get the network prediction.
+ output = self.model.network(data)
+
+ # Compute the loss.
+ loss = self.model.criterion(output, targets)
+
+ # Compute metrics.
+ loss_avg.update(loss.item())
+ output = output.data.cpu()
+ targets = targets.data.cpu()
+ metrics = {
+ metric: self.model.metrics[metric](output, targets)
+ for metric in self.model.metrics
+ }
+ metrics["loss"] = loss.item()
+
+ return metrics
+
+ def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None:
+ log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
+ logger.debug(
+ log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
+ )
+
+ def validate(self, epoch: Optional[int] = None) -> Dict:
+ """Runs the validation loop for one epoch."""
+ # Set model to eval mode.
+ self.model.eval()
+
+ # Running average for the loss.
+ data_loader = self.model.data_loaders["val"]
+
+ # Running average for the loss.
+ loss_avg = RunningAverage()
+
+ # Summary for the current eval loop.
+ summary = []
+
+ for batch, samples in enumerate(data_loader):
+ self.callbacks.on_validation_batch_begin(batch)
+ metrics = self.validation_step(batch, samples, loss_avg)
+ self.callbacks.on_validation_batch_end(batch, logs=metrics)
+ summary.append(metrics)
+
+ # Compute mean of all metrics.
+ metrics_mean = {
+ "val_" + metric: np.mean([x[metric] for x in summary])
+ for metric in summary[0]
+ }
+ self._log_val_metric(metrics_mean, epoch)
+
+ return metrics_mean
+
+ def fit(self) -> None:
+ """Runs the training and evaluation loop."""
+
+ logger.debug(f"Running an experiment called {self.experiment_name}.")
+
+ # Set start time.
+ t_start = time.time()
+
+ self.callbacks.on_fit_begin()
+
+ # Run the training loop.
+ for epoch in range(self.start_epoch, self.epochs + 1):
+ self.callbacks.on_epoch_begin(epoch)
+
+ # Perform one training pass over the training set.
+ self.train()
+
+ # Evaluate the model on the validation set.
+ val_metrics = self.validate(epoch)
+
+ self.callbacks.on_epoch_end(epoch, logs=val_metrics)
+
+ if self.model.stop_training:
+ break
+
+ # Calculate the total training time.
+ t_end = time.time()
+ t_training = t_end - t_start
+
+ self.callbacks.on_fit_end()
+
+ logger.info(f"Training took {t_training:.2f} s.")
diff --git a/src/training/trainer/util.py b/src/training/trainer/util.py
new file mode 100644
index 0000000..132b2dc
--- /dev/null
+++ b/src/training/trainer/util.py
@@ -0,0 +1,19 @@
+"""Utility functions for training neural networks."""
+
+
+class RunningAverage:
+ """Maintains a running average."""
+
+ def __init__(self) -> None:
+ """Initializes the parameters."""
+ self.steps = 0
+ self.total = 0
+
+ def update(self, val: float) -> None:
+ """Updates the parameters."""
+ self.total += val
+ self.steps += 1
+
+ def __call__(self) -> float:
+ """Computes the running average."""
+ return self.total / float(self.steps)