diff options
Diffstat (limited to 'src/training/trainer')
| -rw-r--r-- | src/training/trainer/__init__.py | 2 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/__init__.py | 21 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/base.py | 248 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/early_stopping.py | 108 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py | 97 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/progress_bar.py | 61 | ||||
| -rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 93 | ||||
| -rw-r--r-- | src/training/trainer/population_based_training/__init__.py | 1 | ||||
| -rw-r--r-- | src/training/trainer/population_based_training/population_based_training.py | 1 | ||||
| -rw-r--r-- | src/training/trainer/train.py | 216 | ||||
| -rw-r--r-- | src/training/trainer/util.py | 19 | 
11 files changed, 867 insertions, 0 deletions
| diff --git a/src/training/trainer/__init__.py b/src/training/trainer/__init__.py new file mode 100644 index 0000000..de41bfb --- /dev/null +++ b/src/training/trainer/__init__.py @@ -0,0 +1,2 @@ +"""Trainer modules.""" +from .train import Trainer diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py new file mode 100644 index 0000000..5942276 --- /dev/null +++ b/src/training/trainer/callbacks/__init__.py @@ -0,0 +1,21 @@ +"""The callback modules used in the training script.""" +from .base import Callback, CallbackList, Checkpoint +from .early_stopping import EarlyStopping +from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR +from .progress_bar import ProgressBar +from .wandb_callbacks import WandbCallback, WandbImageLogger + +__all__ = [ +    "Callback", +    "CallbackList", +    "Checkpoint", +    "EarlyStopping", +    "WandbCallback", +    "WandbImageLogger", +    "CyclicLR", +    "MultiStepLR", +    "OneCycleLR", +    "ProgressBar", +    "ReduceLROnPlateau", +    "StepLR", +] diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py new file mode 100644 index 0000000..8df94f3 --- /dev/null +++ b/src/training/trainer/callbacks/base.py @@ -0,0 +1,248 @@ +"""Metaclass for callback functions.""" + +from enum import Enum +from typing import Callable, Dict, List, Optional, Type, Union + +from loguru import logger +import numpy as np +import torch + +from text_recognizer.models import Model + + +class ModeKeys: +    """Mode keys for CallbackList.""" + +    TRAIN = "train" +    VALIDATION = "validation" + + +class Callback: +    """Metaclass for callbacks used in training.""" + +    def __init__(self) -> None: +        """Initializes the Callback instance.""" +        self.model = None + +    def set_model(self, model: Type[Model]) -> None: +        """Set the model.""" +        self.model = model + +    def on_fit_begin(self) -> None: +        """Called when fit begins.""" +        pass + +    def on_fit_end(self) -> None: +        """Called when fit ends.""" +        pass + +    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Called at the beginning of an epoch. Only used in training mode.""" +        pass + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch. Only used in training mode.""" +        pass + +    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the beginning of an epoch.""" +        pass + +    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch.""" +        pass + +    def on_validation_batch_begin( +        self, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Called at the beginning of an epoch.""" +        pass + +    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch.""" +        pass + + +class CallbackList: +    """Container for abstracting away callback calls.""" + +    mode_keys = ModeKeys() + +    def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: +        """Container for `Callback` instances. + +        This object wraps a list of `Callback` instances and allows them all to be +        called via a single end point. + +        Args: +            model (Type[Model]): A `Model` instance. +            callbacks (List[Callback]): List of `Callback` instances. Defaults to None. + +        """ + +        self._callbacks = callbacks or [] +        if model: +            self.set_model(model) + +    def set_model(self, model: Type[Model]) -> None: +        """Set the model for all callbacks.""" +        self.model = model +        for callback in self._callbacks: +            callback.set_model(model=self.model) + +    def append(self, callback: Type[Callback]) -> None: +        """Append new callback to callback list.""" +        self.callbacks.append(callback) + +    def on_fit_begin(self) -> None: +        """Called when fit begins.""" +        for callback in self._callbacks: +            callback.on_fit_begin() + +    def on_fit_end(self) -> None: +        """Called when fit ends.""" +        for callback in self._callbacks: +            callback.on_fit_end() + +    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Called at the beginning of an epoch.""" +        for callback in self._callbacks: +            callback.on_epoch_begin(epoch, logs) + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch.""" +        for callback in self._callbacks: +            callback.on_epoch_end(epoch, logs) + +    def _call_batch_hook( +        self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Helper function for all batch_{begin | end} methods.""" +        if hook == "begin": +            self._call_batch_begin_hook(mode, batch, logs) +        elif hook == "end": +            self._call_batch_end_hook(mode, batch, logs) +        else: +            raise ValueError(f"Unrecognized hook {hook}.") + +    def _call_batch_begin_hook( +        self, mode: str, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Helper function for all `on_*_batch_begin` methods.""" +        hook_name = f"on_{mode}_batch_begin" +        self._call_batch_hook_helper(hook_name, batch, logs) + +    def _call_batch_end_hook( +        self, mode: str, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Helper function for all `on_*_batch_end` methods.""" +        hook_name = f"on_{mode}_batch_end" +        self._call_batch_hook_helper(hook_name, batch, logs) + +    def _call_batch_hook_helper( +        self, hook_name: str, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Helper function for `on_*_batch_begin` methods.""" +        for callback in self._callbacks: +            hook = getattr(callback, hook_name) +            hook(batch, logs) + +    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the beginning of an epoch.""" +        self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) + +    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch.""" +        self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) + +    def on_validation_batch_begin( +        self, batch: int, logs: Optional[Dict] = None +    ) -> None: +        """Called at the beginning of an epoch.""" +        self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) + +    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Called at the end of an epoch.""" +        self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) + +    def __iter__(self) -> iter: +        """Iter function for callback list.""" +        return iter(self._callbacks) + + +class Checkpoint(Callback): +    """Saving model parameters at the end of each epoch.""" + +    mode_dict = { +        "min": torch.lt, +        "max": torch.gt, +    } + +    def __init__( +        self, monitor: str = "accuracy", mode: str = "auto", min_delta: float = 0.0 +    ) -> None: +        """Monitors a quantity that will allow us to determine the best model weights. + +        Args: +            monitor (str): Name of the quantity to monitor. Defaults to "accuracy". +            mode (str): Description of parameter `mode`. Defaults to "auto". +            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. + +        """ +        super().__init__() +        self.monitor = monitor +        self.mode = mode +        self.min_delta = torch.tensor(min_delta) + +        if mode not in ["auto", "min", "max"]: +            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") + +            self.mode = "auto" + +        if self.mode == "auto": +            if "accuracy" in self.monitor: +                self.mode = "max" +            else: +                self.mode = "min" +            logger.debug( +                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." +            ) + +        torch_inf = torch.tensor(np.inf) +        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 +        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + +    @property +    def monitor_op(self) -> float: +        """Returns the comparison method.""" +        return self.mode_dict[self.mode] + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """Saves a checkpoint for the network parameters. + +        Args: +            epoch (int): The current epoch. +            logs (Dict): The log containing the monitored metrics. + +        """ +        current = self.get_monitor_value(logs) +        if current is None: +            return +        if self.monitor_op(current - self.min_delta, self.best_score): +            self.best_score = current +            is_best = True +        else: +            is_best = False + +        self.model.save_checkpoint(is_best, epoch, self.monitor) + +    def get_monitor_value(self, logs: Dict) -> Union[float, None]: +        """Extracts the monitored value.""" +        monitor_value = logs.get(self.monitor) +        if monitor_value is None: +            logger.warning( +                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" +                + f"metrics are: {','.join(list(logs.keys()))}" +            ) +            return None +        return monitor_value diff --git a/src/training/trainer/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py new file mode 100644 index 0000000..02b431f --- /dev/null +++ b/src/training/trainer/callbacks/early_stopping.py @@ -0,0 +1,108 @@ +"""Implements Early stopping for PyTorch model.""" +from typing import Dict, Union + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from training.trainer.callbacks import Callback + + +class EarlyStopping(Callback): +    """Stops training when a monitored metric stops improving.""" + +    mode_dict = { +        "min": torch.lt, +        "max": torch.gt, +    } + +    def __init__( +        self, +        monitor: str = "val_loss", +        min_delta: float = 0.0, +        patience: int = 3, +        mode: str = "auto", +    ) -> None: +        """Initializes the EarlyStopping callback. + +        Args: +            monitor (str): Description of parameter `monitor`. Defaults to "val_loss". +            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. +            patience (int): Description of parameter `patience`. Defaults to 3. +            mode (str): Description of parameter `mode`. Defaults to "auto". + +        """ +        super().__init__() +        self.monitor = monitor +        self.patience = patience +        self.min_delta = torch.tensor(min_delta) +        self.mode = mode +        self.wait_count = 0 +        self.stopped_epoch = 0 + +        if mode not in ["auto", "min", "max"]: +            logger.warning( +                f"EarlyStopping mode {mode} is unkown, fallback to auto mode." +            ) + +            self.mode = "auto" + +        if self.mode == "auto": +            if "accuracy" in self.monitor: +                self.mode = "max" +            else: +                self.mode = "min" +            logger.debug( +                f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." +            ) + +        self.torch_inf = torch.tensor(np.inf) +        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 +        self.best_score = ( +            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf +        ) + +    @property +    def monitor_op(self) -> float: +        """Returns the comparison method.""" +        return self.mode_dict[self.mode] + +    def on_fit_begin(self) -> Union[torch.lt, torch.gt]: +        """Reset the early stopping variables for reuse.""" +        self.wait_count = 0 +        self.stopped_epoch = 0 +        self.best_score = ( +            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf +        ) + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """Computes the early stop criterion.""" +        current = self.get_monitor_value(logs) +        if current is None: +            return +        if self.monitor_op(current - self.min_delta, self.best_score): +            self.best_score = current +            self.wait_count = 0 +        else: +            self.wait_count += 1 +            if self.wait_count >= self.patience: +                self.stopped_epoch = epoch +                self.model.stop_training = True + +    def on_fit_end(self) -> None: +        """Logs if early stopping was used.""" +        if self.stopped_epoch > 0: +            logger.info( +                f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." +            ) + +    def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: +        """Extracts the monitor value.""" +        monitor_value = logs.get(self.monitor) +        if monitor_value is None: +            logger.warning( +                f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" +                + f"metrics are: {','.join(list(logs.keys()))}" +            ) +            return None +        return torch.tensor(monitor_value) diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py new file mode 100644 index 0000000..ba2226a --- /dev/null +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -0,0 +1,97 @@ +"""Callbacks for learning rate schedulers.""" +from typing import Callable, Dict, List, Optional, Type + +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class StepLR(Callback): +    """Callback for StepLR.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every epoch.""" +        self.lr_scheduler.step() + + +class MultiStepLR(Callback): +    """Callback for MultiStepLR.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every epoch.""" +        self.lr_scheduler.step() + + +class ReduceLROnPlateau(Callback): +    """Callback for ReduceLROnPlateau.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every epoch.""" +        val_loss = logs["val_loss"] +        self.lr_scheduler.step(val_loss) + + +class CyclicLR(Callback): +    """Callback for CyclicLR.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every training batch.""" +        self.lr_scheduler.step() + + +class OneCycleLR(Callback): +    """Callback for OneCycleLR.""" + +    def __init__(self) -> None: +        """Initializes the callback.""" +        super().__init__() +        self.lr_scheduler = None + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and lr scheduler.""" +        self.model = model +        self.lr_scheduler = self.model.lr_scheduler + +    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Takes a step at the end of every training batch.""" +        self.lr_scheduler.step() diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py new file mode 100644 index 0000000..1970747 --- /dev/null +++ b/src/training/trainer/callbacks/progress_bar.py @@ -0,0 +1,61 @@ +"""Progress bar callback for the training loop.""" +from typing import Dict, Optional + +from tqdm import tqdm +from training.trainer.callbacks import Callback + + +class ProgressBar(Callback): +    """A TQDM progress bar for the training loop.""" + +    def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: +        """Initializes the tqdm callback.""" +        self.epochs = epochs +        self.log_batch_frequency = log_batch_frequency +        self.progress_bar = None +        self.val_metrics = {} + +    def _configure_progress_bar(self) -> None: +        """Configures the tqdm progress bar with custom bar format.""" +        self.progress_bar = tqdm( +            total=len(self.model.data_loaders["train"]), +            leave=True, +            unit="step", +            mininterval=self.log_batch_frequency, +            bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", +        ) + +    def _key_abbreviations(self, logs: Dict) -> Dict: +        """Changes the length of keys, so that the progress bar fits better.""" + +        def rename(key: str) -> str: +            """Renames accuracy to acc.""" +            return key.replace("accuracy", "acc") + +        return {rename(key): value for key, value in logs.items()} + +    def on_fit_begin(self) -> None: +        """Creates a tqdm progress bar.""" +        self._configure_progress_bar() + +    def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: +        """Updates the description with the current epoch.""" +        self.progress_bar.reset() +        self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """At the end of each epoch, the validation metrics are updated to the progress bar.""" +        self.val_metrics = logs +        self.progress_bar.set_postfix(**self._key_abbreviations(logs)) +        self.progress_bar.update() + +    def on_train_batch_end(self, batch: int, logs: Dict) -> None: +        """Updates the progress bar for each training step.""" +        if self.val_metrics: +            logs.update(self.val_metrics) +        self.progress_bar.set_postfix(**self._key_abbreviations(logs)) +        self.progress_bar.update() + +    def on_fit_end(self) -> None: +        """Closes the tqdm progress bar.""" +        self.progress_bar.close() diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py new file mode 100644 index 0000000..e44c745 --- /dev/null +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -0,0 +1,93 @@ +"""Callback for W&B.""" +from typing import Callable, Dict, List, Optional, Type + +import numpy as np +from torchvision.transforms import Compose, ToTensor +from training.trainer.callbacks import Callback +import wandb + +from text_recognizer.datasets import Transpose +from text_recognizer.models.base import Model + + +class WandbCallback(Callback): +    """A custom W&B metric logger for the trainer.""" + +    def __init__(self, log_batch_frequency: int = None) -> None: +        """Short summary. + +        Args: +            log_batch_frequency (int): If None, metrics will be logged every epoch. +                If set to an integer, callback will log every metrics every log_batch_frequency. + +        """ +        super().__init__() +        self.log_batch_frequency = log_batch_frequency + +    def _on_batch_end(self, batch: int, logs: Dict) -> None: +        if self.log_batch_frequency and batch % self.log_batch_frequency == 0: +            wandb.log(logs, commit=True) + +    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Logs training metrics.""" +        if logs is not None: +            self._on_batch_end(batch, logs) + +    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: +        """Logs validation metrics.""" +        if logs is not None: +            self._on_batch_end(batch, logs) + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """Logs at epoch end.""" +        wandb.log(logs, commit=True) + + +class WandbImageLogger(Callback): +    """Custom W&B callback for image logging.""" + +    def __init__( +        self, +        example_indices: Optional[List] = None, +        num_examples: int = 4, +        transfroms: Optional[Callable] = None, +    ) -> None: +        """Initializes the WandbImageLogger with the model to train. + +        Args: +            example_indices (Optional[List]): Indices for validation images. Defaults to None. +            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. +            transfroms (Optional[Callable]): Transforms to use on the validation images, e.g. transpose. Defaults to +                None. + +        """ + +        super().__init__() +        self.example_indices = example_indices +        self.num_examples = num_examples +        self.transfroms = transfroms +        if self.transfroms is None: +            self.transforms = Compose([Transpose()]) + +    def set_model(self, model: Type[Model]) -> None: +        """Sets the model and extracts validation images from the dataset.""" +        self.model = model +        data_loader = self.model.data_loaders["val"] +        if self.example_indices is None: +            self.example_indices = np.random.randint( +                0, len(data_loader.dataset.data), self.num_examples +            ) +        self.val_images = data_loader.dataset.data[self.example_indices] +        self.val_targets = data_loader.dataset.targets[self.example_indices].numpy() + +    def on_epoch_end(self, epoch: int, logs: Dict) -> None: +        """Get network predictions on validation images.""" +        images = [] +        for i, image in enumerate(self.val_images): +            image = self.transforms(image) +            pred, conf = self.model.predict_on_image(image) +            ground_truth = self.model.mapper(int(self.val_targets[i])) +            caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" +            images.append(wandb.Image(image, caption=caption)) + +        wandb.log({"examples": images}, commit=False) diff --git a/src/training/trainer/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py new file mode 100644 index 0000000..868d739 --- /dev/null +++ b/src/training/trainer/population_based_training/__init__.py @@ -0,0 +1 @@ +"""TBC.""" diff --git a/src/training/trainer/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py new file mode 100644 index 0000000..868d739 --- /dev/null +++ b/src/training/trainer/population_based_training/population_based_training.py @@ -0,0 +1 @@ +"""TBC.""" diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py new file mode 100644 index 0000000..a75ae8f --- /dev/null +++ b/src/training/trainer/train.py @@ -0,0 +1,216 @@ +"""Training script for PyTorch models.""" + +from pathlib import Path +import time +from typing import Dict, List, Optional, Tuple, Type + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from training.trainer.callbacks import Callback, CallbackList +from training.trainer.util import RunningAverage +import wandb + +from text_recognizer.models import Model + + +torch.backends.cudnn.benchmark = True +np.random.seed(4711) +torch.manual_seed(4711) +torch.cuda.manual_seed(4711) + + +class Trainer: +    """Trainer for training PyTorch models.""" + +    def __init__( +        self, +        model: Type[Model], +        model_dir: Path, +        train_args: Dict, +        callbacks: CallbackList, +        checkpoint_path: Optional[Path] = None, +    ) -> None: +        """Initialization of the Trainer. + +        Args: +            model (Type[Model]): A model object. +            model_dir (Path): Path to the model directory. +            train_args (Dict): The training arguments. +            callbacks (CallbackList): List of callbacks to be called. +            checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. + +        """ +        self.model = model +        self.model_dir = model_dir +        self.checkpoint_path = checkpoint_path +        self.start_epoch = 1 +        self.epochs = train_args["epochs"] +        self.callbacks = callbacks + +        if self.checkpoint_path is not None: +            self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + +        # Parse the name of the experiment. +        experiment_dir = str(self.model_dir.parents[1]).split("/") +        self.experiment_name = experiment_dir[-2] + "/" + experiment_dir[-1] + +    def training_step( +        self, +        batch: int, +        samples: Tuple[Tensor, Tensor], +        loss_avg: Type[RunningAverage], +    ) -> Dict: +        """Performs the training step.""" +        # Pass the tensor to the device for computation. +        data, targets = samples +        data, targets = ( +            data.to(self.model.device), +            targets.to(self.model.device), +        ) + +        # Forward pass. +        # Get the network prediction. +        output = self.model.network(data) + +        # Compute the loss. +        loss = self.model.criterion(output, targets) + +        # Backward pass. +        # Clear the previous gradients. +        self.model.optimizer.zero_grad() + +        # Compute the gradients. +        loss.backward() + +        # Perform updates using calculated gradients. +        self.model.optimizer.step() + +        # Compute metrics. +        loss_avg.update(loss.item()) +        output = output.data.cpu() +        targets = targets.data.cpu() +        metrics = { +            metric: self.model.metrics[metric](output, targets) +            for metric in self.model.metrics +        } +        metrics["loss"] = loss_avg() +        return metrics + +    def train(self) -> None: +        """Runs the training loop for one epoch.""" +        # Set model to traning mode. +        self.model.train() + +        # Running average for the loss. +        loss_avg = RunningAverage() + +        data_loader = self.model.data_loaders["train"] + +        for batch, samples in enumerate(data_loader): +            self.callbacks.on_train_batch_begin(batch) +            metrics = self.training_step(batch, samples, loss_avg) +            self.callbacks.on_train_batch_end(batch, logs=metrics) + +    @torch.no_grad() +    def validation_step( +        self, +        batch: int, +        samples: Tuple[Tensor, Tensor], +        loss_avg: Type[RunningAverage], +    ) -> Dict: +        """Performs the validation step.""" +        # Pass the tensor to the device for computation. +        data, targets = samples +        data, targets = ( +            data.to(self.model.device), +            targets.to(self.model.device), +        ) + +        # Forward pass. +        # Get the network prediction. +        output = self.model.network(data) + +        # Compute the loss. +        loss = self.model.criterion(output, targets) + +        # Compute metrics. +        loss_avg.update(loss.item()) +        output = output.data.cpu() +        targets = targets.data.cpu() +        metrics = { +            metric: self.model.metrics[metric](output, targets) +            for metric in self.model.metrics +        } +        metrics["loss"] = loss.item() + +        return metrics + +    def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None: +        log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") +        logger.debug( +            log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) +        ) + +    def validate(self, epoch: Optional[int] = None) -> Dict: +        """Runs the validation loop for one epoch.""" +        # Set model to eval mode. +        self.model.eval() + +        # Running average for the loss. +        data_loader = self.model.data_loaders["val"] + +        # Running average for the loss. +        loss_avg = RunningAverage() + +        # Summary for the current eval loop. +        summary = [] + +        for batch, samples in enumerate(data_loader): +            self.callbacks.on_validation_batch_begin(batch) +            metrics = self.validation_step(batch, samples, loss_avg) +            self.callbacks.on_validation_batch_end(batch, logs=metrics) +            summary.append(metrics) + +        # Compute mean of all metrics. +        metrics_mean = { +            "val_" + metric: np.mean([x[metric] for x in summary]) +            for metric in summary[0] +        } +        self._log_val_metric(metrics_mean, epoch) + +        return metrics_mean + +    def fit(self) -> None: +        """Runs the training and evaluation loop.""" + +        logger.debug(f"Running an experiment called {self.experiment_name}.") + +        # Set start time. +        t_start = time.time() + +        self.callbacks.on_fit_begin() + +        # Run the training loop. +        for epoch in range(self.start_epoch, self.epochs + 1): +            self.callbacks.on_epoch_begin(epoch) + +            # Perform one training pass over the training set. +            self.train() + +            # Evaluate the model on the validation set. +            val_metrics = self.validate(epoch) + +            self.callbacks.on_epoch_end(epoch, logs=val_metrics) + +            if self.model.stop_training: +                break + +        # Calculate the total training time. +        t_end = time.time() +        t_training = t_end - t_start + +        self.callbacks.on_fit_end() + +        logger.info(f"Training took {t_training:.2f} s.") diff --git a/src/training/trainer/util.py b/src/training/trainer/util.py new file mode 100644 index 0000000..132b2dc --- /dev/null +++ b/src/training/trainer/util.py @@ -0,0 +1,19 @@ +"""Utility functions for training neural networks.""" + + +class RunningAverage: +    """Maintains a running average.""" + +    def __init__(self) -> None: +        """Initializes the parameters.""" +        self.steps = 0 +        self.total = 0 + +    def update(self, val: float) -> None: +        """Updates the parameters.""" +        self.total += val +        self.steps += 1 + +    def __call__(self) -> float: +        """Computes the running average.""" +        return self.total / float(self.steps) |