summaryrefslogtreecommitdiff
path: root/src/training/trainer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /src/training/trainer
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'src/training/trainer')
-rw-r--r--src/training/trainer/__init__.py2
-rw-r--r--src/training/trainer/callbacks/__init__.py29
-rw-r--r--src/training/trainer/callbacks/base.py188
-rw-r--r--src/training/trainer/callbacks/checkpoint.py95
-rw-r--r--src/training/trainer/callbacks/early_stopping.py108
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py77
-rw-r--r--src/training/trainer/callbacks/progress_bar.py65
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py261
-rw-r--r--src/training/trainer/train.py325
-rw-r--r--src/training/trainer/util.py28
10 files changed, 0 insertions, 1178 deletions
diff --git a/src/training/trainer/__init__.py b/src/training/trainer/__init__.py
deleted file mode 100644
index de41bfb..0000000
--- a/src/training/trainer/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""Trainer modules."""
-from .train import Trainer
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
deleted file mode 100644
index 80c4177..0000000
--- a/src/training/trainer/callbacks/__init__.py
+++ /dev/null
@@ -1,29 +0,0 @@
-"""The callback modules used in the training script."""
-from .base import Callback, CallbackList
-from .checkpoint import Checkpoint
-from .early_stopping import EarlyStopping
-from .lr_schedulers import (
- LRScheduler,
- SWA,
-)
-from .progress_bar import ProgressBar
-from .wandb_callbacks import (
- WandbCallback,
- WandbImageLogger,
- WandbReconstructionLogger,
- WandbSegmentationLogger,
-)
-
-__all__ = [
- "Callback",
- "CallbackList",
- "Checkpoint",
- "EarlyStopping",
- "LRScheduler",
- "WandbCallback",
- "WandbImageLogger",
- "WandbReconstructionLogger",
- "WandbSegmentationLogger",
- "ProgressBar",
- "SWA",
-]
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py
deleted file mode 100644
index 500b642..0000000
--- a/src/training/trainer/callbacks/base.py
+++ /dev/null
@@ -1,188 +0,0 @@
-"""Metaclass for callback functions."""
-
-from enum import Enum
-from typing import Callable, Dict, List, Optional, Type, Union
-
-from loguru import logger
-import numpy as np
-import torch
-
-from text_recognizer.models import Model
-
-
-class ModeKeys:
- """Mode keys for CallbackList."""
-
- TRAIN = "train"
- VALIDATION = "validation"
-
-
-class Callback:
- """Metaclass for callbacks used in training."""
-
- def __init__(self) -> None:
- """Initializes the Callback instance."""
- self.model = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Set the model."""
- self.model = model
-
- def on_fit_begin(self) -> None:
- """Called when fit begins."""
- pass
-
- def on_fit_end(self) -> None:
- """Called when fit ends."""
- pass
-
- def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Called at the beginning of an epoch. Only used in training mode."""
- pass
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch. Only used in training mode."""
- pass
-
- def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the beginning of an epoch."""
- pass
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch."""
- pass
-
- def on_validation_batch_begin(
- self, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Called at the beginning of an epoch."""
- pass
-
- def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch."""
- pass
-
- def on_test_begin(self) -> None:
- """Called at the beginning of test."""
- pass
-
- def on_test_end(self) -> None:
- """Called at the end of test."""
- pass
-
-
-class CallbackList:
- """Container for abstracting away callback calls."""
-
- mode_keys = ModeKeys()
-
- def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None:
- """Container for `Callback` instances.
-
- This object wraps a list of `Callback` instances and allows them all to be
- called via a single end point.
-
- Args:
- model (Type[Model]): A `Model` instance.
- callbacks (List[Callback]): List of `Callback` instances. Defaults to None.
-
- """
-
- self._callbacks = callbacks or []
- if model:
- self.set_model(model)
-
- def set_model(self, model: Type[Model]) -> None:
- """Set the model for all callbacks."""
- self.model = model
- for callback in self._callbacks:
- callback.set_model(model=self.model)
-
- def append(self, callback: Type[Callback]) -> None:
- """Append new callback to callback list."""
- self._callbacks.append(callback)
-
- def on_fit_begin(self) -> None:
- """Called when fit begins."""
- for callback in self._callbacks:
- callback.on_fit_begin()
-
- def on_fit_end(self) -> None:
- """Called when fit ends."""
- for callback in self._callbacks:
- callback.on_fit_end()
-
- def on_test_begin(self) -> None:
- """Called when test begins."""
- for callback in self._callbacks:
- callback.on_test_begin()
-
- def on_test_end(self) -> None:
- """Called when test ends."""
- for callback in self._callbacks:
- callback.on_test_end()
-
- def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Called at the beginning of an epoch."""
- for callback in self._callbacks:
- callback.on_epoch_begin(epoch, logs)
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch."""
- for callback in self._callbacks:
- callback.on_epoch_end(epoch, logs)
-
- def _call_batch_hook(
- self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Helper function for all batch_{begin | end} methods."""
- if hook == "begin":
- self._call_batch_begin_hook(mode, batch, logs)
- elif hook == "end":
- self._call_batch_end_hook(mode, batch, logs)
- else:
- raise ValueError(f"Unrecognized hook {hook}.")
-
- def _call_batch_begin_hook(
- self, mode: str, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Helper function for all `on_*_batch_begin` methods."""
- hook_name = f"on_{mode}_batch_begin"
- self._call_batch_hook_helper(hook_name, batch, logs)
-
- def _call_batch_end_hook(
- self, mode: str, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Helper function for all `on_*_batch_end` methods."""
- hook_name = f"on_{mode}_batch_end"
- self._call_batch_hook_helper(hook_name, batch, logs)
-
- def _call_batch_hook_helper(
- self, hook_name: str, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Helper function for `on_*_batch_begin` methods."""
- for callback in self._callbacks:
- hook = getattr(callback, hook_name)
- hook(batch, logs)
-
- def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)
-
- def on_validation_batch_begin(
- self, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)
-
- def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)
-
- def __iter__(self) -> iter:
- """Iter function for callback list."""
- return iter(self._callbacks)
diff --git a/src/training/trainer/callbacks/checkpoint.py b/src/training/trainer/callbacks/checkpoint.py
deleted file mode 100644
index a54e0a9..0000000
--- a/src/training/trainer/callbacks/checkpoint.py
+++ /dev/null
@@ -1,95 +0,0 @@
-"""Callback checkpoint for training models."""
-from enum import Enum
-from pathlib import Path
-from typing import Callable, Dict, List, Optional, Type, Union
-
-from loguru import logger
-import numpy as np
-import torch
-from training.trainer.callbacks import Callback
-
-from text_recognizer.models import Model
-
-
-class Checkpoint(Callback):
- """Saving model parameters at the end of each epoch."""
-
- mode_dict = {
- "min": torch.lt,
- "max": torch.gt,
- }
-
- def __init__(
- self,
- checkpoint_path: Union[str, Path],
- monitor: str = "accuracy",
- mode: str = "auto",
- min_delta: float = 0.0,
- ) -> None:
- """Monitors a quantity that will allow us to determine the best model weights.
-
- Args:
- checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint.
- monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
- mode (str): Description of parameter `mode`. Defaults to "auto".
- min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
-
- """
- super().__init__()
- self.checkpoint_path = Path(checkpoint_path)
- self.monitor = monitor
- self.mode = mode
- self.min_delta = torch.tensor(min_delta)
-
- if mode not in ["auto", "min", "max"]:
- logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
-
- self.mode = "auto"
-
- if self.mode == "auto":
- if "accuracy" in self.monitor:
- self.mode = "max"
- else:
- self.mode = "min"
- logger.debug(
- f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
- )
-
- torch_inf = torch.tensor(np.inf)
- self.min_delta *= 1 if self.monitor_op == torch.gt else -1
- self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
-
- @property
- def monitor_op(self) -> float:
- """Returns the comparison method."""
- return self.mode_dict[self.mode]
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Saves a checkpoint for the network parameters.
-
- Args:
- epoch (int): The current epoch.
- logs (Dict): The log containing the monitored metrics.
-
- """
- current = self.get_monitor_value(logs)
- if current is None:
- return
- if self.monitor_op(current - self.min_delta, self.best_score):
- self.best_score = current
- is_best = True
- else:
- is_best = False
-
- self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor)
-
- def get_monitor_value(self, logs: Dict) -> Union[float, None]:
- """Extracts the monitored value."""
- monitor_value = logs.get(self.monitor)
- if monitor_value is None:
- logger.warning(
- f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
- + f" metrics are: {','.join(list(logs.keys()))}"
- )
- return None
- return monitor_value
diff --git a/src/training/trainer/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py
deleted file mode 100644
index 02b431f..0000000
--- a/src/training/trainer/callbacks/early_stopping.py
+++ /dev/null
@@ -1,108 +0,0 @@
-"""Implements Early stopping for PyTorch model."""
-from typing import Dict, Union
-
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from training.trainer.callbacks import Callback
-
-
-class EarlyStopping(Callback):
- """Stops training when a monitored metric stops improving."""
-
- mode_dict = {
- "min": torch.lt,
- "max": torch.gt,
- }
-
- def __init__(
- self,
- monitor: str = "val_loss",
- min_delta: float = 0.0,
- patience: int = 3,
- mode: str = "auto",
- ) -> None:
- """Initializes the EarlyStopping callback.
-
- Args:
- monitor (str): Description of parameter `monitor`. Defaults to "val_loss".
- min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
- patience (int): Description of parameter `patience`. Defaults to 3.
- mode (str): Description of parameter `mode`. Defaults to "auto".
-
- """
- super().__init__()
- self.monitor = monitor
- self.patience = patience
- self.min_delta = torch.tensor(min_delta)
- self.mode = mode
- self.wait_count = 0
- self.stopped_epoch = 0
-
- if mode not in ["auto", "min", "max"]:
- logger.warning(
- f"EarlyStopping mode {mode} is unkown, fallback to auto mode."
- )
-
- self.mode = "auto"
-
- if self.mode == "auto":
- if "accuracy" in self.monitor:
- self.mode = "max"
- else:
- self.mode = "min"
- logger.debug(
- f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}."
- )
-
- self.torch_inf = torch.tensor(np.inf)
- self.min_delta *= 1 if self.monitor_op == torch.gt else -1
- self.best_score = (
- self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
- )
-
- @property
- def monitor_op(self) -> float:
- """Returns the comparison method."""
- return self.mode_dict[self.mode]
-
- def on_fit_begin(self) -> Union[torch.lt, torch.gt]:
- """Reset the early stopping variables for reuse."""
- self.wait_count = 0
- self.stopped_epoch = 0
- self.best_score = (
- self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
- )
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Computes the early stop criterion."""
- current = self.get_monitor_value(logs)
- if current is None:
- return
- if self.monitor_op(current - self.min_delta, self.best_score):
- self.best_score = current
- self.wait_count = 0
- else:
- self.wait_count += 1
- if self.wait_count >= self.patience:
- self.stopped_epoch = epoch
- self.model.stop_training = True
-
- def on_fit_end(self) -> None:
- """Logs if early stopping was used."""
- if self.stopped_epoch > 0:
- logger.info(
- f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
- )
-
- def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]:
- """Extracts the monitor value."""
- monitor_value = logs.get(self.monitor)
- if monitor_value is None:
- logger.warning(
- f"Early stopping is conditioned on metric {self.monitor} which is not available. Available"
- + f"metrics are: {','.join(list(logs.keys()))}"
- )
- return None
- return torch.tensor(monitor_value)
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
deleted file mode 100644
index 630c434..0000000
--- a/src/training/trainer/callbacks/lr_schedulers.py
+++ /dev/null
@@ -1,77 +0,0 @@
-"""Callbacks for learning rate schedulers."""
-from typing import Callable, Dict, List, Optional, Type
-
-from torch.optim.swa_utils import update_bn
-from training.trainer.callbacks import Callback
-
-from text_recognizer.models import Model
-
-
-class LRScheduler(Callback):
- """Generic learning rate scheduler callback."""
-
- def __init__(self) -> None:
- super().__init__()
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
- self.interval = self.model.lr_scheduler["interval"]
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every epoch."""
- if self.interval == "epoch":
- if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__:
- self.lr_scheduler.step(logs["val_loss"])
- else:
- self.lr_scheduler.step()
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every training batch."""
- if self.interval == "step":
- self.lr_scheduler.step()
-
-
-class SWA(Callback):
- """Stochastic Weight Averaging callback."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
- self.interval = None
- self.swa_scheduler = None
- self.swa_start = None
- self.current_epoch = 1
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
- self.interval = self.model.lr_scheduler["interval"]
- self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"]
- self.swa_start = self.model.swa_scheduler["swa_start"]
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every training batch."""
- if epoch > self.swa_start:
- self.model.swa_network.update_parameters(self.model.network)
- self.swa_scheduler.step()
- elif self.interval == "epoch":
- self.lr_scheduler.step()
- self.current_epoch = epoch
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every training batch."""
- if self.current_epoch < self.swa_start and self.interval == "step":
- self.lr_scheduler.step()
-
- def on_fit_end(self) -> None:
- """Update batch norm statistics for the swa model at the end of training."""
- if self.model.swa_network:
- update_bn(
- self.model.val_dataloader(),
- self.model.swa_network,
- device=self.model.device,
- )
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py
deleted file mode 100644
index 6c4305a..0000000
--- a/src/training/trainer/callbacks/progress_bar.py
+++ /dev/null
@@ -1,65 +0,0 @@
-"""Progress bar callback for the training loop."""
-from typing import Dict, Optional
-
-from tqdm import tqdm
-from training.trainer.callbacks import Callback
-
-
-class ProgressBar(Callback):
- """A TQDM progress bar for the training loop."""
-
- def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
- """Initializes the tqdm callback."""
- self.epochs = epochs
- print(epochs, type(epochs))
- self.log_batch_frequency = log_batch_frequency
- self.progress_bar = None
- self.val_metrics = {}
-
- def _configure_progress_bar(self) -> None:
- """Configures the tqdm progress bar with custom bar format."""
- self.progress_bar = tqdm(
- total=len(self.model.train_dataloader()),
- leave=False,
- unit="steps",
- mininterval=self.log_batch_frequency,
- bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
- )
-
- def _key_abbreviations(self, logs: Dict) -> Dict:
- """Changes the length of keys, so that the progress bar fits better."""
-
- def rename(key: str) -> str:
- """Renames accuracy to acc."""
- return key.replace("accuracy", "acc")
-
- return {rename(key): value for key, value in logs.items()}
-
- # def on_fit_begin(self) -> None:
- # """Creates a tqdm progress bar."""
- # self._configure_progress_bar()
-
- def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
- """Updates the description with the current epoch."""
- if epoch == 1:
- self._configure_progress_bar()
- else:
- self.progress_bar.reset()
- self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """At the end of each epoch, the validation metrics are updated to the progress bar."""
- self.val_metrics = logs
- self.progress_bar.set_postfix(**self._key_abbreviations(logs))
- self.progress_bar.update()
-
- def on_train_batch_end(self, batch: int, logs: Dict) -> None:
- """Updates the progress bar for each training step."""
- if self.val_metrics:
- logs.update(self.val_metrics)
- self.progress_bar.set_postfix(**self._key_abbreviations(logs))
- self.progress_bar.update()
-
- def on_fit_end(self) -> None:
- """Closes the tqdm progress bar."""
- self.progress_bar.close()
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
deleted file mode 100644
index 552a4f4..0000000
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ /dev/null
@@ -1,261 +0,0 @@
-"""Callback for W&B."""
-from typing import Callable, Dict, List, Optional, Type
-
-import numpy as np
-from training.trainer.callbacks import Callback
-import wandb
-
-import text_recognizer.datasets.transforms as transforms
-from text_recognizer.models.base import Model
-
-
-class WandbCallback(Callback):
- """A custom W&B metric logger for the trainer."""
-
- def __init__(self, log_batch_frequency: int = None) -> None:
- """Short summary.
-
- Args:
- log_batch_frequency (int): If None, metrics will be logged every epoch.
- If set to an integer, callback will log every metrics every log_batch_frequency.
-
- """
- super().__init__()
- self.log_batch_frequency = log_batch_frequency
-
- def _on_batch_end(self, batch: int, logs: Dict) -> None:
- if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
- wandb.log(logs, commit=True)
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Logs training metrics."""
- if logs is not None:
- logs["lr"] = self.model.optimizer.param_groups[0]["lr"]
- self._on_batch_end(batch, logs)
-
- def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Logs validation metrics."""
- if logs is not None:
- self._on_batch_end(batch, logs)
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Logs at epoch end."""
- wandb.log(logs, commit=True)
-
-
-class WandbImageLogger(Callback):
- """Custom W&B callback for image logging."""
-
- def __init__(
- self,
- example_indices: Optional[List] = None,
- num_examples: int = 4,
- transform: Optional[bool] = None,
- ) -> None:
- """Initializes the WandbImageLogger with the model to train.
-
- Args:
- example_indices (Optional[List]): Indices for validation images. Defaults to None.
- num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
- transform (Optional[Dict]): Use transform on image or not. Defaults to None.
-
- """
-
- super().__init__()
- self.caption = None
- self.example_indices = example_indices
- self.test_sample_indices = None
- self.num_examples = num_examples
- self.transform = (
- self._configure_transform(transform) if transform is not None else None
- )
-
- def _configure_transform(self, transform: Dict) -> Callable:
- args = transform["args"] or {}
- return getattr(transforms, transform["type"])(**args)
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and extracts validation images from the dataset."""
- self.model = model
- self.caption = "Validation Examples"
- if self.example_indices is None:
- self.example_indices = np.random.randint(
- 0, len(self.model.val_dataset), self.num_examples
- )
- self.images = self.model.val_dataset.dataset.data[self.example_indices]
- self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
- self.targets = self.targets.tolist()
-
- def on_test_begin(self) -> None:
- """Get samples from test dataset."""
- self.caption = "Test Examples"
- if self.test_sample_indices is None:
- self.test_sample_indices = np.random.randint(
- 0, len(self.model.test_dataset), self.num_examples
- )
- self.images = self.model.test_dataset.data[self.test_sample_indices]
- self.targets = self.model.test_dataset.targets[self.test_sample_indices]
- self.targets = self.targets.tolist()
-
- def on_test_end(self) -> None:
- """Log test images."""
- self.on_epoch_end(0, {})
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Get network predictions on validation images."""
- images = []
- for i, image in enumerate(self.images):
- image = self.transform(image) if self.transform is not None else image
- pred, conf = self.model.predict_on_image(image)
- if isinstance(self.targets[i], list):
- ground_truth = "".join(
- [
- self.model.mapper(int(target_index) - 26)
- if target_index > 35
- else self.model.mapper(int(target_index))
- for target_index in self.targets[i]
- ]
- ).rstrip("_")
- else:
- ground_truth = self.model.mapper(int(self.targets[i]))
- caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
- images.append(wandb.Image(image, caption=caption))
-
- wandb.log({f"{self.caption}": images}, commit=False)
-
-
-class WandbSegmentationLogger(Callback):
- """Custom W&B callback for image logging."""
-
- def __init__(
- self,
- class_labels: Dict,
- example_indices: Optional[List] = None,
- num_examples: int = 4,
- ) -> None:
- """Initializes the WandbImageLogger with the model to train.
-
- Args:
- class_labels (Dict): A dict with int as key and class string as value.
- example_indices (Optional[List]): Indices for validation images. Defaults to None.
- num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
-
- """
-
- super().__init__()
- self.caption = None
- self.class_labels = {int(k): v for k, v in class_labels.items()}
- self.example_indices = example_indices
- self.test_sample_indices = None
- self.num_examples = num_examples
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and extracts validation images from the dataset."""
- self.model = model
- self.caption = "Validation Segmentation Examples"
- if self.example_indices is None:
- self.example_indices = np.random.randint(
- 0, len(self.model.val_dataset), self.num_examples
- )
- self.images = self.model.val_dataset.dataset.data[self.example_indices]
- self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
- self.targets = self.targets.tolist()
-
- def on_test_begin(self) -> None:
- """Get samples from test dataset."""
- self.caption = "Test Segmentation Examples"
- if self.test_sample_indices is None:
- self.test_sample_indices = np.random.randint(
- 0, len(self.model.test_dataset), self.num_examples
- )
- self.images = self.model.test_dataset.data[self.test_sample_indices]
- self.targets = self.model.test_dataset.targets[self.test_sample_indices]
- self.targets = self.targets.tolist()
-
- def on_test_end(self) -> None:
- """Log test images."""
- self.on_epoch_end(0, {})
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Get network predictions on validation images."""
- images = []
- for i, image in enumerate(self.images):
- pred_mask = (
- self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
- )
- gt_mask = np.array(self.targets[i])
- images.append(
- wandb.Image(
- image,
- masks={
- "predictions": {
- "mask_data": pred_mask,
- "class_labels": self.class_labels,
- },
- "ground_truth": {
- "mask_data": gt_mask,
- "class_labels": self.class_labels,
- },
- },
- )
- )
-
- wandb.log({f"{self.caption}": images}, commit=False)
-
-
-class WandbReconstructionLogger(Callback):
- """Custom W&B callback for image reconstructions logging."""
-
- def __init__(
- self, example_indices: Optional[List] = None, num_examples: int = 4,
- ) -> None:
- """Initializes the WandbImageLogger with the model to train.
-
- Args:
- example_indices (Optional[List]): Indices for validation images. Defaults to None.
- num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
-
- """
-
- super().__init__()
- self.caption = None
- self.example_indices = example_indices
- self.test_sample_indices = None
- self.num_examples = num_examples
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and extracts validation images from the dataset."""
- self.model = model
- self.caption = "Validation Reconstructions Examples"
- if self.example_indices is None:
- self.example_indices = np.random.randint(
- 0, len(self.model.val_dataset), self.num_examples
- )
- self.images = self.model.val_dataset.dataset.data[self.example_indices]
-
- def on_test_begin(self) -> None:
- """Get samples from test dataset."""
- self.caption = "Test Reconstructions Examples"
- if self.test_sample_indices is None:
- self.test_sample_indices = np.random.randint(
- 0, len(self.model.test_dataset), self.num_examples
- )
- self.images = self.model.test_dataset.data[self.test_sample_indices]
-
- def on_test_end(self) -> None:
- """Log test images."""
- self.on_epoch_end(0, {})
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Get network predictions on validation images."""
- images = []
- for image in self.images:
- reconstructed_image = (
- self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
- )
- images.append(image)
- images.append(reconstructed_image)
-
- wandb.log(
- {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False,
- )
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
deleted file mode 100644
index b770c94..0000000
--- a/src/training/trainer/train.py
+++ /dev/null
@@ -1,325 +0,0 @@
-"""Training script for PyTorch models."""
-
-from pathlib import Path
-import time
-from typing import Dict, List, Optional, Tuple, Type
-import warnings
-
-from einops import rearrange
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from torch.optim.swa_utils import update_bn
-from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA
-from training.trainer.util import log_val_metric
-import wandb
-
-from text_recognizer.models import Model
-
-
-torch.backends.cudnn.benchmark = True
-np.random.seed(4711)
-torch.manual_seed(4711)
-torch.cuda.manual_seed(4711)
-
-
-warnings.filterwarnings("ignore")
-
-
-class Trainer:
- """Trainer for training PyTorch models."""
-
- def __init__(
- self,
- max_epochs: int,
- callbacks: List[Type[Callback]],
- transformer_model: bool = False,
- max_norm: float = 0.0,
- freeze_backbone: Optional[int] = None,
- ) -> None:
- """Initialization of the Trainer.
-
- Args:
- max_epochs (int): The maximum number of epochs in the training loop.
- callbacks (CallbackList): List of callbacks to be called.
- transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
- max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0.
- freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training
- Transformers. Default is None.
-
- """
- # Training arguments.
- self.start_epoch = 1
- self.max_epochs = max_epochs
- self.callbacks = callbacks
- self.freeze_backbone = freeze_backbone
-
- # Flag for setting callbacks.
- self.callbacks_configured = False
-
- self.transformer_model = transformer_model
-
- self.max_norm = max_norm
-
- # Model placeholders
- self.model = None
-
- def _configure_callbacks(self) -> None:
- """Instantiate the CallbackList."""
- if not self.callbacks_configured:
- # If learning rate schedulers are present, they need to be added to the callbacks.
- if self.model.swa_scheduler is not None:
- self.callbacks.append(SWA())
- elif self.model.lr_scheduler is not None:
- self.callbacks.append(LRScheduler())
-
- self.callbacks = CallbackList(self.model, self.callbacks)
-
- def compute_metrics(
- self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int
- ) -> Dict:
- """Computes metrics for output and target pairs."""
- # Compute metrics.
- loss = loss.detach().float().item()
- output = output.detach()
- targets = targets.detach()
- if self.model.metrics is not None:
- metrics = {}
- for metric in self.model.metrics:
- if metric == "cer" or metric == "wer":
- metrics[metric] = self.model.metrics[metric](
- output,
- targets,
- batch_size,
- self.model.mapper(self.model.pad_token),
- )
- else:
- metrics[metric] = self.model.metrics[metric](output, targets)
- else:
- metrics = {}
- metrics["loss"] = loss
-
- return metrics
-
- def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
- """Performs the training step."""
- # Pass the tensor to the device for computation.
- data, targets = samples
- data, targets = (
- data.to(self.model.device),
- targets.to(self.model.device),
- )
-
- batch_size = data.shape[0]
-
- # Placeholder for uxiliary loss.
- aux_loss = None
-
- # Forward pass.
- # Get the network prediction.
- if self.transformer_model:
- if self.freeze_backbone is not None and batch < self.freeze_backbone:
- with torch.no_grad():
- image_features = self.model.network.extract_image_features(data)
-
- if isinstance(image_features, Tuple):
- image_features, _ = image_features
-
- output = self.model.network.decode_image_features(
- image_features, targets[:, :-1]
- )
- else:
- output = self.model.network.forward(data, targets[:, :-1])
- if isinstance(output, Tuple):
- output, aux_loss = output
- output = rearrange(output, "b t v -> (b t) v")
- targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
- else:
- output = self.model.forward(data)
-
- if isinstance(output, Tuple):
- output, aux_loss = output
- targets = data
-
- # Compute the loss.
- loss = self.model.criterion(output, targets)
-
- if aux_loss is not None:
- loss += aux_loss
-
- # Backward pass.
- # Clear the previous gradients.
- for p in self.model.network.parameters():
- p.grad = None
-
- # Compute the gradients.
- loss.backward()
-
- if self.max_norm > 0:
- torch.nn.utils.clip_grad_norm_(
- self.model.network.parameters(), self.max_norm
- )
-
- # Perform updates using calculated gradients.
- self.model.optimizer.step()
-
- metrics = self.compute_metrics(output, targets, loss, batch_size)
-
- return metrics
-
- def train(self) -> None:
- """Runs the training loop for one epoch."""
- # Set model to traning mode.
- self.model.train()
-
- for batch, samples in enumerate(self.model.train_dataloader()):
- self.callbacks.on_train_batch_begin(batch)
- metrics = self.training_step(batch, samples)
- self.callbacks.on_train_batch_end(batch, logs=metrics)
-
- @torch.no_grad()
- def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
- """Performs the validation step."""
-
- # Pass the tensor to the device for computation.
- data, targets = samples
- data, targets = (
- data.to(self.model.device),
- targets.to(self.model.device),
- )
-
- batch_size = data.shape[0]
-
- # Placeholder for uxiliary loss.
- aux_loss = None
-
- # Forward pass.
- # Get the network prediction.
- # Use SWA if available and using test dataset.
- if self.transformer_model:
- output = self.model.network.forward(data, targets[:, :-1])
- if isinstance(output, Tuple):
- output, aux_loss = output
- output = rearrange(output, "b t v -> (b t) v")
- targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
- else:
- output = self.model.forward(data)
-
- if isinstance(output, Tuple):
- output, aux_loss = output
- targets = data
-
- # Compute the loss.
- loss = self.model.criterion(output, targets)
-
- if aux_loss is not None:
- loss += aux_loss
-
- # Compute metrics.
- metrics = self.compute_metrics(output, targets, loss, batch_size)
-
- return metrics
-
- def validate(self) -> Dict:
- """Runs the validation loop for one epoch."""
- # Set model to eval mode.
- self.model.eval()
-
- # Summary for the current eval loop.
- summary = []
-
- for batch, samples in enumerate(self.model.val_dataloader()):
- self.callbacks.on_validation_batch_begin(batch)
- metrics = self.validation_step(batch, samples)
- self.callbacks.on_validation_batch_end(batch, logs=metrics)
- summary.append(metrics)
-
- # Compute mean of all metrics.
- metrics_mean = {
- "val_" + metric: np.mean([x[metric] for x in summary])
- for metric in summary[0]
- }
-
- return metrics_mean
-
- def fit(self, model: Type[Model]) -> None:
- """Runs the training and evaluation loop."""
-
- # Sets model, loads the data, criterion, and optimizers.
- self.model = model
- self.model.prepare_data()
- self.model.configure_model()
-
- # Configure callbacks.
- self._configure_callbacks()
-
- # Set start time.
- t_start = time.time()
-
- self.callbacks.on_fit_begin()
-
- # Run the training loop.
- for epoch in range(self.start_epoch, self.max_epochs + 1):
- self.callbacks.on_epoch_begin(epoch)
-
- # Perform one training pass over the training set.
- self.train()
-
- # Evaluate the model on the validation set.
- val_metrics = self.validate()
- log_val_metric(val_metrics, epoch)
-
- self.callbacks.on_epoch_end(epoch, logs=val_metrics)
-
- if self.model.stop_training:
- break
-
- # Calculate the total training time.
- t_end = time.time()
- t_training = t_end - t_start
-
- self.callbacks.on_fit_end()
-
- logger.info(f"Training took {t_training:.2f} s.")
-
- # "Teardown".
- self.model = None
-
- def test(self, model: Type[Model]) -> Dict:
- """Run inference on test data."""
-
- # Sets model, loads the data, criterion, and optimizers.
- self.model = model
- self.model.prepare_data()
- self.model.configure_model()
-
- # Configure callbacks.
- self._configure_callbacks()
-
- self.callbacks.on_test_begin()
-
- self.model.eval()
-
- # Check if SWA network is available.
- self.model.use_swa_model()
-
- # Summary for the current test loop.
- summary = []
-
- for batch, samples in enumerate(self.model.test_dataloader()):
- metrics = self.validation_step(batch, samples)
- summary.append(metrics)
-
- self.callbacks.on_test_end()
-
- # Compute mean of all test metrics.
- metrics_mean = {
- "test_" + metric: np.mean([x[metric] for x in summary])
- for metric in summary[0]
- }
-
- # "Teardown".
- self.model = None
-
- return metrics_mean
diff --git a/src/training/trainer/util.py b/src/training/trainer/util.py
deleted file mode 100644
index 7cf1b45..0000000
--- a/src/training/trainer/util.py
+++ /dev/null
@@ -1,28 +0,0 @@
-"""Utility functions for training neural networks."""
-from typing import Dict, Optional
-
-from loguru import logger
-
-
-def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None:
- """Logging of val metrics to file/terminal."""
- log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
- logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()))
-
-
-class RunningAverage:
- """Maintains a running average."""
-
- def __init__(self) -> None:
- """Initializes the parameters."""
- self.steps = 0
- self.total = 0
-
- def update(self, val: float) -> None:
- """Updates the parameters."""
- self.total += val
- self.steps += 1
-
- def __call__(self) -> float:
- """Computes the running average."""
- return self.total / float(self.steps)