From 7e8e54e84c63171e748bbf09516fd517e6821ace Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 20 Mar 2021 18:09:06 +0100
Subject: Inital commit for refactoring to lightning

---
 training/experiments/default_config_emnist.yml |  70 +++++
 training/experiments/embedding_experiment.yml  |  64 +++++
 training/experiments/sample_experiment.yml     |  99 +++++++
 training/gpu_manager.py                        |  62 ++++
 training/prepare_experiments.py                |  34 +++
 training/run_experiment.py                     | 382 +++++++++++++++++++++++++
 training/run_sweep.py                          |  92 ++++++
 training/sweep_emnist.yml                      |  26 ++
 training/sweep_emnist_resnet.yml               |  50 ++++
 training/trainer/__init__.py                   |   2 +
 training/trainer/callbacks/__init__.py         |  29 ++
 training/trainer/callbacks/base.py             | 188 ++++++++++++
 training/trainer/callbacks/checkpoint.py       |  95 ++++++
 training/trainer/callbacks/early_stopping.py   | 108 +++++++
 training/trainer/callbacks/lr_schedulers.py    |  77 +++++
 training/trainer/callbacks/progress_bar.py     |  65 +++++
 training/trainer/callbacks/wandb_callbacks.py  | 261 +++++++++++++++++
 training/trainer/train.py                      | 325 +++++++++++++++++++++
 training/trainer/util.py                       |  28 ++
 19 files changed, 2057 insertions(+)
 create mode 100644 training/experiments/default_config_emnist.yml
 create mode 100644 training/experiments/embedding_experiment.yml
 create mode 100644 training/experiments/sample_experiment.yml
 create mode 100644 training/gpu_manager.py
 create mode 100644 training/prepare_experiments.py
 create mode 100644 training/run_experiment.py
 create mode 100644 training/run_sweep.py
 create mode 100644 training/sweep_emnist.yml
 create mode 100644 training/sweep_emnist_resnet.yml
 create mode 100644 training/trainer/__init__.py
 create mode 100644 training/trainer/callbacks/__init__.py
 create mode 100644 training/trainer/callbacks/base.py
 create mode 100644 training/trainer/callbacks/checkpoint.py
 create mode 100644 training/trainer/callbacks/early_stopping.py
 create mode 100644 training/trainer/callbacks/lr_schedulers.py
 create mode 100644 training/trainer/callbacks/progress_bar.py
 create mode 100644 training/trainer/callbacks/wandb_callbacks.py
 create mode 100644 training/trainer/train.py
 create mode 100644 training/trainer/util.py

(limited to 'training')

diff --git a/training/experiments/default_config_emnist.yml b/training/experiments/default_config_emnist.yml
new file mode 100644
index 0000000..bf2ed0a
--- /dev/null
+++ b/training/experiments/default_config_emnist.yml
@@ -0,0 +1,70 @@
+dataset: EmnistDataset
+dataset_args:
+  sample_to_balance: true
+  subsample_fraction: 0.33
+  transform: null
+  target_transform: null
+  seed: 4711
+
+data_loader_args:
+  splits: [train, val]
+  shuffle: true
+  num_workers: 8
+  cuda: true
+
+model: CharacterModel
+metrics: [accuracy]
+
+network_args:
+  in_channels: 1
+  num_classes: 80
+  depths: [2]
+  block_sizes: [256]
+
+train_args:
+  batch_size: 256
+  epochs: 5
+
+criterion: CrossEntropyLoss
+criterion_args:
+  weight: null
+  ignore_index: -100
+  reduction: mean
+
+optimizer: AdamW
+optimizer_args:
+  lr: 1.e-03
+  betas: [0.9, 0.999]
+  eps: 1.e-08
+  # weight_decay: 5.e-4
+  amsgrad: false
+
+lr_scheduler: OneCycleLR
+lr_scheduler_args:
+  max_lr: 1.e-03
+  epochs: 5
+  anneal_strategy: linear
+
+
+callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR]
+callback_args:
+  Checkpoint:
+    monitor: val_accuracy
+  ProgressBar:
+    epochs: 5
+    log_batch_frequency: 100
+  EarlyStopping:
+    monitor: val_loss
+    min_delta: 0.0
+    patience: 3
+    mode: min
+  WandbCallback:
+    log_batch_frequency: 10
+  WandbImageLogger:
+    num_examples: 4
+  OneCycleLR:
+    null
+verbosity: 1 # 0, 1, 2
+resume_experiment: null
+train: true
+validation_metric: val_accuracy
diff --git a/training/experiments/embedding_experiment.yml b/training/experiments/embedding_experiment.yml
new file mode 100644
index 0000000..1e5f941
--- /dev/null
+++ b/training/experiments/embedding_experiment.yml
@@ -0,0 +1,64 @@
+experiment_group: Embedding Experiments
+experiments:
+    - train_args:
+        transformer_model: false
+        batch_size: &batch_size 256
+        max_epochs: &max_epochs 32
+        input_shape: [[1, 28, 28]]
+      dataset:
+        type: EmnistDataset
+        args:
+          sample_to_balance: true
+          subsample_fraction: null
+          transform: null
+          target_transform: null
+          seed: 4711
+        train_args:
+          num_workers: 8
+          train_fraction: 0.85
+          batch_size: *batch_size
+      model: CharacterModel
+      metrics: []
+      network:
+        type: DenseNet
+        args:
+          growth_rate: 4
+          block_config: [4, 4]
+          in_channels: 1
+          base_channels: 24
+          num_classes: 128
+          bn_size: 4
+          dropout_rate: 0.1
+          classifier: true
+          activation: elu
+      criterion:
+        type: EmbeddingLoss
+        args:
+          margin: 0.2
+          type_of_triplets: semihard
+      optimizer:
+        type: AdamW
+        args:
+          lr: 1.e-02
+          betas: [0.9, 0.999]
+          eps: 1.e-08
+          weight_decay: 5.e-4
+          amsgrad: false
+      lr_scheduler:
+        type: CosineAnnealingLR
+        args:
+          T_max: *max_epochs
+      callbacks: [Checkpoint, ProgressBar, WandbCallback]
+      callback_args:
+        Checkpoint:
+          monitor: val_loss
+          mode: min
+        ProgressBar:
+          epochs: *max_epochs
+        WandbCallback:
+          log_batch_frequency: 10
+      verbosity: 1 # 0, 1, 2
+      resume_experiment: null
+      train: true
+      test: true
+      test_metric: mean_average_precision_at_r
diff --git a/training/experiments/sample_experiment.yml b/training/experiments/sample_experiment.yml
new file mode 100644
index 0000000..8f94475
--- /dev/null
+++ b/training/experiments/sample_experiment.yml
@@ -0,0 +1,99 @@
+experiment_group: Sample Experiments
+experiments:
+    - train_args:
+        batch_size: 256
+        max_epochs: &max_epochs 32
+      dataset:
+        type: EmnistDataset
+        args:
+          sample_to_balance: true
+          subsample_fraction: null
+          transform: null
+          target_transform: null
+          seed: 4711
+        train_args:
+          num_workers: 6
+          train_fraction: 0.8
+
+      model: CharacterModel
+      metrics: [accuracy]
+      # network: MLP
+      # network_args:
+      #   input_size: 784
+      #   hidden_size: 512
+      #   output_size: 80
+      #   num_layers: 5
+      #   dropout_rate: 0.2
+      #   activation_fn: SELU
+      network:
+        type: ResidualNetwork
+        args:
+          in_channels: 1
+          num_classes: 80
+          depths: [2, 2]
+          block_sizes: [64, 64]
+          activation: leaky_relu
+      # network:
+      #   type: WideResidualNetwork
+      #   args:
+      #     in_channels: 1
+      #     num_classes: 80
+      #     depth: 10
+      #     num_layers: 3
+      #     width_factor: 4
+      #     dropout_rate: 0.2
+      #     activation: SELU
+      # network: LeNet
+      # network_args:
+      #   output_size: 62
+      #   activation_fn: GELU
+      criterion:
+        type: CrossEntropyLoss
+        args:
+          weight: null
+          ignore_index: -100
+          reduction: mean
+      optimizer:
+        type: AdamW
+        args:
+          lr: 1.e-02
+          betas: [0.9, 0.999]
+          eps: 1.e-08
+          # weight_decay: 5.e-4
+          amsgrad: false
+      # lr_scheduler:
+      #   type: OneCycleLR
+      #   args:
+      #     max_lr: 1.e-03
+      #     epochs: *max_epochs
+      #     anneal_strategy: linear
+      lr_scheduler:
+        type: CosineAnnealingLR
+        args:
+          T_max: *max_epochs
+          interval: epoch
+      swa_args:
+        start: 2
+        lr: 5.e-2
+      callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping]
+      callback_args:
+        Checkpoint:
+          monitor: val_accuracy
+        ProgressBar:
+          epochs: null
+          log_batch_frequency: 100
+        EarlyStopping:
+          monitor: val_loss
+          min_delta: 0.0
+          patience: 5
+          mode: min
+        WandbCallback:
+          log_batch_frequency: 10
+        WandbImageLogger:
+          num_examples: 4
+          use_transpose: true
+      verbosity: 0 # 0, 1, 2
+      resume_experiment: null
+      train: true
+      test: true
+      test_metric: test_accuracy
diff --git a/training/gpu_manager.py b/training/gpu_manager.py
new file mode 100644
index 0000000..ce1b3dd
--- /dev/null
+++ b/training/gpu_manager.py
@@ -0,0 +1,62 @@
+"""GPUManager class."""
+import os
+import time
+from typing import Optional
+
+import gpustat
+from loguru import logger
+import numpy as np
+from redlock import Redlock
+
+
+GPU_LOCK_TIMEOUT = 5000  # ms
+
+
+class GPUManager:
+    """Class for allocating GPUs."""
+
+    def __init__(self, verbose: bool = False) -> None:
+        """Initializes Redlock manager."""
+        self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}])
+        self.verbose = verbose
+
+    def get_free_gpu(self) -> int:
+        """Gets a free GPU.
+
+        If some GPUs are available, try reserving one by checking out an exclusive redis lock.
+        If none available or can not get lock, sleep and check again.
+
+        Returns:
+            int: The gpu index.
+
+        """
+        while True:
+            gpu_index = self._get_free_gpu()
+            if gpu_index is not None:
+                return gpu_index
+
+            if self.verbose:
+                logger.debug(f"pid {os.getpid()} sleeping")
+            time.sleep(GPU_LOCK_TIMEOUT / 1000)
+
+    def _get_free_gpu(self) -> Optional[int]:
+        """Fetches an available GPU index."""
+        try:
+            available_gpu_indices = [
+                gpu.index
+                for gpu in gpustat.GPUStatCollection.new_query()
+                if gpu.memory_used < 0.5 * gpu.memory_total
+            ]
+        except Exception as e:
+            logger.debug(f"Got the following exception: {e}")
+            return None
+
+        if available_gpu_indices:
+            gpu_index = np.random.choice(available_gpu_indices)
+            if self.verbose:
+                logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}")
+            if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT):
+                return int(gpu_index)
+            if self.verbose:
+                logger.debug(f"pid {os.getpid()} could not get lock.")
+        return None
diff --git a/training/prepare_experiments.py b/training/prepare_experiments.py
new file mode 100644
index 0000000..21997af
--- /dev/null
+++ b/training/prepare_experiments.py
@@ -0,0 +1,34 @@
+"""Run a experiment from a config file."""
+import json
+
+import click
+import yaml
+
+
+def run_experiments(experiments_filename: str) -> None:
+    """Run experiment from file."""
+    with open(experiments_filename, "r") as f:
+        experiments_config = yaml.safe_load(f)
+
+    num_experiments = len(experiments_config["experiments"])
+    for index in range(num_experiments):
+        experiment_config = experiments_config["experiments"][index]
+        experiment_config["experiment_group"] = experiments_config["experiment_group"]
+        cmd = f"poetry run run-experiment --gpu=-1 --save '{json.dumps(experiment_config)}'"
+        print(cmd)
+
+
+@click.command()
+@click.option(
+    "--experiments_filename",
+    required=True,
+    type=str,
+    help="Filename of Yaml file of experiments to run.",
+)
+def run_cli(experiments_filename: str) -> None:
+    """Parse command-line arguments and run experiments from provided file."""
+    run_experiments(experiments_filename)
+
+
+if __name__ == "__main__":
+    run_cli()
diff --git a/training/run_experiment.py b/training/run_experiment.py
new file mode 100644
index 0000000..faafea6
--- /dev/null
+++ b/training/run_experiment.py
@@ -0,0 +1,382 @@
+"""Script to run experiments."""
+from datetime import datetime
+from glob import glob
+import importlib
+import json
+import os
+from pathlib import Path
+import re
+from typing import Callable, Dict, List, Optional, Tuple, Type
+import warnings
+
+import click
+from loguru import logger
+import numpy as np
+import torch
+from torchsummary import summary
+from tqdm import tqdm
+from training.gpu_manager import GPUManager
+from training.trainer.callbacks import CallbackList
+from training.trainer.train import Trainer
+import wandb
+import yaml
+
+import text_recognizer.models
+from text_recognizer.models import Model
+import text_recognizer.networks
+from text_recognizer.networks.loss import loss as custom_loss_module
+
+EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
+
+
+def _get_level(verbose: int) -> int:
+    """Sets the logger level."""
+    levels = {0: 40, 1: 20, 2: 10}
+    verbose = verbose if verbose <= 2 else 2
+    return levels[verbose]
+
+
+def _create_experiment_dir(
+    experiment_config: Dict, checkpoint: Optional[str] = None
+) -> Path:
+    """Create new experiment."""
+    EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
+    experiment_dir = EXPERIMENTS_DIRNAME / (
+        f"{experiment_config['model']}_"
+        + f"{experiment_config['dataset']['type']}_"
+        + f"{experiment_config['network']['type']}"
+    )
+
+    if checkpoint is None:
+        experiment = datetime.now().strftime("%m%d_%H%M%S")
+        logger.debug(f"Creating a new experiment called {experiment}")
+    else:
+        available_experiments = glob(str(experiment_dir) + "/*")
+        available_experiments.sort()
+        if checkpoint == "last":
+            experiment = available_experiments[-1]
+            logger.debug(f"Resuming the latest experiment {experiment}")
+        else:
+            experiment = checkpoint
+            if not str(experiment_dir / experiment) in available_experiments:
+                raise FileNotFoundError("Experiment does not exist.")
+            logger.debug(f"Resuming the from experiment {checkpoint}")
+
+    experiment_dir = experiment_dir / experiment
+
+    # Create log and model directories.
+    log_dir = experiment_dir / "log"
+    model_dir = experiment_dir / "model"
+
+    return experiment_dir, log_dir, model_dir
+
+
+def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]:
+    """Loads all modules and arguments."""
+    # Load the dataset module.
+    dataset_args = experiment_config.get("dataset", {})
+    dataset_ = dataset_args["type"]
+
+    # Import the model module and model arguments.
+    model_class_ = getattr(text_recognizer.models, experiment_config["model"])
+
+    # Import metrics.
+    metric_fns_ = (
+        {
+            metric: getattr(text_recognizer.networks, metric)
+            for metric in experiment_config["metrics"]
+        }
+        if experiment_config["metrics"] is not None
+        else None
+    )
+
+    # Import network module and arguments.
+    network_fn_ = experiment_config["network"]["type"]
+    network_args = experiment_config["network"].get("args", {})
+
+    # Criterion
+    if experiment_config["criterion"]["type"] in custom_loss_module.__all__:
+        criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"])
+    else:
+        criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
+    criterion_args = experiment_config["criterion"].get("args", {}) or {}
+
+    # Optimizers
+    optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
+    optimizer_args = experiment_config["optimizer"].get("args", {})
+
+    # Learning rate scheduler
+    lr_scheduler_ = None
+    lr_scheduler_args = None
+    if "lr_scheduler" in experiment_config:
+        lr_scheduler_ = getattr(
+            torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"]
+        )
+        lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {}
+
+    # SWA scheduler.
+    if "swa_args" in experiment_config:
+        swa_args = experiment_config.get("swa_args", {}) or {}
+    else:
+        swa_args = None
+
+    model_args = {
+        "dataset": dataset_,
+        "dataset_args": dataset_args,
+        "metrics": metric_fns_,
+        "network_fn": network_fn_,
+        "network_args": network_args,
+        "criterion": criterion_,
+        "criterion_args": criterion_args,
+        "optimizer": optimizer_,
+        "optimizer_args": optimizer_args,
+        "lr_scheduler": lr_scheduler_,
+        "lr_scheduler_args": lr_scheduler_args,
+        "swa_args": swa_args,
+    }
+
+    return model_class_, model_args
+
+
+def _configure_callbacks(experiment_config: Dict, model_dir: Path) -> CallbackList:
+    """Configure a callback list for trainer."""
+    if "Checkpoint" in experiment_config["callback_args"]:
+        experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = str(
+            model_dir
+        )
+
+    # Initializes callbacks.
+    callback_modules = importlib.import_module("training.trainer.callbacks")
+    callbacks = []
+    for callback in experiment_config["callbacks"]:
+        args = experiment_config["callback_args"][callback] or {}
+        callbacks.append(getattr(callback_modules, callback)(**args))
+
+    return callbacks
+
+
+def _configure_logger(log_dir: Path, verbose: int = 0) -> None:
+    """Configure the loguru logger for output to terminal and disk."""
+    # Have to remove default logger to get tqdm to work properly.
+    logger.remove()
+
+    # Fetch verbosity level.
+    level = _get_level(verbose)
+
+    logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)
+    logger.add(
+        str(log_dir / "train.log"),
+        format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
+    )
+
+
+def _save_config(experiment_dir: Path, experiment_config: Dict) -> None:
+    """Copy config to experiment directory."""
+    config_path = experiment_dir / "config.yml"
+    with open(str(config_path), "w") as f:
+        yaml.dump(experiment_config, f)
+
+
+def _load_from_checkpoint(
+    model: Type[Model], model_dir: Path, pretrained_weights: str = None,
+) -> None:
+    """If checkpoint exists, load model weights and optimizers from checkpoint."""
+    # Get checkpoint path.
+    if pretrained_weights is not None:
+        logger.info(f"Loading weights from {pretrained_weights}.")
+        checkpoint_path = (
+            EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt"
+        )
+    else:
+        logger.info(f"Loading weights from {model_dir}.")
+        checkpoint_path = model_dir / "last.pt"
+    if checkpoint_path.exists():
+        logger.info("Loading and resuming training from checkpoint.")
+        model.load_from_checkpoint(checkpoint_path)
+
+
+def evaluate_embedding(model: Type[Model]) -> Dict:
+    """Evaluates the embedding space."""
+    from pytorch_metric_learning import testers
+    from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
+
+    accuracy_calculator = AccuracyCalculator(
+        include=("mean_average_precision_at_r",), k=10
+    )
+
+    def get_all_embeddings(model: Type[Model]) -> Tuple:
+        tester = testers.BaseTester()
+        return tester.get_all_embeddings(model.test_dataset, model.network)
+
+    embeddings, labels = get_all_embeddings(model)
+    logger.info("Computing embedding accuracy")
+    accuracies = accuracy_calculator.get_accuracy(
+        embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True
+    )
+    logger.info(
+        f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}"
+    )
+    return accuracies
+
+
+def run_experiment(
+    experiment_config: Dict,
+    save_weights: bool,
+    device: str,
+    use_wandb: bool,
+    train: bool,
+    test: bool,
+    verbose: int = 0,
+    checkpoint: Optional[str] = None,
+    pretrained_weights: Optional[str] = None,
+) -> None:
+    """Runs an experiment."""
+    logger.info(f"Experiment config: {json.dumps(experiment_config)}")
+
+    # Create new experiment.
+    experiment_dir, log_dir, model_dir = _create_experiment_dir(
+        experiment_config, checkpoint
+    )
+
+    # Make sure the log/model directory exists.
+    log_dir.mkdir(parents=True, exist_ok=True)
+    model_dir.mkdir(parents=True, exist_ok=True)
+
+    # Load the modules and model arguments.
+    model_class_, model_args = _load_modules_and_arguments(experiment_config)
+
+    # Initializes the model with experiment config.
+    model = model_class_(**model_args, device=device)
+
+    callbacks = _configure_callbacks(experiment_config, model_dir)
+
+    # Setup logger.
+    _configure_logger(log_dir, verbose)
+
+    # Load from checkpoint if resuming an experiment.
+    resume = False
+    if checkpoint is not None or pretrained_weights is not None:
+        # resume = True
+        _load_from_checkpoint(model, model_dir, pretrained_weights)
+
+    logger.info(f"The class mapping is {model.mapping}")
+
+    # Initializes Weights & Biases
+    if use_wandb:
+        wandb.init(project="text-recognizer", config=experiment_config, resume=resume)
+
+        # Lets W&B save the model and track the gradients and optional parameters.
+        wandb.watch(model.network)
+
+    experiment_config["experiment_group"] = experiment_config.get(
+        "experiment_group", None
+    )
+
+    experiment_config["device"] = device
+
+    # Save the config used in the experiment folder.
+    _save_config(experiment_dir, experiment_config)
+
+    # Prints a summary of the network in terminal.
+    model.summary(experiment_config["train_args"]["input_shape"])
+
+    # Load trainer.
+    trainer = Trainer(
+        max_epochs=experiment_config["train_args"]["max_epochs"],
+        callbacks=callbacks,
+        transformer_model=experiment_config["train_args"]["transformer_model"],
+        max_norm=experiment_config["train_args"]["max_norm"],
+        freeze_backbone=experiment_config["train_args"]["freeze_backbone"],
+    )
+
+    # Train the model.
+    if train:
+        trainer.fit(model)
+
+    # Run inference over test set.
+    if test:
+        logger.info("Loading checkpoint with the best weights.")
+        if "checkpoint" in experiment_config["train_args"]:
+            model.load_from_checkpoint(
+                model_dir / experiment_config["train_args"]["checkpoint"]
+            )
+        else:
+            model.load_from_checkpoint(model_dir / "best.pt")
+
+        logger.info("Running inference on test set.")
+        if experiment_config["criterion"]["type"] == "EmbeddingLoss":
+            logger.info("Evaluating embedding.")
+            score = evaluate_embedding(model)
+        else:
+            score = trainer.test(model)
+
+        logger.info(f"Test set evaluation: {score}")
+
+        if use_wandb:
+            wandb.log(
+                {
+                    experiment_config["test_metric"]: score[
+                        experiment_config["test_metric"]
+                    ]
+                }
+            )
+
+    if save_weights:
+        model.save_weights(model_dir)
+
+
+@click.command()
+@click.argument("experiment_config",)
+@click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.")
+@click.option(
+    "--save",
+    is_flag=True,
+    help="If set, the final weights will be saved to a canonical, version-controlled location.",
+)
+@click.option(
+    "--nowandb", is_flag=False, help="If true, do not use wandb for this run."
+)
+@click.option("--test", is_flag=True, help="If true, test the model.")
+@click.option("-v", "--verbose", count=True)
+@click.option("--checkpoint", type=str, help="Path to the experiment.")
+@click.option(
+    "--pretrained_weights", type=str, help="Path to pretrained model weights."
+)
+@click.option(
+    "--notrain", is_flag=False, help="Do not train the model.",
+)
+def run_cli(
+    experiment_config: str,
+    gpu: int,
+    save: bool,
+    nowandb: bool,
+    notrain: bool,
+    test: bool,
+    verbose: int,
+    checkpoint: Optional[str] = None,
+    pretrained_weights: Optional[str] = None,
+) -> None:
+    """Run experiment."""
+    if gpu < 0:
+        gpu_manager = GPUManager(True)
+        gpu = gpu_manager.get_free_gpu()
+    device = "cuda:" + str(gpu)
+
+    experiment_config = json.loads(experiment_config)
+    os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}"
+
+    run_experiment(
+        experiment_config,
+        save,
+        device,
+        use_wandb=not nowandb,
+        train=not notrain,
+        test=test,
+        verbose=verbose,
+        checkpoint=checkpoint,
+        pretrained_weights=pretrained_weights,
+    )
+
+
+if __name__ == "__main__":
+    run_cli()
diff --git a/training/run_sweep.py b/training/run_sweep.py
new file mode 100644
index 0000000..a578592
--- /dev/null
+++ b/training/run_sweep.py
@@ -0,0 +1,92 @@
+"""W&B Sweep Functionality."""
+from ast import literal_eval
+import json
+import os
+from pathlib import Path
+import signal
+import subprocess  # nosec
+import sys
+from typing import Dict, List, Tuple
+
+import click
+import yaml
+
+EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
+
+
+def load_config() -> Dict:
+    """Load base hyperparameter config."""
+    with open(str(EXPERIMENTS_DIRNAME / "default_config_emnist.yml"), "r") as f:
+        default_config = yaml.safe_load(f)
+    return default_config
+
+
+def args_to_json(
+    default_config: dict, preserve_args: tuple = ("gpu", "save")
+) -> Tuple[dict, list]:
+    """Convert command line arguments to nested config values.
+
+    i.e. run_sweep.py --dataset_args.foo=1.7
+    {
+        "dataset_args": {
+            "foo": 1.7
+        }
+    }
+
+    Args:
+        default_config (dict): The base config used for every experiment.
+        preserve_args (tuple): Arguments preserved for all runs. Defaults to ("gpu", "save").
+
+    Returns:
+        Tuple[dict, list]: Tuple of config dictionary and list of arguments.
+
+    """
+
+    args = []
+    config = default_config.copy()
+    key, val = None, None
+    for arg in sys.argv[1:]:
+        if "=" in arg:
+            key, val = arg.split("=")
+        elif key:
+            val = arg
+        else:
+            key = arg
+        if key and val:
+            parsed_key = key.lstrip("-").split(".")
+            if parsed_key[0] in preserve_args:
+                args.append("--{}={}".format(parsed_key[0], val))
+            else:
+                nested = config
+                for level in parsed_key[:-1]:
+                    nested[level] = config.get(level, {})
+                    nested = nested[level]
+                try:
+                    # Convert numerics to floats / ints
+                    val = literal_eval(val)
+                except ValueError:
+                    pass
+                nested[parsed_key[-1]] = val
+            key, val = None, None
+    return config, args
+
+
+def main() -> None:
+    """Runs a W&B sweep."""
+    default_config = load_config()
+    config, args = args_to_json(default_config)
+    env = {
+        k: v for k, v in os.environ.items() if k not in ("WANDB_PROGRAM", "WANDB_ARGS")
+    }
+    # pylint: disable=subprocess-popen-preexec-fn
+    run = subprocess.Popen(
+        ["python", "training/run_experiment.py", *args, json.dumps(config)],
+        env=env,
+        preexec_fn=os.setsid,
+    )  # nosec
+    signal.signal(signal.SIGTERM, lambda *args: run.terminate())
+    run.wait()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/training/sweep_emnist.yml b/training/sweep_emnist.yml
new file mode 100644
index 0000000..48d7261
--- /dev/null
+++ b/training/sweep_emnist.yml
@@ -0,0 +1,26 @@
+program: training/run_sweep.py
+method: bayes
+metric:
+  name: val_loss
+  goal: minimize
+parameters:
+  dataset:
+    value: EmnistDataset
+  model:
+    value: CharacterModel
+  network:
+    value: MLP
+  network_args.hidden_size:
+    values: [128, 256]
+  network_args.dropout_rate:
+    values: [0.2, 0.4]
+  network_args.num_layers:
+    values: [3, 6]
+  optimizer_args.lr:
+    values: [1.e-1, 1.e-5]
+  lr_scheduler_args.max_lr:
+    values: [1.0e-1, 1.0e-5]
+  train_args.batch_size:
+    values: [64, 128]
+  train_args.epochs:
+    value: 5
diff --git a/training/sweep_emnist_resnet.yml b/training/sweep_emnist_resnet.yml
new file mode 100644
index 0000000..19a3040
--- /dev/null
+++ b/training/sweep_emnist_resnet.yml
@@ -0,0 +1,50 @@
+program: training/run_sweep.py
+method: bayes
+metric:
+  name: val_accuracy
+  goal: maximize
+parameters:
+  dataset:
+    value: EmnistDataset
+  model:
+    value: CharacterModel
+  network:
+    value: ResidualNetwork
+  network_args.block_sizes:
+    distribution: q_uniform
+    min: 16
+    max: 256
+    q: 8
+  network_args.depths:
+    distribution: int_uniform
+    min: 1
+    max: 3
+  network_args.levels:
+      distribution: int_uniform
+      min: 1
+      max: 2
+  network_args.activation:
+    distribution: categorical
+    values:
+      - gelu
+      - leaky_relu
+      - relu
+      - selu
+  optimizer_args.lr:
+    distribution: uniform
+    min: 1.e-5
+    max: 1.e-1
+  lr_scheduler_args.max_lr:
+    distribution: uniform
+    min: 1.e-5
+    max: 1.e-1
+  train_args.batch_size:
+    distribution: q_uniform
+    min: 32
+    max: 256
+    q: 8
+  train_args.epochs:
+    value: 5
+early_terminate:
+   type: hyperband
+   min_iter: 2
diff --git a/training/trainer/__init__.py b/training/trainer/__init__.py
new file mode 100644
index 0000000..de41bfb
--- /dev/null
+++ b/training/trainer/__init__.py
@@ -0,0 +1,2 @@
+"""Trainer modules."""
+from .train import Trainer
diff --git a/training/trainer/callbacks/__init__.py b/training/trainer/callbacks/__init__.py
new file mode 100644
index 0000000..80c4177
--- /dev/null
+++ b/training/trainer/callbacks/__init__.py
@@ -0,0 +1,29 @@
+"""The callback modules used in the training script."""
+from .base import Callback, CallbackList
+from .checkpoint import Checkpoint
+from .early_stopping import EarlyStopping
+from .lr_schedulers import (
+    LRScheduler,
+    SWA,
+)
+from .progress_bar import ProgressBar
+from .wandb_callbacks import (
+    WandbCallback,
+    WandbImageLogger,
+    WandbReconstructionLogger,
+    WandbSegmentationLogger,
+)
+
+__all__ = [
+    "Callback",
+    "CallbackList",
+    "Checkpoint",
+    "EarlyStopping",
+    "LRScheduler",
+    "WandbCallback",
+    "WandbImageLogger",
+    "WandbReconstructionLogger",
+    "WandbSegmentationLogger",
+    "ProgressBar",
+    "SWA",
+]
diff --git a/training/trainer/callbacks/base.py b/training/trainer/callbacks/base.py
new file mode 100644
index 0000000..500b642
--- /dev/null
+++ b/training/trainer/callbacks/base.py
@@ -0,0 +1,188 @@
+"""Metaclass for callback functions."""
+
+from enum import Enum
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from loguru import logger
+import numpy as np
+import torch
+
+from text_recognizer.models import Model
+
+
+class ModeKeys:
+    """Mode keys for CallbackList."""
+
+    TRAIN = "train"
+    VALIDATION = "validation"
+
+
+class Callback:
+    """Metaclass for callbacks used in training."""
+
+    def __init__(self) -> None:
+        """Initializes the Callback instance."""
+        self.model = None
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Set the model."""
+        self.model = model
+
+    def on_fit_begin(self) -> None:
+        """Called when fit begins."""
+        pass
+
+    def on_fit_end(self) -> None:
+        """Called when fit ends."""
+        pass
+
+    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the beginning of an epoch. Only used in training mode."""
+        pass
+
+    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch. Only used in training mode."""
+        pass
+
+    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the beginning of an epoch."""
+        pass
+
+    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch."""
+        pass
+
+    def on_validation_batch_begin(
+        self, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Called at the beginning of an epoch."""
+        pass
+
+    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch."""
+        pass
+
+    def on_test_begin(self) -> None:
+        """Called at the beginning of test."""
+        pass
+
+    def on_test_end(self) -> None:
+        """Called at the end of test."""
+        pass
+
+
+class CallbackList:
+    """Container for abstracting away callback calls."""
+
+    mode_keys = ModeKeys()
+
+    def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None:
+        """Container for `Callback` instances.
+
+        This object wraps a list of `Callback` instances and allows them all to be
+        called via a single end point.
+
+        Args:
+            model (Type[Model]): A `Model` instance.
+            callbacks (List[Callback]): List of `Callback` instances. Defaults to None.
+
+        """
+
+        self._callbacks = callbacks or []
+        if model:
+            self.set_model(model)
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Set the model for all callbacks."""
+        self.model = model
+        for callback in self._callbacks:
+            callback.set_model(model=self.model)
+
+    def append(self, callback: Type[Callback]) -> None:
+        """Append new callback to callback list."""
+        self._callbacks.append(callback)
+
+    def on_fit_begin(self) -> None:
+        """Called when fit begins."""
+        for callback in self._callbacks:
+            callback.on_fit_begin()
+
+    def on_fit_end(self) -> None:
+        """Called when fit ends."""
+        for callback in self._callbacks:
+            callback.on_fit_end()
+
+    def on_test_begin(self) -> None:
+        """Called when test begins."""
+        for callback in self._callbacks:
+            callback.on_test_begin()
+
+    def on_test_end(self) -> None:
+        """Called when test ends."""
+        for callback in self._callbacks:
+            callback.on_test_end()
+
+    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the beginning of an epoch."""
+        for callback in self._callbacks:
+            callback.on_epoch_begin(epoch, logs)
+
+    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch."""
+        for callback in self._callbacks:
+            callback.on_epoch_end(epoch, logs)
+
+    def _call_batch_hook(
+        self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Helper function for all batch_{begin | end} methods."""
+        if hook == "begin":
+            self._call_batch_begin_hook(mode, batch, logs)
+        elif hook == "end":
+            self._call_batch_end_hook(mode, batch, logs)
+        else:
+            raise ValueError(f"Unrecognized hook {hook}.")
+
+    def _call_batch_begin_hook(
+        self, mode: str, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Helper function for all `on_*_batch_begin` methods."""
+        hook_name = f"on_{mode}_batch_begin"
+        self._call_batch_hook_helper(hook_name, batch, logs)
+
+    def _call_batch_end_hook(
+        self, mode: str, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Helper function for all `on_*_batch_end` methods."""
+        hook_name = f"on_{mode}_batch_end"
+        self._call_batch_hook_helper(hook_name, batch, logs)
+
+    def _call_batch_hook_helper(
+        self, hook_name: str, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Helper function for `on_*_batch_begin` methods."""
+        for callback in self._callbacks:
+            hook = getattr(callback, hook_name)
+            hook(batch, logs)
+
+    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the beginning of an epoch."""
+        self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)
+
+    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch."""
+        self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)
+
+    def on_validation_batch_begin(
+        self, batch: int, logs: Optional[Dict] = None
+    ) -> None:
+        """Called at the beginning of an epoch."""
+        self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)
+
+    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Called at the end of an epoch."""
+        self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)
+
+    def __iter__(self) -> iter:
+        """Iter function for callback list."""
+        return iter(self._callbacks)
diff --git a/training/trainer/callbacks/checkpoint.py b/training/trainer/callbacks/checkpoint.py
new file mode 100644
index 0000000..a54e0a9
--- /dev/null
+++ b/training/trainer/callbacks/checkpoint.py
@@ -0,0 +1,95 @@
+"""Callback checkpoint for training models."""
+from enum import Enum
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from loguru import logger
+import numpy as np
+import torch
+from training.trainer.callbacks import Callback
+
+from text_recognizer.models import Model
+
+
+class Checkpoint(Callback):
+    """Saving model parameters at the end of each epoch."""
+
+    mode_dict = {
+        "min": torch.lt,
+        "max": torch.gt,
+    }
+
+    def __init__(
+        self,
+        checkpoint_path: Union[str, Path],
+        monitor: str = "accuracy",
+        mode: str = "auto",
+        min_delta: float = 0.0,
+    ) -> None:
+        """Monitors a quantity that will allow us to determine the best model weights.
+
+        Args:
+            checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint.
+            monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
+            mode (str): Description of parameter `mode`. Defaults to "auto".
+            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
+
+        """
+        super().__init__()
+        self.checkpoint_path = Path(checkpoint_path)
+        self.monitor = monitor
+        self.mode = mode
+        self.min_delta = torch.tensor(min_delta)
+
+        if mode not in ["auto", "min", "max"]:
+            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
+
+            self.mode = "auto"
+
+        if self.mode == "auto":
+            if "accuracy" in self.monitor:
+                self.mode = "max"
+            else:
+                self.mode = "min"
+            logger.debug(
+                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
+            )
+
+        torch_inf = torch.tensor(np.inf)
+        self.min_delta *= 1 if self.monitor_op == torch.gt else -1
+        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
+
+    @property
+    def monitor_op(self) -> float:
+        """Returns the comparison method."""
+        return self.mode_dict[self.mode]
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Saves a checkpoint for the network parameters.
+
+        Args:
+            epoch (int): The current epoch.
+            logs (Dict): The log containing the monitored metrics.
+
+        """
+        current = self.get_monitor_value(logs)
+        if current is None:
+            return
+        if self.monitor_op(current - self.min_delta, self.best_score):
+            self.best_score = current
+            is_best = True
+        else:
+            is_best = False
+
+        self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor)
+
+    def get_monitor_value(self, logs: Dict) -> Union[float, None]:
+        """Extracts the monitored value."""
+        monitor_value = logs.get(self.monitor)
+        if monitor_value is None:
+            logger.warning(
+                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
+                + f" metrics are: {','.join(list(logs.keys()))}"
+            )
+            return None
+        return monitor_value
diff --git a/training/trainer/callbacks/early_stopping.py b/training/trainer/callbacks/early_stopping.py
new file mode 100644
index 0000000..02b431f
--- /dev/null
+++ b/training/trainer/callbacks/early_stopping.py
@@ -0,0 +1,108 @@
+"""Implements Early stopping for PyTorch model."""
+from typing import Dict, Union
+
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from training.trainer.callbacks import Callback
+
+
+class EarlyStopping(Callback):
+    """Stops training when a monitored metric stops improving."""
+
+    mode_dict = {
+        "min": torch.lt,
+        "max": torch.gt,
+    }
+
+    def __init__(
+        self,
+        monitor: str = "val_loss",
+        min_delta: float = 0.0,
+        patience: int = 3,
+        mode: str = "auto",
+    ) -> None:
+        """Initializes the EarlyStopping callback.
+
+        Args:
+            monitor (str): Description of parameter `monitor`. Defaults to "val_loss".
+            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
+            patience (int): Description of parameter `patience`. Defaults to 3.
+            mode (str): Description of parameter `mode`. Defaults to "auto".
+
+        """
+        super().__init__()
+        self.monitor = monitor
+        self.patience = patience
+        self.min_delta = torch.tensor(min_delta)
+        self.mode = mode
+        self.wait_count = 0
+        self.stopped_epoch = 0
+
+        if mode not in ["auto", "min", "max"]:
+            logger.warning(
+                f"EarlyStopping mode {mode} is unkown, fallback to auto mode."
+            )
+
+            self.mode = "auto"
+
+        if self.mode == "auto":
+            if "accuracy" in self.monitor:
+                self.mode = "max"
+            else:
+                self.mode = "min"
+            logger.debug(
+                f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}."
+            )
+
+        self.torch_inf = torch.tensor(np.inf)
+        self.min_delta *= 1 if self.monitor_op == torch.gt else -1
+        self.best_score = (
+            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
+        )
+
+    @property
+    def monitor_op(self) -> float:
+        """Returns the comparison method."""
+        return self.mode_dict[self.mode]
+
+    def on_fit_begin(self) -> Union[torch.lt, torch.gt]:
+        """Reset the early stopping variables for reuse."""
+        self.wait_count = 0
+        self.stopped_epoch = 0
+        self.best_score = (
+            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
+        )
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Computes the early stop criterion."""
+        current = self.get_monitor_value(logs)
+        if current is None:
+            return
+        if self.monitor_op(current - self.min_delta, self.best_score):
+            self.best_score = current
+            self.wait_count = 0
+        else:
+            self.wait_count += 1
+            if self.wait_count >= self.patience:
+                self.stopped_epoch = epoch
+                self.model.stop_training = True
+
+    def on_fit_end(self) -> None:
+        """Logs if early stopping was used."""
+        if self.stopped_epoch > 0:
+            logger.info(
+                f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
+            )
+
+    def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]:
+        """Extracts the monitor value."""
+        monitor_value = logs.get(self.monitor)
+        if monitor_value is None:
+            logger.warning(
+                f"Early stopping is conditioned on metric {self.monitor} which is not available. Available"
+                + f"metrics are: {','.join(list(logs.keys()))}"
+            )
+            return None
+        return torch.tensor(monitor_value)
diff --git a/training/trainer/callbacks/lr_schedulers.py b/training/trainer/callbacks/lr_schedulers.py
new file mode 100644
index 0000000..630c434
--- /dev/null
+++ b/training/trainer/callbacks/lr_schedulers.py
@@ -0,0 +1,77 @@
+"""Callbacks for learning rate schedulers."""
+from typing import Callable, Dict, List, Optional, Type
+
+from torch.optim.swa_utils import update_bn
+from training.trainer.callbacks import Callback
+
+from text_recognizer.models import Model
+
+
+class LRScheduler(Callback):
+    """Generic learning rate scheduler callback."""
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Sets the model and lr scheduler."""
+        self.model = model
+        self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+        self.interval = self.model.lr_scheduler["interval"]
+
+    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Takes a step at the end of every epoch."""
+        if self.interval == "epoch":
+            if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__:
+                self.lr_scheduler.step(logs["val_loss"])
+            else:
+                self.lr_scheduler.step()
+
+    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Takes a step at the end of every training batch."""
+        if self.interval == "step":
+            self.lr_scheduler.step()
+
+
+class SWA(Callback):
+    """Stochastic Weight Averaging callback."""
+
+    def __init__(self) -> None:
+        """Initializes the callback."""
+        super().__init__()
+        self.lr_scheduler = None
+        self.interval = None
+        self.swa_scheduler = None
+        self.swa_start = None
+        self.current_epoch = 1
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Sets the model and lr scheduler."""
+        self.model = model
+        self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+        self.interval = self.model.lr_scheduler["interval"]
+        self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"]
+        self.swa_start = self.model.swa_scheduler["swa_start"]
+
+    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
+        """Takes a step at the end of every training batch."""
+        if epoch > self.swa_start:
+            self.model.swa_network.update_parameters(self.model.network)
+            self.swa_scheduler.step()
+        elif self.interval == "epoch":
+            self.lr_scheduler.step()
+        self.current_epoch = epoch
+
+    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Takes a step at the end of every training batch."""
+        if self.current_epoch < self.swa_start and self.interval == "step":
+            self.lr_scheduler.step()
+
+    def on_fit_end(self) -> None:
+        """Update batch norm statistics for the swa model at the end of training."""
+        if self.model.swa_network:
+            update_bn(
+                self.model.val_dataloader(),
+                self.model.swa_network,
+                device=self.model.device,
+            )
diff --git a/training/trainer/callbacks/progress_bar.py b/training/trainer/callbacks/progress_bar.py
new file mode 100644
index 0000000..6c4305a
--- /dev/null
+++ b/training/trainer/callbacks/progress_bar.py
@@ -0,0 +1,65 @@
+"""Progress bar callback for the training loop."""
+from typing import Dict, Optional
+
+from tqdm import tqdm
+from training.trainer.callbacks import Callback
+
+
+class ProgressBar(Callback):
+    """A TQDM progress bar for the training loop."""
+
+    def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
+        """Initializes the tqdm callback."""
+        self.epochs = epochs
+        print(epochs, type(epochs))
+        self.log_batch_frequency = log_batch_frequency
+        self.progress_bar = None
+        self.val_metrics = {}
+
+    def _configure_progress_bar(self) -> None:
+        """Configures the tqdm progress bar with custom bar format."""
+        self.progress_bar = tqdm(
+            total=len(self.model.train_dataloader()),
+            leave=False,
+            unit="steps",
+            mininterval=self.log_batch_frequency,
+            bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
+        )
+
+    def _key_abbreviations(self, logs: Dict) -> Dict:
+        """Changes the length of keys, so that the progress bar fits better."""
+
+        def rename(key: str) -> str:
+            """Renames accuracy to acc."""
+            return key.replace("accuracy", "acc")
+
+        return {rename(key): value for key, value in logs.items()}
+
+    # def on_fit_begin(self) -> None:
+    #     """Creates a tqdm progress bar."""
+    #     self._configure_progress_bar()
+
+    def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
+        """Updates the description with the current epoch."""
+        if epoch == 1:
+            self._configure_progress_bar()
+        else:
+            self.progress_bar.reset()
+        self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """At the end of each epoch, the validation metrics are updated to the progress bar."""
+        self.val_metrics = logs
+        self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+        self.progress_bar.update()
+
+    def on_train_batch_end(self, batch: int, logs: Dict) -> None:
+        """Updates the progress bar for each training step."""
+        if self.val_metrics:
+            logs.update(self.val_metrics)
+        self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+        self.progress_bar.update()
+
+    def on_fit_end(self) -> None:
+        """Closes the tqdm progress bar."""
+        self.progress_bar.close()
diff --git a/training/trainer/callbacks/wandb_callbacks.py b/training/trainer/callbacks/wandb_callbacks.py
new file mode 100644
index 0000000..552a4f4
--- /dev/null
+++ b/training/trainer/callbacks/wandb_callbacks.py
@@ -0,0 +1,261 @@
+"""Callback for W&B."""
+from typing import Callable, Dict, List, Optional, Type
+
+import numpy as np
+from training.trainer.callbacks import Callback
+import wandb
+
+import text_recognizer.datasets.transforms as transforms
+from text_recognizer.models.base import Model
+
+
+class WandbCallback(Callback):
+    """A custom W&B metric logger for the trainer."""
+
+    def __init__(self, log_batch_frequency: int = None) -> None:
+        """Short summary.
+
+        Args:
+            log_batch_frequency (int): If None, metrics will be logged every epoch.
+                If set to an integer, callback will log every metrics every log_batch_frequency.
+
+        """
+        super().__init__()
+        self.log_batch_frequency = log_batch_frequency
+
+    def _on_batch_end(self, batch: int, logs: Dict) -> None:
+        if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
+            wandb.log(logs, commit=True)
+
+    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Logs training metrics."""
+        if logs is not None:
+            logs["lr"] = self.model.optimizer.param_groups[0]["lr"]
+            self._on_batch_end(batch, logs)
+
+    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+        """Logs validation metrics."""
+        if logs is not None:
+            self._on_batch_end(batch, logs)
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Logs at epoch end."""
+        wandb.log(logs, commit=True)
+
+
+class WandbImageLogger(Callback):
+    """Custom W&B callback for image logging."""
+
+    def __init__(
+        self,
+        example_indices: Optional[List] = None,
+        num_examples: int = 4,
+        transform: Optional[bool] = None,
+    ) -> None:
+        """Initializes the WandbImageLogger with the model to train.
+
+        Args:
+            example_indices (Optional[List]): Indices for validation images. Defaults to None.
+            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+            transform (Optional[Dict]): Use transform on image or not. Defaults to None.
+
+        """
+
+        super().__init__()
+        self.caption = None
+        self.example_indices = example_indices
+        self.test_sample_indices = None
+        self.num_examples = num_examples
+        self.transform = (
+            self._configure_transform(transform) if transform is not None else None
+        )
+
+    def _configure_transform(self, transform: Dict) -> Callable:
+        args = transform["args"] or {}
+        return getattr(transforms, transform["type"])(**args)
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Sets the model and extracts validation images from the dataset."""
+        self.model = model
+        self.caption = "Validation Examples"
+        if self.example_indices is None:
+            self.example_indices = np.random.randint(
+                0, len(self.model.val_dataset), self.num_examples
+            )
+        self.images = self.model.val_dataset.dataset.data[self.example_indices]
+        self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
+        self.targets = self.targets.tolist()
+
+    def on_test_begin(self) -> None:
+        """Get samples from test dataset."""
+        self.caption = "Test Examples"
+        if self.test_sample_indices is None:
+            self.test_sample_indices = np.random.randint(
+                0, len(self.model.test_dataset), self.num_examples
+            )
+        self.images = self.model.test_dataset.data[self.test_sample_indices]
+        self.targets = self.model.test_dataset.targets[self.test_sample_indices]
+        self.targets = self.targets.tolist()
+
+    def on_test_end(self) -> None:
+        """Log test images."""
+        self.on_epoch_end(0, {})
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Get network predictions on validation images."""
+        images = []
+        for i, image in enumerate(self.images):
+            image = self.transform(image) if self.transform is not None else image
+            pred, conf = self.model.predict_on_image(image)
+            if isinstance(self.targets[i], list):
+                ground_truth = "".join(
+                    [
+                        self.model.mapper(int(target_index) - 26)
+                        if target_index > 35
+                        else self.model.mapper(int(target_index))
+                        for target_index in self.targets[i]
+                    ]
+                ).rstrip("_")
+            else:
+                ground_truth = self.model.mapper(int(self.targets[i]))
+            caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
+            images.append(wandb.Image(image, caption=caption))
+
+        wandb.log({f"{self.caption}": images}, commit=False)
+
+
+class WandbSegmentationLogger(Callback):
+    """Custom W&B callback for image logging."""
+
+    def __init__(
+        self,
+        class_labels: Dict,
+        example_indices: Optional[List] = None,
+        num_examples: int = 4,
+    ) -> None:
+        """Initializes the WandbImageLogger with the model to train.
+
+        Args:
+            class_labels (Dict): A dict with int as key and class string as value.
+            example_indices (Optional[List]): Indices for validation images. Defaults to None.
+            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+
+        """
+
+        super().__init__()
+        self.caption = None
+        self.class_labels = {int(k): v for k, v in class_labels.items()}
+        self.example_indices = example_indices
+        self.test_sample_indices = None
+        self.num_examples = num_examples
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Sets the model and extracts validation images from the dataset."""
+        self.model = model
+        self.caption = "Validation Segmentation Examples"
+        if self.example_indices is None:
+            self.example_indices = np.random.randint(
+                0, len(self.model.val_dataset), self.num_examples
+            )
+        self.images = self.model.val_dataset.dataset.data[self.example_indices]
+        self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
+        self.targets = self.targets.tolist()
+
+    def on_test_begin(self) -> None:
+        """Get samples from test dataset."""
+        self.caption = "Test Segmentation Examples"
+        if self.test_sample_indices is None:
+            self.test_sample_indices = np.random.randint(
+                0, len(self.model.test_dataset), self.num_examples
+            )
+        self.images = self.model.test_dataset.data[self.test_sample_indices]
+        self.targets = self.model.test_dataset.targets[self.test_sample_indices]
+        self.targets = self.targets.tolist()
+
+    def on_test_end(self) -> None:
+        """Log test images."""
+        self.on_epoch_end(0, {})
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Get network predictions on validation images."""
+        images = []
+        for i, image in enumerate(self.images):
+            pred_mask = (
+                self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
+            )
+            gt_mask = np.array(self.targets[i])
+            images.append(
+                wandb.Image(
+                    image,
+                    masks={
+                        "predictions": {
+                            "mask_data": pred_mask,
+                            "class_labels": self.class_labels,
+                        },
+                        "ground_truth": {
+                            "mask_data": gt_mask,
+                            "class_labels": self.class_labels,
+                        },
+                    },
+                )
+            )
+
+        wandb.log({f"{self.caption}": images}, commit=False)
+
+
+class WandbReconstructionLogger(Callback):
+    """Custom W&B callback for image reconstructions logging."""
+
+    def __init__(
+        self, example_indices: Optional[List] = None, num_examples: int = 4,
+    ) -> None:
+        """Initializes the  WandbImageLogger with the model to train.
+
+        Args:
+            example_indices (Optional[List]): Indices for validation images. Defaults to None.
+            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
+
+        """
+
+        super().__init__()
+        self.caption = None
+        self.example_indices = example_indices
+        self.test_sample_indices = None
+        self.num_examples = num_examples
+
+    def set_model(self, model: Type[Model]) -> None:
+        """Sets the model and extracts validation images from the dataset."""
+        self.model = model
+        self.caption = "Validation Reconstructions Examples"
+        if self.example_indices is None:
+            self.example_indices = np.random.randint(
+                0, len(self.model.val_dataset), self.num_examples
+            )
+        self.images = self.model.val_dataset.dataset.data[self.example_indices]
+
+    def on_test_begin(self) -> None:
+        """Get samples from test dataset."""
+        self.caption = "Test Reconstructions Examples"
+        if self.test_sample_indices is None:
+            self.test_sample_indices = np.random.randint(
+                0, len(self.model.test_dataset), self.num_examples
+            )
+        self.images = self.model.test_dataset.data[self.test_sample_indices]
+
+    def on_test_end(self) -> None:
+        """Log test images."""
+        self.on_epoch_end(0, {})
+
+    def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+        """Get network predictions on validation images."""
+        images = []
+        for image in self.images:
+            reconstructed_image = (
+                self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
+            )
+            images.append(image)
+            images.append(reconstructed_image)
+
+        wandb.log(
+            {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False,
+        )
diff --git a/training/trainer/train.py b/training/trainer/train.py
new file mode 100644
index 0000000..b770c94
--- /dev/null
+++ b/training/trainer/train.py
@@ -0,0 +1,325 @@
+"""Training script for PyTorch models."""
+
+from pathlib import Path
+import time
+from typing import Dict, List, Optional, Tuple, Type
+import warnings
+
+from einops import rearrange
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from torch.optim.swa_utils import update_bn
+from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA
+from training.trainer.util import log_val_metric
+import wandb
+
+from text_recognizer.models import Model
+
+
+torch.backends.cudnn.benchmark = True
+np.random.seed(4711)
+torch.manual_seed(4711)
+torch.cuda.manual_seed(4711)
+
+
+warnings.filterwarnings("ignore")
+
+
+class Trainer:
+    """Trainer for training PyTorch models."""
+
+    def __init__(
+        self,
+        max_epochs: int,
+        callbacks: List[Type[Callback]],
+        transformer_model: bool = False,
+        max_norm: float = 0.0,
+        freeze_backbone: Optional[int] = None,
+    ) -> None:
+        """Initialization of the Trainer.
+
+        Args:
+            max_epochs (int): The maximum number of epochs in the training loop.
+            callbacks (CallbackList): List of callbacks to be called.
+            transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
+            max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0.
+            freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training
+                Transformers. Default is None.
+
+        """
+        # Training arguments.
+        self.start_epoch = 1
+        self.max_epochs = max_epochs
+        self.callbacks = callbacks
+        self.freeze_backbone = freeze_backbone
+
+        # Flag for setting callbacks.
+        self.callbacks_configured = False
+
+        self.transformer_model = transformer_model
+
+        self.max_norm = max_norm
+
+        # Model placeholders
+        self.model = None
+
+    def _configure_callbacks(self) -> None:
+        """Instantiate the CallbackList."""
+        if not self.callbacks_configured:
+            # If learning rate schedulers are present, they need to be added to the callbacks.
+            if self.model.swa_scheduler is not None:
+                self.callbacks.append(SWA())
+            elif self.model.lr_scheduler is not None:
+                self.callbacks.append(LRScheduler())
+
+            self.callbacks = CallbackList(self.model, self.callbacks)
+
+    def compute_metrics(
+        self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int
+    ) -> Dict:
+        """Computes metrics for output and target pairs."""
+        # Compute metrics.
+        loss = loss.detach().float().item()
+        output = output.detach()
+        targets = targets.detach()
+        if self.model.metrics is not None:
+            metrics = {}
+            for metric in self.model.metrics:
+                if metric == "cer" or metric == "wer":
+                    metrics[metric] = self.model.metrics[metric](
+                        output,
+                        targets,
+                        batch_size,
+                        self.model.mapper(self.model.pad_token),
+                    )
+                else:
+                    metrics[metric] = self.model.metrics[metric](output, targets)
+        else:
+            metrics = {}
+        metrics["loss"] = loss
+
+        return metrics
+
+    def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
+        """Performs the training step."""
+        # Pass the tensor to the device for computation.
+        data, targets = samples
+        data, targets = (
+            data.to(self.model.device),
+            targets.to(self.model.device),
+        )
+
+        batch_size = data.shape[0]
+
+        # Placeholder for uxiliary loss.
+        aux_loss = None
+
+        # Forward pass.
+        # Get the network prediction.
+        if self.transformer_model:
+            if self.freeze_backbone is not None and batch < self.freeze_backbone:
+                with torch.no_grad():
+                    image_features = self.model.network.extract_image_features(data)
+
+                if isinstance(image_features, Tuple):
+                    image_features, _ = image_features
+
+                output = self.model.network.decode_image_features(
+                    image_features, targets[:, :-1]
+                )
+            else:
+                output = self.model.network.forward(data, targets[:, :-1])
+                if isinstance(output, Tuple):
+                    output, aux_loss = output
+            output = rearrange(output, "b t v -> (b t) v")
+            targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
+        else:
+            output = self.model.forward(data)
+
+            if isinstance(output, Tuple):
+                output, aux_loss = output
+                targets = data
+
+        # Compute the loss.
+        loss = self.model.criterion(output, targets)
+
+        if aux_loss is not None:
+            loss += aux_loss
+
+        # Backward pass.
+        # Clear the previous gradients.
+        for p in self.model.network.parameters():
+            p.grad = None
+
+        # Compute the gradients.
+        loss.backward()
+
+        if self.max_norm > 0:
+            torch.nn.utils.clip_grad_norm_(
+                self.model.network.parameters(), self.max_norm
+            )
+
+        # Perform updates using calculated gradients.
+        self.model.optimizer.step()
+
+        metrics = self.compute_metrics(output, targets, loss, batch_size)
+
+        return metrics
+
+    def train(self) -> None:
+        """Runs the training loop for one epoch."""
+        # Set model to traning mode.
+        self.model.train()
+
+        for batch, samples in enumerate(self.model.train_dataloader()):
+            self.callbacks.on_train_batch_begin(batch)
+            metrics = self.training_step(batch, samples)
+            self.callbacks.on_train_batch_end(batch, logs=metrics)
+
+    @torch.no_grad()
+    def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
+        """Performs the validation step."""
+
+        # Pass the tensor to the device for computation.
+        data, targets = samples
+        data, targets = (
+            data.to(self.model.device),
+            targets.to(self.model.device),
+        )
+
+        batch_size = data.shape[0]
+
+        # Placeholder for uxiliary loss.
+        aux_loss = None
+
+        # Forward pass.
+        # Get the network prediction.
+        # Use SWA if available and using test dataset.
+        if self.transformer_model:
+            output = self.model.network.forward(data, targets[:, :-1])
+            if isinstance(output, Tuple):
+                output, aux_loss = output
+            output = rearrange(output, "b t v -> (b t) v")
+            targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
+        else:
+            output = self.model.forward(data)
+
+            if isinstance(output, Tuple):
+                output, aux_loss = output
+                targets = data
+
+        # Compute the loss.
+        loss = self.model.criterion(output, targets)
+
+        if aux_loss is not None:
+            loss += aux_loss
+
+        # Compute metrics.
+        metrics = self.compute_metrics(output, targets, loss, batch_size)
+
+        return metrics
+
+    def validate(self) -> Dict:
+        """Runs the validation loop for one epoch."""
+        # Set model to eval mode.
+        self.model.eval()
+
+        # Summary for the current eval loop.
+        summary = []
+
+        for batch, samples in enumerate(self.model.val_dataloader()):
+            self.callbacks.on_validation_batch_begin(batch)
+            metrics = self.validation_step(batch, samples)
+            self.callbacks.on_validation_batch_end(batch, logs=metrics)
+            summary.append(metrics)
+
+        # Compute mean of all metrics.
+        metrics_mean = {
+            "val_" + metric: np.mean([x[metric] for x in summary])
+            for metric in summary[0]
+        }
+
+        return metrics_mean
+
+    def fit(self, model: Type[Model]) -> None:
+        """Runs the training and evaluation loop."""
+
+        # Sets model, loads the data, criterion, and optimizers.
+        self.model = model
+        self.model.prepare_data()
+        self.model.configure_model()
+
+        # Configure callbacks.
+        self._configure_callbacks()
+
+        # Set start time.
+        t_start = time.time()
+
+        self.callbacks.on_fit_begin()
+
+        # Run the training loop.
+        for epoch in range(self.start_epoch, self.max_epochs + 1):
+            self.callbacks.on_epoch_begin(epoch)
+
+            # Perform one training pass over the training set.
+            self.train()
+
+            # Evaluate the model on the validation set.
+            val_metrics = self.validate()
+            log_val_metric(val_metrics, epoch)
+
+            self.callbacks.on_epoch_end(epoch, logs=val_metrics)
+
+            if self.model.stop_training:
+                break
+
+        # Calculate the total training time.
+        t_end = time.time()
+        t_training = t_end - t_start
+
+        self.callbacks.on_fit_end()
+
+        logger.info(f"Training took {t_training:.2f} s.")
+
+        # "Teardown".
+        self.model = None
+
+    def test(self, model: Type[Model]) -> Dict:
+        """Run inference on test data."""
+
+        # Sets model, loads the data, criterion, and optimizers.
+        self.model = model
+        self.model.prepare_data()
+        self.model.configure_model()
+
+        # Configure callbacks.
+        self._configure_callbacks()
+
+        self.callbacks.on_test_begin()
+
+        self.model.eval()
+
+        # Check if SWA network is available.
+        self.model.use_swa_model()
+
+        # Summary for the current test loop.
+        summary = []
+
+        for batch, samples in enumerate(self.model.test_dataloader()):
+            metrics = self.validation_step(batch, samples)
+            summary.append(metrics)
+
+        self.callbacks.on_test_end()
+
+        # Compute mean of all test metrics.
+        metrics_mean = {
+            "test_" + metric: np.mean([x[metric] for x in summary])
+            for metric in summary[0]
+        }
+
+        # "Teardown".
+        self.model = None
+
+        return metrics_mean
diff --git a/training/trainer/util.py b/training/trainer/util.py
new file mode 100644
index 0000000..7cf1b45
--- /dev/null
+++ b/training/trainer/util.py
@@ -0,0 +1,28 @@
+"""Utility functions for training neural networks."""
+from typing import Dict, Optional
+
+from loguru import logger
+
+
+def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None:
+    """Logging of val metrics to file/terminal."""
+    log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
+    logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()))
+
+
+class RunningAverage:
+    """Maintains a running average."""
+
+    def __init__(self) -> None:
+        """Initializes the parameters."""
+        self.steps = 0
+        self.total = 0
+
+    def update(self, val: float) -> None:
+        """Updates the parameters."""
+        self.total += val
+        self.steps += 1
+
+    def __call__(self) -> float:
+        """Computes the running average."""
+        return self.total / float(self.steps)
-- 
cgit v1.2.3-70-g09d2