diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 20:47:55 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 20:47:55 +0200 | 
| commit | 9ae5fa1a88899180f88ddb14d4cef457ceb847e5 (patch) | |
| tree | 4fe2bcd82553c8062eb0908ae6442c123addf55d | |
| parent | 9e54591b7e342edc93b0bb04809a0f54045c6a15 (diff) | |
Add new training loop with PyTorch Lightning, remove stale files
26 files changed, 167 insertions, 2229 deletions
| @@ -32,11 +32,11 @@ poetry run build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb    - [x] transform that encodes iam targets to wordpieces    - [x] transducer loss function  - [  ] Train with word pieces -  - [  ] implement wandb callback for logging +- [ ] Local attention in first layer of transformer  +- [ ] Halonet encoder  - [  ] Implement CPC -  - [  ] Window images -  - [  ] Train backbone -- [  ] Bert training, how? +  - [ ] https://arxiv.org/pdf/1905.09272.pdf +  - [ ] https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html?highlight=byol  - [ ] Predictive coding @@ -60,20 +60,3 @@ wandb agent $SWEEP_ID  ``` -## PyTorch Performance Guide -Tips and tricks from ["PyTorch Performance Tuning Guide - Szymon Migacz, NVIDIA"](https://www.youtube.com/watch?v=9mS1fIYj1So&t=125s): - -* Always better to use `num_workers > 0`, allows asynchronous data processing -* Use `pin_memory=True` to allow data loading and computations to happen on the GPU in parallel. -* Have to tune `num_workers` to use based on the problem, too many and data loading becomes slower. -* For CNNs use `torch.backends.cudnn.benchmark=True`, allows cuDNN to select the best algorithm for convolutional computations (autotuner). -* Increase batch size to max out GPU memory. -* Use optimizer for large batch training, e.g. LARS, LAMB etc. -* Set `bias=False` for convolutions directly followed by BatchNorm. -* Use `for p in model.parameters(): p.grad = None` instead of `model.zero_grad()`. -* Careful with disable debug APIs in prod (detect_anomaly, profiler, gradcheck). -* Use `DistributedDataParallel` not `DataParallel`, uses 1 CPU core for each GPU. -* Important to load balance compute on all GPUs, if variably-sized inputs or GPUs will idle. -* Use an apex fused optimizer -* Use checkpointing to recompute memory-intensive compute-efficient ops in backward pass (e.g. activations, upsampling), `torch.utils.checkpoint`. -* Use `@torch.jit.script`, especially to fuse long sequences of pointwise operations like GELU. diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb index 0fa5a1b..81957f7 100644 --- a/notebooks/00-testing-stuff-out.ipynb +++ b/notebooks/00-testing-stuff-out.ipynb @@ -25,8 +25,52 @@    },    {     "cell_type": "code", -   "execution_count": 1, +   "execution_count": 5, +   "metadata": {}, +   "outputs": [], +   "source": [ +    "class Hej:\n", +    "    a = 2\n", +    "    \n", +    "class Hejjj:\n", +    "    b = 1" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 7,     "metadata": {}, +   "outputs": [], +   "source": [ +    "l = [Hej(), Hejjj()]" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 10, +   "metadata": {}, +   "outputs": [ +    { +     "data": { +      "text/plain": [ +       "<__main__.Hej at 0x7efefc77f370>" +      ] +     }, +     "execution_count": 10, +     "metadata": {}, +     "output_type": "execute_result" +    } +   ], +   "source": [ +    "next(o for o in l if isinstance(o, Hej))" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 1, +   "metadata": { +    "scrolled": true +   },     "outputs": [      {       "data": { diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 46e5136..2d6e435 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,5 +1,5 @@  """Base PyTorch Lightning model.""" -from typing import Any, Dict, Tuple, Type +from typing import Any, Dict, List, Tuple, Type  import madgrad  import pytorch_lightning as pl @@ -40,7 +40,7 @@ class LitBaseModel(pl.LightningModule):          args = {} or criterion_args["args"]          return getattr(nn, criterion_args["type"])(**args) -    def configure_optimizer(self) -> Dict[str, Any]: +    def configure_optimizer(self) -> Tuple[List[type], List[Dict[str, Any]]]:          """Configures optimizer and lr scheduler."""          args = {} or self.optimizer_args["args"]          if self.optimizer_args["type"] == "MADGRAD": @@ -48,15 +48,15 @@ class LitBaseModel(pl.LightningModule):          else:              optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args) +        scheduler = {"monitor": self.monitor}          args = {} or self.lr_scheduler_args["args"] -        scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])( -            **args -        ) -        return { -            "optimizer": optimizer, -            "lr_scheduler": scheduler, -            "monitor": self.monitor, -        } +        if "interval" in args: +            scheduler["interval"] = args.pop("interval") + +        scheduler["scheduler"] = getattr( +            torch.optim.lr_scheduler, self.lr_scheduler_args["type"] +        )(**args) +        return [optimizer], [scheduler]      def forward(self, data: Tensor) -> Tensor:          """Feedforward pass.""" diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py index b489264..cb83608 100644 --- a/text_recognizer/networks/loss/__init__.py +++ b/text_recognizer/networks/loss/__init__.py @@ -1,2 +1,2 @@  """Loss module.""" -from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy +from .loss import LabelSmoothingCrossEntropy diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py index cf9fa0d..d12dc9c 100644 --- a/text_recognizer/networks/loss/loss.py +++ b/text_recognizer/networks/loss/loss.py @@ -1,39 +1,9 @@  """Implementations of custom loss functions.""" -from pytorch_metric_learning import distances, losses, miners, reducers  import torch  from torch import nn  from torch import Tensor -from torch.autograd import Variable -import torch.nn.functional as F -__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] - - -class EmbeddingLoss: -    """Metric loss for training encoders to produce information-rich latent embeddings.""" - -    def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: -        self.distance = distances.CosineSimilarity() -        self.reducer = reducers.ThresholdReducer(low=0) -        self.loss_fn = losses.TripletMarginLoss( -            margin=margin, distance=self.distance, reducer=self.reducer -        ) -        self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) - -    def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: -        """Computes the metric loss for the embeddings based on their labels. - -        Args: -            embeddings (Tensor): The laten vectors encoded by the network. -            labels (Tensor): Labels of the embeddings. - -        Returns: -            Tensor: The metric loss for the embeddings. - -        """ -        hard_pairs = self.miner(embeddings, labels) -        loss = self.loss_fn(embeddings, labels, hard_pairs) -        return loss +__all__ = ["LabelSmoothingCrossEntropy"]  class LabelSmoothingCrossEntropy(nn.Module): diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 131a6b4..d292680 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -1,38 +1,13 @@  """Miscellaneous neural network functionality."""  import importlib  from pathlib import Path -from typing import Dict, Tuple, Type +from typing import Dict, Type -from einops import rearrange  from loguru import logger  import torch  from torch import nn -def sliding_window( -    images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] -) -> torch.Tensor: -    """Creates patches of an image. - -    Args: -        images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). -        patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. -        stride (Tuple[int, int]): The stride of the sliding window. - -    Returns: -        torch.Tensor: A tensor with the shape (batch, patches, height, width). - -    """ -    unfold = nn.Unfold(kernel_size=patch_size, stride=stride) -    # Preform the sliding window, unsqueeze as the channel dimesion is lost. -    c = images.shape[1] -    patches = unfold(images) -    patches = rearrange( -        patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1], -    ) -    return patches - -  def activation_function(activation: str) -> Type[nn.Module]:      """Returns the callable activation function."""      activation_fns = nn.ModuleDict( diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py deleted file mode 100644 index c673d96..0000000 --- a/text_recognizer/networks/vq_transformer.py +++ /dev/null @@ -1,150 +0,0 @@ -"""A VQ-Transformer for image to text recognition.""" -from typing import Dict, Optional, Tuple - -from einops import rearrange, repeat -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.util import configure_backbone -from text_recognizer.networks.vqvae.encoder import _ResidualBlock - - -class VQTransformer(nn.Module): -    """VQ+Transfomer for image to character sequence prediction.""" - -    def __init__( -        self, -        num_encoder_layers: int, -        num_decoder_layers: int, -        hidden_dim: int, -        vocab_size: int, -        num_heads: int, -        adaptive_pool_dim: Tuple, -        expansion_dim: int, -        dropout_rate: float, -        trg_pad_index: int, -        max_len: int, -        backbone: str, -        backbone_args: Optional[Dict] = None, -        activation: str = "gelu", -    ) -> None: -        super().__init__() - -        # Configure vector quantized backbone. -        self.backbone = configure_backbone(backbone, backbone_args) -        self.conv = nn.Sequential( -            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2), -            nn.ReLU(inplace=True), -        ) - -        # Configure embeddings for Transformer network. -        self.trg_pad_index = trg_pad_index -        self.vocab_size = vocab_size -        self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) -        self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) -        self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) -        nn.init.normal_(self.character_embedding.weight, std=0.02) - -        self.adaptive_pool = ( -            nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None -        ) - -        self.transformer = Transformer( -            num_encoder_layers, -            num_decoder_layers, -            hidden_dim, -            num_heads, -            expansion_dim, -            dropout_rate, -            activation, -        ) - -        self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) - -    def _create_trg_mask(self, trg: Tensor) -> Tensor: -        # Move this outside the transformer. -        trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] -        trg_len = trg.shape[1] -        trg_sub_mask = torch.tril( -            torch.ones((trg_len, trg_len), device=trg.device) -        ).bool() -        trg_mask = trg_pad_mask & trg_sub_mask -        return trg_mask - -    def encoder(self, src: Tensor) -> Tensor: -        """Forward pass with the encoder of the transformer.""" -        return self.transformer.encoder(src) - -    def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: -        """Forward pass with the decoder of the transformer + classification head.""" -        return self.head( -            self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) -        ) - -    def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]: -        """Extracts image features with a backbone neural network. - -        It seem like the winning idea was to swap channels and width dimension and collapse -        the height dimension. The transformer is learning like a baby with this implementation!!! :D -        Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D - -        Args: -            src (Tensor): Input tensor. - -        Returns: -            Tensor: The input src to the transformer and the vq loss. - -        """ -        # If batch dimension is missing, it needs to be added. -        if len(src.shape) < 4: -            src = src[(None,) * (4 - len(src.shape))] -        src, vq_loss = self.backbone.encode(src) -        # src = self.backbone.decoder.res_block(src) -        src = self.conv(src) - -        if self.adaptive_pool is not None: -            src = rearrange(src, "b c h w -> b w c h") -            src = self.adaptive_pool(src) -            src = src.squeeze(3) -        else: -            src = rearrange(src, "b c h w -> b (w h) c") - -        b, t, _ = src.shape - -        src += self.src_position_embedding[:, :t] - -        return src, vq_loss - -    def target_embedding(self, trg: Tensor) -> Tensor: -        """Encodes target tensor with embedding and postion. - -        Args: -            trg (Tensor): Target tensor. - -        Returns: -            Tensor: Encoded target tensor. - -        """ -        trg = self.character_embedding(trg.long()) -        trg = self.trg_position_encoding(trg) -        return trg - -    def decode_image_features( -        self, image_features: Tensor, trg: Optional[Tensor] = None -    ) -> Tensor: -        """Takes images features from the backbone and decodes them with the transformer.""" -        trg_mask = self._create_trg_mask(trg) -        trg = self.target_embedding(trg) -        out = self.transformer(image_features, trg, trg_mask=trg_mask) - -        logits = self.head(out) -        return logits - -    def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: -        """Forward pass with CNN transfomer.""" -        image_features, vq_loss = self.extract_image_features(x) -        logits = self.decode_image_features(image_features, trg) -        return logits, vq_loss diff --git a/training/experiments/default_config_emnist.yml b/training/experiments/default_config_emnist.yml deleted file mode 100644 index bf2ed0a..0000000 --- a/training/experiments/default_config_emnist.yml +++ /dev/null @@ -1,70 +0,0 @@ -dataset: EmnistDataset -dataset_args: -  sample_to_balance: true -  subsample_fraction: 0.33 -  transform: null -  target_transform: null -  seed: 4711 - -data_loader_args: -  splits: [train, val] -  shuffle: true -  num_workers: 8 -  cuda: true - -model: CharacterModel -metrics: [accuracy] - -network_args: -  in_channels: 1 -  num_classes: 80 -  depths: [2] -  block_sizes: [256] - -train_args: -  batch_size: 256 -  epochs: 5 - -criterion: CrossEntropyLoss -criterion_args: -  weight: null -  ignore_index: -100 -  reduction: mean - -optimizer: AdamW -optimizer_args: -  lr: 1.e-03 -  betas: [0.9, 0.999] -  eps: 1.e-08 -  # weight_decay: 5.e-4 -  amsgrad: false - -lr_scheduler: OneCycleLR -lr_scheduler_args: -  max_lr: 1.e-03 -  epochs: 5 -  anneal_strategy: linear - - -callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] -callback_args: -  Checkpoint: -    monitor: val_accuracy -  ProgressBar: -    epochs: 5 -    log_batch_frequency: 100 -  EarlyStopping: -    monitor: val_loss -    min_delta: 0.0 -    patience: 3 -    mode: min -  WandbCallback: -    log_batch_frequency: 10 -  WandbImageLogger: -    num_examples: 4 -  OneCycleLR: -    null -verbosity: 1 # 0, 1, 2 -resume_experiment: null -train: true -validation_metric: val_accuracy diff --git a/training/experiments/embedding_experiment.yml b/training/experiments/embedding_experiment.yml deleted file mode 100644 index 1e5f941..0000000 --- a/training/experiments/embedding_experiment.yml +++ /dev/null @@ -1,64 +0,0 @@ -experiment_group: Embedding Experiments -experiments: -    - train_args: -        transformer_model: false -        batch_size: &batch_size 256 -        max_epochs: &max_epochs 32 -        input_shape: [[1, 28, 28]] -      dataset: -        type: EmnistDataset -        args: -          sample_to_balance: true -          subsample_fraction: null -          transform: null -          target_transform: null -          seed: 4711 -        train_args: -          num_workers: 8 -          train_fraction: 0.85 -          batch_size: *batch_size -      model: CharacterModel -      metrics: [] -      network: -        type: DenseNet -        args: -          growth_rate: 4 -          block_config: [4, 4] -          in_channels: 1 -          base_channels: 24 -          num_classes: 128 -          bn_size: 4 -          dropout_rate: 0.1 -          classifier: true -          activation: elu -      criterion: -        type: EmbeddingLoss -        args: -          margin: 0.2 -          type_of_triplets: semihard -      optimizer: -        type: AdamW -        args: -          lr: 1.e-02 -          betas: [0.9, 0.999] -          eps: 1.e-08 -          weight_decay: 5.e-4 -          amsgrad: false -      lr_scheduler: -        type: CosineAnnealingLR -        args: -          T_max: *max_epochs -      callbacks: [Checkpoint, ProgressBar, WandbCallback] -      callback_args: -        Checkpoint: -          monitor: val_loss -          mode: min -        ProgressBar: -          epochs: *max_epochs -        WandbCallback: -          log_batch_frequency: 10 -      verbosity: 1 # 0, 1, 2 -      resume_experiment: null -      train: true -      test: true -      test_metric: mean_average_precision_at_r diff --git a/training/experiments/sample_experiment.yml b/training/experiments/sample_experiment.yml deleted file mode 100644 index 8f94475..0000000 --- a/training/experiments/sample_experiment.yml +++ /dev/null @@ -1,99 +0,0 @@ -experiment_group: Sample Experiments -experiments: -    - train_args: -        batch_size: 256 -        max_epochs: &max_epochs 32 -      dataset: -        type: EmnistDataset -        args: -          sample_to_balance: true -          subsample_fraction: null -          transform: null -          target_transform: null -          seed: 4711 -        train_args: -          num_workers: 6 -          train_fraction: 0.8 - -      model: CharacterModel -      metrics: [accuracy] -      # network: MLP -      # network_args: -      #   input_size: 784 -      #   hidden_size: 512 -      #   output_size: 80 -      #   num_layers: 5 -      #   dropout_rate: 0.2 -      #   activation_fn: SELU -      network: -        type: ResidualNetwork -        args: -          in_channels: 1 -          num_classes: 80 -          depths: [2, 2] -          block_sizes: [64, 64] -          activation: leaky_relu -      # network: -      #   type: WideResidualNetwork -      #   args: -      #     in_channels: 1 -      #     num_classes: 80 -      #     depth: 10 -      #     num_layers: 3 -      #     width_factor: 4 -      #     dropout_rate: 0.2 -      #     activation: SELU -      # network: LeNet -      # network_args: -      #   output_size: 62 -      #   activation_fn: GELU -      criterion: -        type: CrossEntropyLoss -        args: -          weight: null -          ignore_index: -100 -          reduction: mean -      optimizer: -        type: AdamW -        args: -          lr: 1.e-02 -          betas: [0.9, 0.999] -          eps: 1.e-08 -          # weight_decay: 5.e-4 -          amsgrad: false -      # lr_scheduler: -      #   type: OneCycleLR -      #   args: -      #     max_lr: 1.e-03 -      #     epochs: *max_epochs -      #     anneal_strategy: linear -      lr_scheduler: -        type: CosineAnnealingLR -        args: -          T_max: *max_epochs -          interval: epoch -      swa_args: -        start: 2 -        lr: 5.e-2 -      callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping] -      callback_args: -        Checkpoint: -          monitor: val_accuracy -        ProgressBar: -          epochs: null -          log_batch_frequency: 100 -        EarlyStopping: -          monitor: val_loss -          min_delta: 0.0 -          patience: 5 -          mode: min -        WandbCallback: -          log_batch_frequency: 10 -        WandbImageLogger: -          num_examples: 4 -          use_transpose: true -      verbosity: 0 # 0, 1, 2 -      resume_experiment: null -      train: true -      test: true -      test_metric: test_accuracy diff --git a/training/gpu_manager.py b/training/gpu_manager.py deleted file mode 100644 index ce1b3dd..0000000 --- a/training/gpu_manager.py +++ /dev/null @@ -1,62 +0,0 @@ -"""GPUManager class.""" -import os -import time -from typing import Optional - -import gpustat -from loguru import logger -import numpy as np -from redlock import Redlock - - -GPU_LOCK_TIMEOUT = 5000  # ms - - -class GPUManager: -    """Class for allocating GPUs.""" - -    def __init__(self, verbose: bool = False) -> None: -        """Initializes Redlock manager.""" -        self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}]) -        self.verbose = verbose - -    def get_free_gpu(self) -> int: -        """Gets a free GPU. - -        If some GPUs are available, try reserving one by checking out an exclusive redis lock. -        If none available or can not get lock, sleep and check again. - -        Returns: -            int: The gpu index. - -        """ -        while True: -            gpu_index = self._get_free_gpu() -            if gpu_index is not None: -                return gpu_index - -            if self.verbose: -                logger.debug(f"pid {os.getpid()} sleeping") -            time.sleep(GPU_LOCK_TIMEOUT / 1000) - -    def _get_free_gpu(self) -> Optional[int]: -        """Fetches an available GPU index.""" -        try: -            available_gpu_indices = [ -                gpu.index -                for gpu in gpustat.GPUStatCollection.new_query() -                if gpu.memory_used < 0.5 * gpu.memory_total -            ] -        except Exception as e: -            logger.debug(f"Got the following exception: {e}") -            return None - -        if available_gpu_indices: -            gpu_index = np.random.choice(available_gpu_indices) -            if self.verbose: -                logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}") -            if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT): -                return int(gpu_index) -            if self.verbose: -                logger.debug(f"pid {os.getpid()} could not get lock.") -        return None diff --git a/training/prepare_experiments.py b/training/prepare_experiments.py deleted file mode 100644 index 21997af..0000000 --- a/training/prepare_experiments.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Run a experiment from a config file.""" -import json - -import click -import yaml - - -def run_experiments(experiments_filename: str) -> None: -    """Run experiment from file.""" -    with open(experiments_filename, "r") as f: -        experiments_config = yaml.safe_load(f) - -    num_experiments = len(experiments_config["experiments"]) -    for index in range(num_experiments): -        experiment_config = experiments_config["experiments"][index] -        experiment_config["experiment_group"] = experiments_config["experiment_group"] -        cmd = f"poetry run run-experiment --gpu=-1 --save '{json.dumps(experiment_config)}'" -        print(cmd) - - -@click.command() -@click.option( -    "--experiments_filename", -    required=True, -    type=str, -    help="Filename of Yaml file of experiments to run.", -) -def run_cli(experiments_filename: str) -> None: -    """Parse command-line arguments and run experiments from provided file.""" -    run_experiments(experiments_filename) - - -if __name__ == "__main__": -    run_cli() diff --git a/training/run_experiment.py b/training/run_experiment.py index faafea6..ff8b886 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -1,162 +1,34 @@  """Script to run experiments."""  from datetime import datetime -from glob import glob  import importlib -import json -import os  from pathlib import Path -import re -from typing import Callable, Dict, List, Optional, Tuple, Type -import warnings +from typing import Dict, List, Optional, Type  import click  from loguru import logger  import numpy as np +from omegaconf import OmegaConf +import pytorch_lightning as pl  import torch +from torch import nn  from torchsummary import summary  from tqdm import tqdm -from training.gpu_manager import GPUManager -from training.trainer.callbacks import CallbackList -from training.trainer.train import Trainer  import wandb -import yaml -import text_recognizer.models -from text_recognizer.models import Model -import text_recognizer.networks -from text_recognizer.networks.loss import loss as custom_loss_module +SEED = 4711  EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" -def _get_level(verbose: int) -> int: -    """Sets the logger level.""" -    levels = {0: 40, 1: 20, 2: 10} -    verbose = verbose if verbose <= 2 else 2 -    return levels[verbose] - - -def _create_experiment_dir( -    experiment_config: Dict, checkpoint: Optional[str] = None -) -> Path: -    """Create new experiment.""" -    EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) -    experiment_dir = EXPERIMENTS_DIRNAME / ( -        f"{experiment_config['model']}_" -        + f"{experiment_config['dataset']['type']}_" -        + f"{experiment_config['network']['type']}" -    ) - -    if checkpoint is None: -        experiment = datetime.now().strftime("%m%d_%H%M%S") -        logger.debug(f"Creating a new experiment called {experiment}") -    else: -        available_experiments = glob(str(experiment_dir) + "/*") -        available_experiments.sort() -        if checkpoint == "last": -            experiment = available_experiments[-1] -            logger.debug(f"Resuming the latest experiment {experiment}") -        else: -            experiment = checkpoint -            if not str(experiment_dir / experiment) in available_experiments: -                raise FileNotFoundError("Experiment does not exist.") -            logger.debug(f"Resuming the from experiment {checkpoint}") - -    experiment_dir = experiment_dir / experiment - -    # Create log and model directories. -    log_dir = experiment_dir / "log" -    model_dir = experiment_dir / "model" - -    return experiment_dir, log_dir, model_dir - - -def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]: -    """Loads all modules and arguments.""" -    # Load the dataset module. -    dataset_args = experiment_config.get("dataset", {}) -    dataset_ = dataset_args["type"] - -    # Import the model module and model arguments. -    model_class_ = getattr(text_recognizer.models, experiment_config["model"]) - -    # Import metrics. -    metric_fns_ = ( -        { -            metric: getattr(text_recognizer.networks, metric) -            for metric in experiment_config["metrics"] -        } -        if experiment_config["metrics"] is not None -        else None -    ) - -    # Import network module and arguments. -    network_fn_ = experiment_config["network"]["type"] -    network_args = experiment_config["network"].get("args", {}) - -    # Criterion -    if experiment_config["criterion"]["type"] in custom_loss_module.__all__: -        criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"]) -    else: -        criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) -    criterion_args = experiment_config["criterion"].get("args", {}) or {} - -    # Optimizers -    optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) -    optimizer_args = experiment_config["optimizer"].get("args", {}) - -    # Learning rate scheduler -    lr_scheduler_ = None -    lr_scheduler_args = None -    if "lr_scheduler" in experiment_config: -        lr_scheduler_ = getattr( -            torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] -        ) -        lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {} - -    # SWA scheduler. -    if "swa_args" in experiment_config: -        swa_args = experiment_config.get("swa_args", {}) or {} -    else: -        swa_args = None - -    model_args = { -        "dataset": dataset_, -        "dataset_args": dataset_args, -        "metrics": metric_fns_, -        "network_fn": network_fn_, -        "network_args": network_args, -        "criterion": criterion_, -        "criterion_args": criterion_args, -        "optimizer": optimizer_, -        "optimizer_args": optimizer_args, -        "lr_scheduler": lr_scheduler_, -        "lr_scheduler_args": lr_scheduler_args, -        "swa_args": swa_args, -    } - -    return model_class_, model_args - - -def _configure_callbacks(experiment_config: Dict, model_dir: Path) -> CallbackList: -    """Configure a callback list for trainer.""" -    if "Checkpoint" in experiment_config["callback_args"]: -        experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = str( -            model_dir -        ) - -    # Initializes callbacks. -    callback_modules = importlib.import_module("training.trainer.callbacks") -    callbacks = [] -    for callback in experiment_config["callbacks"]: -        args = experiment_config["callback_args"][callback] or {} -        callbacks.append(getattr(callback_modules, callback)(**args)) - -    return callbacks +def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: +    """Configure the loguru logger for output to terminal and disk.""" +    def _get_level(verbose: int) -> int: +        """Sets the logger level.""" +        levels = {0: 40, 1: 20, 2: 10} +        verbose = verbose if verbose <= 2 else 2 +        return levels[verbose] -def _configure_logger(log_dir: Path, verbose: int = 0) -> None: -    """Configure the loguru logger for output to terminal and disk."""      # Have to remove default logger to get tqdm to work properly.      logger.remove() @@ -164,219 +36,138 @@ def _configure_logger(log_dir: Path, verbose: int = 0) -> None:      level = _get_level(verbose)      logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level) -    logger.add( -        str(log_dir / "train.log"), -        format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", -    ) - - -def _save_config(experiment_dir: Path, experiment_config: Dict) -> None: -    """Copy config to experiment directory.""" -    config_path = experiment_dir / "config.yml" -    with open(str(config_path), "w") as f: -        yaml.dump(experiment_config, f) - - -def _load_from_checkpoint( -    model: Type[Model], model_dir: Path, pretrained_weights: str = None, -) -> None: -    """If checkpoint exists, load model weights and optimizers from checkpoint.""" -    # Get checkpoint path. -    if pretrained_weights is not None: -        logger.info(f"Loading weights from {pretrained_weights}.") -        checkpoint_path = ( -            EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt" +    if log_dir is not None: +        logger.add( +            str(log_dir / "train.log"), +            format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",          ) -    else: -        logger.info(f"Loading weights from {model_dir}.") -        checkpoint_path = model_dir / "last.pt" -    if checkpoint_path.exists(): -        logger.info("Loading and resuming training from checkpoint.") -        model.load_from_checkpoint(checkpoint_path) -def evaluate_embedding(model: Type[Model]) -> Dict: -    """Evaluates the embedding space.""" -    from pytorch_metric_learning import testers -    from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator +def _import_class(module_and_class_name: str) -> type: +    """Import class from module.""" +    module_name, class_name = module_and_class_name.rsplit(".", 1) +    module = importlib.import_module(module_name) +    return getattr(module, class_name) -    accuracy_calculator = AccuracyCalculator( -        include=("mean_average_precision_at_r",), k=10 -    ) -    def get_all_embeddings(model: Type[Model]) -> Tuple: -        tester = testers.BaseTester() -        return tester.get_all_embeddings(model.test_dataset, model.network) +def _configure_pl_callbacks(args: List[Dict]) -> List[Type[pl.callbacks.Callback]]: +    """Configures PyTorch Lightning callbacks.""" +    pl_callbacks = [ +        getattr(pl.callbacks, callback["type"])(**callback["args"]) for callback in args +    ] +    return pl_callbacks -    embeddings, labels = get_all_embeddings(model) -    logger.info("Computing embedding accuracy") -    accuracies = accuracy_calculator.get_accuracy( -        embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True -    ) -    logger.info( -        f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}" -    ) -    return accuracies +def _configure_wandb_callback( +    network: Type[nn.Module], args: Dict +) -> pl.loggers.WandbLogger: +    """Configures wandb logger.""" +    pl_logger = pl.loggers.WandbLogger() +    pl_logger.watch(network) +    pl_logger.log_hyperparams(vars(args)) +    return pl_logger -def run_experiment( -    experiment_config: Dict, -    save_weights: bool, -    device: str, -    use_wandb: bool, -    train: bool, -    test: bool, -    verbose: int = 0, -    checkpoint: Optional[str] = None, -    pretrained_weights: Optional[str] = None, -) -> None: -    """Runs an experiment.""" -    logger.info(f"Experiment config: {json.dumps(experiment_config)}") -    # Create new experiment. -    experiment_dir, log_dir, model_dir = _create_experiment_dir( -        experiment_config, checkpoint +def _save_best_weights( +    callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool +) -> None: +    """Saves the best model.""" +    model_checkpoint_callback = next( +        callback +        for callback in callbacks +        if isinstance(callback, pl.callbacks.ModelCheckpoint)      ) +    best_model_path = model_checkpoint_callback.best_model_path +    if best_model_path: +        logger.info(f"Best model saved at: {best_model_path}") +        if use_wandb: +            logger.info("Uploading model to W&B...") +            wandb.save(best_model_path) -    # Make sure the log/model directory exists. -    log_dir.mkdir(parents=True, exist_ok=True) -    model_dir.mkdir(parents=True, exist_ok=True) - -    # Load the modules and model arguments. -    model_class_, model_args = _load_modules_and_arguments(experiment_config) - -    # Initializes the model with experiment config. -    model = model_class_(**model_args, device=device) - -    callbacks = _configure_callbacks(experiment_config, model_dir) -    # Setup logger. -    _configure_logger(log_dir, verbose) +def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None: +    """Runs experiment.""" +    logger.info("Starting experiment...") -    # Load from checkpoint if resuming an experiment. -    resume = False -    if checkpoint is not None or pretrained_weights is not None: -        # resume = True -        _load_from_checkpoint(model, model_dir, pretrained_weights) +    # Seed everything in the experiment +    logger.info(f"Seeding everthing with seed={SEED}") +    pl.utilities.seed.seed_everything(SEED) -    logger.info(f"The class mapping is {model.mapping}") +    # Load config. +    logger.info(f"Loading config from: {path}") +    config = OmegaConf.load(path) -    # Initializes Weights & Biases -    if use_wandb: -        wandb.init(project="text-recognizer", config=experiment_config, resume=resume) +    # Load classes +    data_module_class = _import_class(f"text_recognizer.data.{config.data.type}") +    network_class = _import_class(f"text_recognizer.networks.{config.network.type}") +    lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}") -        # Lets W&B save the model and track the gradients and optional parameters. -        wandb.watch(model.network) +    # Initialize data object and network. +    data_module = data_module_class(**config.data.args) +    network = network_class(**config.network.args) -    experiment_config["experiment_group"] = experiment_config.get( -        "experiment_group", None +    # Load callback and logger +    callbacks = _configure_pl_callbacks(config.callbacks) +    pl_logger = ( +        _configure_wandb_callback(network, config.network.args) +        if use_wandb +        else pl.logger.TensorBoardLogger("training/logs")      ) -    experiment_config["device"] = device - -    # Save the config used in the experiment folder. -    _save_config(experiment_dir, experiment_config) - -    # Prints a summary of the network in terminal. -    model.summary(experiment_config["train_args"]["input_shape"]) +    # Checkpoint +    if config.load_checkpoint is not None: +        logger.info( +            f"Loading network weights from checkpoint: {config.load_checkpoint}" +        ) +        lit_model = lit_model_class.load_from_checkpoint( +            config.load_checkpoint, network=network, **config.model.args +        ) +    else: +        lit_model = lit_model_class(**config.model.args) -    # Load trainer. -    trainer = Trainer( -        max_epochs=experiment_config["train_args"]["max_epochs"], +    trainer = pl.Trainer( +        **config.trainer,          callbacks=callbacks, -        transformer_model=experiment_config["train_args"]["transformer_model"], -        max_norm=experiment_config["train_args"]["max_norm"], -        freeze_backbone=experiment_config["train_args"]["freeze_backbone"], +        logger=pl_logger, +        weigths_save_path="training/logs",      ) -    # Train the model. +    if tune: +        logger.info(f"Tuning learning rate and batch size...") +        trainer.tune(lit_model, datamodule=data_module) +      if train: -        trainer.fit(model) +        logger.info(f"Training network...") +        trainer.fit(lit_model, datamodule=data_module) -    # Run inference over test set.      if test: -        logger.info("Loading checkpoint with the best weights.") -        if "checkpoint" in experiment_config["train_args"]: -            model.load_from_checkpoint( -                model_dir / experiment_config["train_args"]["checkpoint"] -            ) -        else: -            model.load_from_checkpoint(model_dir / "best.pt") - -        logger.info("Running inference on test set.") -        if experiment_config["criterion"]["type"] == "EmbeddingLoss": -            logger.info("Evaluating embedding.") -            score = evaluate_embedding(model) -        else: -            score = trainer.test(model) - -        logger.info(f"Test set evaluation: {score}") - -        if use_wandb: -            wandb.log( -                { -                    experiment_config["test_metric"]: score[ -                        experiment_config["test_metric"] -                    ] -                } -            ) +        logger.info(f"Testing network...") +        trainer.test(lit_model, datamodule=data_module) -    if save_weights: -        model.save_weights(model_dir) +    _save_best_weights(callbacks, use_wandb)  @click.command() -@click.argument("experiment_config",) -@click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.") +@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.") +@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.")  @click.option( -    "--save", -    is_flag=True, -    help="If set, the final weights will be saved to a canonical, version-controlled location.", -) -@click.option( -    "--nowandb", is_flag=False, help="If true, do not use wandb for this run." +    "--tune", is_flag=True, help="If true, tune hyperparameters for training."  ) +@click.option("--train", is_flag=True, help="If true, train the model.")  @click.option("--test", is_flag=True, help="If true, test the model.")  @click.option("-v", "--verbose", count=True) -@click.option("--checkpoint", type=str, help="Path to the experiment.") -@click.option( -    "--pretrained_weights", type=str, help="Path to pretrained model weights." -) -@click.option( -    "--notrain", is_flag=False, help="Do not train the model.", -) -def run_cli( +def cli(      experiment_config: str, -    gpu: int, -    save: bool, -    nowandb: bool, -    notrain: bool, +    use_wandb: bool, +    tune: bool, +    train: bool,      test: bool,      verbose: int, -    checkpoint: Optional[str] = None, -    pretrained_weights: Optional[str] = None,  ) -> None:      """Run experiment.""" -    if gpu < 0: -        gpu_manager = GPUManager(True) -        gpu = gpu_manager.get_free_gpu() -    device = "cuda:" + str(gpu) - -    experiment_config = json.loads(experiment_config) -    os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}" - -    run_experiment( -        experiment_config, -        save, -        device, -        use_wandb=not nowandb, -        train=not notrain, -        test=test, -        verbose=verbose, -        checkpoint=checkpoint, -        pretrained_weights=pretrained_weights, -    ) +    _configure_logging(None, verbose=verbose) +    run(path=experiment_config, train=train, test=test, tune=tune, use_wandb=use_wandb)  if __name__ == "__main__": -    run_cli() +    cli() diff --git a/training/run_sweep.py b/training/run_sweep.py deleted file mode 100644 index a578592..0000000 --- a/training/run_sweep.py +++ /dev/null @@ -1,92 +0,0 @@ -"""W&B Sweep Functionality.""" -from ast import literal_eval -import json -import os -from pathlib import Path -import signal -import subprocess  # nosec -import sys -from typing import Dict, List, Tuple - -import click -import yaml - -EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" - - -def load_config() -> Dict: -    """Load base hyperparameter config.""" -    with open(str(EXPERIMENTS_DIRNAME / "default_config_emnist.yml"), "r") as f: -        default_config = yaml.safe_load(f) -    return default_config - - -def args_to_json( -    default_config: dict, preserve_args: tuple = ("gpu", "save") -) -> Tuple[dict, list]: -    """Convert command line arguments to nested config values. - -    i.e. run_sweep.py --dataset_args.foo=1.7 -    { -        "dataset_args": { -            "foo": 1.7 -        } -    } - -    Args: -        default_config (dict): The base config used for every experiment. -        preserve_args (tuple): Arguments preserved for all runs. Defaults to ("gpu", "save"). - -    Returns: -        Tuple[dict, list]: Tuple of config dictionary and list of arguments. - -    """ - -    args = [] -    config = default_config.copy() -    key, val = None, None -    for arg in sys.argv[1:]: -        if "=" in arg: -            key, val = arg.split("=") -        elif key: -            val = arg -        else: -            key = arg -        if key and val: -            parsed_key = key.lstrip("-").split(".") -            if parsed_key[0] in preserve_args: -                args.append("--{}={}".format(parsed_key[0], val)) -            else: -                nested = config -                for level in parsed_key[:-1]: -                    nested[level] = config.get(level, {}) -                    nested = nested[level] -                try: -                    # Convert numerics to floats / ints -                    val = literal_eval(val) -                except ValueError: -                    pass -                nested[parsed_key[-1]] = val -            key, val = None, None -    return config, args - - -def main() -> None: -    """Runs a W&B sweep.""" -    default_config = load_config() -    config, args = args_to_json(default_config) -    env = { -        k: v for k, v in os.environ.items() if k not in ("WANDB_PROGRAM", "WANDB_ARGS") -    } -    # pylint: disable=subprocess-popen-preexec-fn -    run = subprocess.Popen( -        ["python", "training/run_experiment.py", *args, json.dumps(config)], -        env=env, -        preexec_fn=os.setsid, -    )  # nosec -    signal.signal(signal.SIGTERM, lambda *args: run.terminate()) -    run.wait() - - -if __name__ == "__main__": -    main() diff --git a/training/sweep_emnist.yml b/training/sweep_emnist.yml deleted file mode 100644 index 48d7261..0000000 --- a/training/sweep_emnist.yml +++ /dev/null @@ -1,26 +0,0 @@ -program: training/run_sweep.py -method: bayes -metric: -  name: val_loss -  goal: minimize -parameters: -  dataset: -    value: EmnistDataset -  model: -    value: CharacterModel -  network: -    value: MLP -  network_args.hidden_size: -    values: [128, 256] -  network_args.dropout_rate: -    values: [0.2, 0.4] -  network_args.num_layers: -    values: [3, 6] -  optimizer_args.lr: -    values: [1.e-1, 1.e-5] -  lr_scheduler_args.max_lr: -    values: [1.0e-1, 1.0e-5] -  train_args.batch_size: -    values: [64, 128] -  train_args.epochs: -    value: 5 diff --git a/training/sweep_emnist_resnet.yml b/training/sweep_emnist_resnet.yml deleted file mode 100644 index 19a3040..0000000 --- a/training/sweep_emnist_resnet.yml +++ /dev/null @@ -1,50 +0,0 @@ -program: training/run_sweep.py -method: bayes -metric: -  name: val_accuracy -  goal: maximize -parameters: -  dataset: -    value: EmnistDataset -  model: -    value: CharacterModel -  network: -    value: ResidualNetwork -  network_args.block_sizes: -    distribution: q_uniform -    min: 16 -    max: 256 -    q: 8 -  network_args.depths: -    distribution: int_uniform -    min: 1 -    max: 3 -  network_args.levels: -      distribution: int_uniform -      min: 1 -      max: 2 -  network_args.activation: -    distribution: categorical -    values: -      - gelu -      - leaky_relu -      - relu -      - selu -  optimizer_args.lr: -    distribution: uniform -    min: 1.e-5 -    max: 1.e-1 -  lr_scheduler_args.max_lr: -    distribution: uniform -    min: 1.e-5 -    max: 1.e-1 -  train_args.batch_size: -    distribution: q_uniform -    min: 32 -    max: 256 -    q: 8 -  train_args.epochs: -    value: 5 -early_terminate: -   type: hyperband -   min_iter: 2 diff --git a/training/trainer/__init__.py b/training/trainer/__init__.py deleted file mode 100644 index de41bfb..0000000 --- a/training/trainer/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Trainer modules.""" -from .train import Trainer diff --git a/training/trainer/callbacks/__init__.py b/training/trainer/callbacks/__init__.py deleted file mode 100644 index 80c4177..0000000 --- a/training/trainer/callbacks/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -"""The callback modules used in the training script.""" -from .base import Callback, CallbackList -from .checkpoint import Checkpoint -from .early_stopping import EarlyStopping -from .lr_schedulers import ( -    LRScheduler, -    SWA, -) -from .progress_bar import ProgressBar -from .wandb_callbacks import ( -    WandbCallback, -    WandbImageLogger, -    WandbReconstructionLogger, -    WandbSegmentationLogger, -) - -__all__ = [ -    "Callback", -    "CallbackList", -    "Checkpoint", -    "EarlyStopping", -    "LRScheduler", -    "WandbCallback", -    "WandbImageLogger", -    "WandbReconstructionLogger", -    "WandbSegmentationLogger", -    "ProgressBar", -    "SWA", -] diff --git a/training/trainer/callbacks/base.py b/training/trainer/callbacks/base.py deleted file mode 100644 index 500b642..0000000 --- a/training/trainer/callbacks/base.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Metaclass for callback functions.""" - -from enum import Enum -from typing import Callable, Dict, List, Optional, Type, Union - -from loguru import logger -import numpy as np -import torch - -from text_recognizer.models import Model - - -class ModeKeys: -    """Mode keys for CallbackList.""" - -    TRAIN = "train" -    VALIDATION = "validation" - - -class Callback: -    """Metaclass for callbacks used in training.""" - -    def __init__(self) -> None: -        """Initializes the Callback instance.""" -        self.model = None - -    def set_model(self, model: Type[Model]) -> None: -        """Set the model.""" -        self.model = model - -    def on_fit_begin(self) -> None: -        """Called when fit begins.""" -        pass - -    def on_fit_end(self) -> None: -        """Called when fit ends.""" -        pass - -    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Called at the beginning of an epoch. Only used in training mode.""" -        pass - -    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch. Only used in training mode.""" -        pass - -    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the beginning of an epoch.""" -        pass - -    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch.""" -        pass - -    def on_validation_batch_begin( -        self, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Called at the beginning of an epoch.""" -        pass - -    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch.""" -        pass - -    def on_test_begin(self) -> None: -        """Called at the beginning of test.""" -        pass - -    def on_test_end(self) -> None: -        """Called at the end of test.""" -        pass - - -class CallbackList: -    """Container for abstracting away callback calls.""" - -    mode_keys = ModeKeys() - -    def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: -        """Container for `Callback` instances. - -        This object wraps a list of `Callback` instances and allows them all to be -        called via a single end point. - -        Args: -            model (Type[Model]): A `Model` instance. -            callbacks (List[Callback]): List of `Callback` instances. Defaults to None. - -        """ - -        self._callbacks = callbacks or [] -        if model: -            self.set_model(model) - -    def set_model(self, model: Type[Model]) -> None: -        """Set the model for all callbacks.""" -        self.model = model -        for callback in self._callbacks: -            callback.set_model(model=self.model) - -    def append(self, callback: Type[Callback]) -> None: -        """Append new callback to callback list.""" -        self._callbacks.append(callback) - -    def on_fit_begin(self) -> None: -        """Called when fit begins.""" -        for callback in self._callbacks: -            callback.on_fit_begin() - -    def on_fit_end(self) -> None: -        """Called when fit ends.""" -        for callback in self._callbacks: -            callback.on_fit_end() - -    def on_test_begin(self) -> None: -        """Called when test begins.""" -        for callback in self._callbacks: -            callback.on_test_begin() - -    def on_test_end(self) -> None: -        """Called when test ends.""" -        for callback in self._callbacks: -            callback.on_test_end() - -    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Called at the beginning of an epoch.""" -        for callback in self._callbacks: -            callback.on_epoch_begin(epoch, logs) - -    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch.""" -        for callback in self._callbacks: -            callback.on_epoch_end(epoch, logs) - -    def _call_batch_hook( -        self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Helper function for all batch_{begin | end} methods.""" -        if hook == "begin": -            self._call_batch_begin_hook(mode, batch, logs) -        elif hook == "end": -            self._call_batch_end_hook(mode, batch, logs) -        else: -            raise ValueError(f"Unrecognized hook {hook}.") - -    def _call_batch_begin_hook( -        self, mode: str, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Helper function for all `on_*_batch_begin` methods.""" -        hook_name = f"on_{mode}_batch_begin" -        self._call_batch_hook_helper(hook_name, batch, logs) - -    def _call_batch_end_hook( -        self, mode: str, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Helper function for all `on_*_batch_end` methods.""" -        hook_name = f"on_{mode}_batch_end" -        self._call_batch_hook_helper(hook_name, batch, logs) - -    def _call_batch_hook_helper( -        self, hook_name: str, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Helper function for `on_*_batch_begin` methods.""" -        for callback in self._callbacks: -            hook = getattr(callback, hook_name) -            hook(batch, logs) - -    def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the beginning of an epoch.""" -        self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) - -    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch.""" -        self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) - -    def on_validation_batch_begin( -        self, batch: int, logs: Optional[Dict] = None -    ) -> None: -        """Called at the beginning of an epoch.""" -        self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) - -    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Called at the end of an epoch.""" -        self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) - -    def __iter__(self) -> iter: -        """Iter function for callback list.""" -        return iter(self._callbacks) diff --git a/training/trainer/callbacks/checkpoint.py b/training/trainer/callbacks/checkpoint.py deleted file mode 100644 index a54e0a9..0000000 --- a/training/trainer/callbacks/checkpoint.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Callback checkpoint for training models.""" -from enum import Enum -from pathlib import Path -from typing import Callable, Dict, List, Optional, Type, Union - -from loguru import logger -import numpy as np -import torch -from training.trainer.callbacks import Callback - -from text_recognizer.models import Model - - -class Checkpoint(Callback): -    """Saving model parameters at the end of each epoch.""" - -    mode_dict = { -        "min": torch.lt, -        "max": torch.gt, -    } - -    def __init__( -        self, -        checkpoint_path: Union[str, Path], -        monitor: str = "accuracy", -        mode: str = "auto", -        min_delta: float = 0.0, -    ) -> None: -        """Monitors a quantity that will allow us to determine the best model weights. - -        Args: -            checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. -            monitor (str): Name of the quantity to monitor. Defaults to "accuracy". -            mode (str): Description of parameter `mode`. Defaults to "auto". -            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - -        """ -        super().__init__() -        self.checkpoint_path = Path(checkpoint_path) -        self.monitor = monitor -        self.mode = mode -        self.min_delta = torch.tensor(min_delta) - -        if mode not in ["auto", "min", "max"]: -            logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - -            self.mode = "auto" - -        if self.mode == "auto": -            if "accuracy" in self.monitor: -                self.mode = "max" -            else: -                self.mode = "min" -            logger.debug( -                f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." -            ) - -        torch_inf = torch.tensor(np.inf) -        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 -        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - -    @property -    def monitor_op(self) -> float: -        """Returns the comparison method.""" -        return self.mode_dict[self.mode] - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Saves a checkpoint for the network parameters. - -        Args: -            epoch (int): The current epoch. -            logs (Dict): The log containing the monitored metrics. - -        """ -        current = self.get_monitor_value(logs) -        if current is None: -            return -        if self.monitor_op(current - self.min_delta, self.best_score): -            self.best_score = current -            is_best = True -        else: -            is_best = False - -        self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor) - -    def get_monitor_value(self, logs: Dict) -> Union[float, None]: -        """Extracts the monitored value.""" -        monitor_value = logs.get(self.monitor) -        if monitor_value is None: -            logger.warning( -                f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" -                + f" metrics are: {','.join(list(logs.keys()))}" -            ) -            return None -        return monitor_value diff --git a/training/trainer/callbacks/early_stopping.py b/training/trainer/callbacks/early_stopping.py deleted file mode 100644 index 02b431f..0000000 --- a/training/trainer/callbacks/early_stopping.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Implements Early stopping for PyTorch model.""" -from typing import Dict, Union - -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from training.trainer.callbacks import Callback - - -class EarlyStopping(Callback): -    """Stops training when a monitored metric stops improving.""" - -    mode_dict = { -        "min": torch.lt, -        "max": torch.gt, -    } - -    def __init__( -        self, -        monitor: str = "val_loss", -        min_delta: float = 0.0, -        patience: int = 3, -        mode: str = "auto", -    ) -> None: -        """Initializes the EarlyStopping callback. - -        Args: -            monitor (str): Description of parameter `monitor`. Defaults to "val_loss". -            min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. -            patience (int): Description of parameter `patience`. Defaults to 3. -            mode (str): Description of parameter `mode`. Defaults to "auto". - -        """ -        super().__init__() -        self.monitor = monitor -        self.patience = patience -        self.min_delta = torch.tensor(min_delta) -        self.mode = mode -        self.wait_count = 0 -        self.stopped_epoch = 0 - -        if mode not in ["auto", "min", "max"]: -            logger.warning( -                f"EarlyStopping mode {mode} is unkown, fallback to auto mode." -            ) - -            self.mode = "auto" - -        if self.mode == "auto": -            if "accuracy" in self.monitor: -                self.mode = "max" -            else: -                self.mode = "min" -            logger.debug( -                f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." -            ) - -        self.torch_inf = torch.tensor(np.inf) -        self.min_delta *= 1 if self.monitor_op == torch.gt else -1 -        self.best_score = ( -            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf -        ) - -    @property -    def monitor_op(self) -> float: -        """Returns the comparison method.""" -        return self.mode_dict[self.mode] - -    def on_fit_begin(self) -> Union[torch.lt, torch.gt]: -        """Reset the early stopping variables for reuse.""" -        self.wait_count = 0 -        self.stopped_epoch = 0 -        self.best_score = ( -            self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf -        ) - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Computes the early stop criterion.""" -        current = self.get_monitor_value(logs) -        if current is None: -            return -        if self.monitor_op(current - self.min_delta, self.best_score): -            self.best_score = current -            self.wait_count = 0 -        else: -            self.wait_count += 1 -            if self.wait_count >= self.patience: -                self.stopped_epoch = epoch -                self.model.stop_training = True - -    def on_fit_end(self) -> None: -        """Logs if early stopping was used.""" -        if self.stopped_epoch > 0: -            logger.info( -                f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." -            ) - -    def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: -        """Extracts the monitor value.""" -        monitor_value = logs.get(self.monitor) -        if monitor_value is None: -            logger.warning( -                f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" -                + f"metrics are: {','.join(list(logs.keys()))}" -            ) -            return None -        return torch.tensor(monitor_value) diff --git a/training/trainer/callbacks/lr_schedulers.py b/training/trainer/callbacks/lr_schedulers.py deleted file mode 100644 index 630c434..0000000 --- a/training/trainer/callbacks/lr_schedulers.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Callbacks for learning rate schedulers.""" -from typing import Callable, Dict, List, Optional, Type - -from torch.optim.swa_utils import update_bn -from training.trainer.callbacks import Callback - -from text_recognizer.models import Model - - -class LRScheduler(Callback): -    """Generic learning rate scheduler callback.""" - -    def __init__(self) -> None: -        super().__init__() - -    def set_model(self, model: Type[Model]) -> None: -        """Sets the model and lr scheduler.""" -        self.model = model -        self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] -        self.interval = self.model.lr_scheduler["interval"] - -    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Takes a step at the end of every epoch.""" -        if self.interval == "epoch": -            if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__: -                self.lr_scheduler.step(logs["val_loss"]) -            else: -                self.lr_scheduler.step() - -    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Takes a step at the end of every training batch.""" -        if self.interval == "step": -            self.lr_scheduler.step() - - -class SWA(Callback): -    """Stochastic Weight Averaging callback.""" - -    def __init__(self) -> None: -        """Initializes the callback.""" -        super().__init__() -        self.lr_scheduler = None -        self.interval = None -        self.swa_scheduler = None -        self.swa_start = None -        self.current_epoch = 1 - -    def set_model(self, model: Type[Model]) -> None: -        """Sets the model and lr scheduler.""" -        self.model = model -        self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] -        self.interval = self.model.lr_scheduler["interval"] -        self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"] -        self.swa_start = self.model.swa_scheduler["swa_start"] - -    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: -        """Takes a step at the end of every training batch.""" -        if epoch > self.swa_start: -            self.model.swa_network.update_parameters(self.model.network) -            self.swa_scheduler.step() -        elif self.interval == "epoch": -            self.lr_scheduler.step() -        self.current_epoch = epoch - -    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Takes a step at the end of every training batch.""" -        if self.current_epoch < self.swa_start and self.interval == "step": -            self.lr_scheduler.step() - -    def on_fit_end(self) -> None: -        """Update batch norm statistics for the swa model at the end of training.""" -        if self.model.swa_network: -            update_bn( -                self.model.val_dataloader(), -                self.model.swa_network, -                device=self.model.device, -            ) diff --git a/training/trainer/callbacks/progress_bar.py b/training/trainer/callbacks/progress_bar.py deleted file mode 100644 index 6c4305a..0000000 --- a/training/trainer/callbacks/progress_bar.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Progress bar callback for the training loop.""" -from typing import Dict, Optional - -from tqdm import tqdm -from training.trainer.callbacks import Callback - - -class ProgressBar(Callback): -    """A TQDM progress bar for the training loop.""" - -    def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: -        """Initializes the tqdm callback.""" -        self.epochs = epochs -        print(epochs, type(epochs)) -        self.log_batch_frequency = log_batch_frequency -        self.progress_bar = None -        self.val_metrics = {} - -    def _configure_progress_bar(self) -> None: -        """Configures the tqdm progress bar with custom bar format.""" -        self.progress_bar = tqdm( -            total=len(self.model.train_dataloader()), -            leave=False, -            unit="steps", -            mininterval=self.log_batch_frequency, -            bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", -        ) - -    def _key_abbreviations(self, logs: Dict) -> Dict: -        """Changes the length of keys, so that the progress bar fits better.""" - -        def rename(key: str) -> str: -            """Renames accuracy to acc.""" -            return key.replace("accuracy", "acc") - -        return {rename(key): value for key, value in logs.items()} - -    # def on_fit_begin(self) -> None: -    #     """Creates a tqdm progress bar.""" -    #     self._configure_progress_bar() - -    def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: -        """Updates the description with the current epoch.""" -        if epoch == 1: -            self._configure_progress_bar() -        else: -            self.progress_bar.reset() -        self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """At the end of each epoch, the validation metrics are updated to the progress bar.""" -        self.val_metrics = logs -        self.progress_bar.set_postfix(**self._key_abbreviations(logs)) -        self.progress_bar.update() - -    def on_train_batch_end(self, batch: int, logs: Dict) -> None: -        """Updates the progress bar for each training step.""" -        if self.val_metrics: -            logs.update(self.val_metrics) -        self.progress_bar.set_postfix(**self._key_abbreviations(logs)) -        self.progress_bar.update() - -    def on_fit_end(self) -> None: -        """Closes the tqdm progress bar.""" -        self.progress_bar.close() diff --git a/training/trainer/callbacks/wandb_callbacks.py b/training/trainer/callbacks/wandb_callbacks.py deleted file mode 100644 index 552a4f4..0000000 --- a/training/trainer/callbacks/wandb_callbacks.py +++ /dev/null @@ -1,261 +0,0 @@ -"""Callback for W&B.""" -from typing import Callable, Dict, List, Optional, Type - -import numpy as np -from training.trainer.callbacks import Callback -import wandb - -import text_recognizer.datasets.transforms as transforms -from text_recognizer.models.base import Model - - -class WandbCallback(Callback): -    """A custom W&B metric logger for the trainer.""" - -    def __init__(self, log_batch_frequency: int = None) -> None: -        """Short summary. - -        Args: -            log_batch_frequency (int): If None, metrics will be logged every epoch. -                If set to an integer, callback will log every metrics every log_batch_frequency. - -        """ -        super().__init__() -        self.log_batch_frequency = log_batch_frequency - -    def _on_batch_end(self, batch: int, logs: Dict) -> None: -        if self.log_batch_frequency and batch % self.log_batch_frequency == 0: -            wandb.log(logs, commit=True) - -    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Logs training metrics.""" -        if logs is not None: -            logs["lr"] = self.model.optimizer.param_groups[0]["lr"] -            self._on_batch_end(batch, logs) - -    def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: -        """Logs validation metrics.""" -        if logs is not None: -            self._on_batch_end(batch, logs) - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Logs at epoch end.""" -        wandb.log(logs, commit=True) - - -class WandbImageLogger(Callback): -    """Custom W&B callback for image logging.""" - -    def __init__( -        self, -        example_indices: Optional[List] = None, -        num_examples: int = 4, -        transform: Optional[bool] = None, -    ) -> None: -        """Initializes the WandbImageLogger with the model to train. - -        Args: -            example_indices (Optional[List]): Indices for validation images. Defaults to None. -            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. -            transform (Optional[Dict]): Use transform on image or not. Defaults to None. - -        """ - -        super().__init__() -        self.caption = None -        self.example_indices = example_indices -        self.test_sample_indices = None -        self.num_examples = num_examples -        self.transform = ( -            self._configure_transform(transform) if transform is not None else None -        ) - -    def _configure_transform(self, transform: Dict) -> Callable: -        args = transform["args"] or {} -        return getattr(transforms, transform["type"])(**args) - -    def set_model(self, model: Type[Model]) -> None: -        """Sets the model and extracts validation images from the dataset.""" -        self.model = model -        self.caption = "Validation Examples" -        if self.example_indices is None: -            self.example_indices = np.random.randint( -                0, len(self.model.val_dataset), self.num_examples -            ) -        self.images = self.model.val_dataset.dataset.data[self.example_indices] -        self.targets = self.model.val_dataset.dataset.targets[self.example_indices] -        self.targets = self.targets.tolist() - -    def on_test_begin(self) -> None: -        """Get samples from test dataset.""" -        self.caption = "Test Examples" -        if self.test_sample_indices is None: -            self.test_sample_indices = np.random.randint( -                0, len(self.model.test_dataset), self.num_examples -            ) -        self.images = self.model.test_dataset.data[self.test_sample_indices] -        self.targets = self.model.test_dataset.targets[self.test_sample_indices] -        self.targets = self.targets.tolist() - -    def on_test_end(self) -> None: -        """Log test images.""" -        self.on_epoch_end(0, {}) - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Get network predictions on validation images.""" -        images = [] -        for i, image in enumerate(self.images): -            image = self.transform(image) if self.transform is not None else image -            pred, conf = self.model.predict_on_image(image) -            if isinstance(self.targets[i], list): -                ground_truth = "".join( -                    [ -                        self.model.mapper(int(target_index) - 26) -                        if target_index > 35 -                        else self.model.mapper(int(target_index)) -                        for target_index in self.targets[i] -                    ] -                ).rstrip("_") -            else: -                ground_truth = self.model.mapper(int(self.targets[i])) -            caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" -            images.append(wandb.Image(image, caption=caption)) - -        wandb.log({f"{self.caption}": images}, commit=False) - - -class WandbSegmentationLogger(Callback): -    """Custom W&B callback for image logging.""" - -    def __init__( -        self, -        class_labels: Dict, -        example_indices: Optional[List] = None, -        num_examples: int = 4, -    ) -> None: -        """Initializes the WandbImageLogger with the model to train. - -        Args: -            class_labels (Dict): A dict with int as key and class string as value. -            example_indices (Optional[List]): Indices for validation images. Defaults to None. -            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - -        """ - -        super().__init__() -        self.caption = None -        self.class_labels = {int(k): v for k, v in class_labels.items()} -        self.example_indices = example_indices -        self.test_sample_indices = None -        self.num_examples = num_examples - -    def set_model(self, model: Type[Model]) -> None: -        """Sets the model and extracts validation images from the dataset.""" -        self.model = model -        self.caption = "Validation Segmentation Examples" -        if self.example_indices is None: -            self.example_indices = np.random.randint( -                0, len(self.model.val_dataset), self.num_examples -            ) -        self.images = self.model.val_dataset.dataset.data[self.example_indices] -        self.targets = self.model.val_dataset.dataset.targets[self.example_indices] -        self.targets = self.targets.tolist() - -    def on_test_begin(self) -> None: -        """Get samples from test dataset.""" -        self.caption = "Test Segmentation Examples" -        if self.test_sample_indices is None: -            self.test_sample_indices = np.random.randint( -                0, len(self.model.test_dataset), self.num_examples -            ) -        self.images = self.model.test_dataset.data[self.test_sample_indices] -        self.targets = self.model.test_dataset.targets[self.test_sample_indices] -        self.targets = self.targets.tolist() - -    def on_test_end(self) -> None: -        """Log test images.""" -        self.on_epoch_end(0, {}) - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Get network predictions on validation images.""" -        images = [] -        for i, image in enumerate(self.images): -            pred_mask = ( -                self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() -            ) -            gt_mask = np.array(self.targets[i]) -            images.append( -                wandb.Image( -                    image, -                    masks={ -                        "predictions": { -                            "mask_data": pred_mask, -                            "class_labels": self.class_labels, -                        }, -                        "ground_truth": { -                            "mask_data": gt_mask, -                            "class_labels": self.class_labels, -                        }, -                    }, -                ) -            ) - -        wandb.log({f"{self.caption}": images}, commit=False) - - -class WandbReconstructionLogger(Callback): -    """Custom W&B callback for image reconstructions logging.""" - -    def __init__( -        self, example_indices: Optional[List] = None, num_examples: int = 4, -    ) -> None: -        """Initializes the  WandbImageLogger with the model to train. - -        Args: -            example_indices (Optional[List]): Indices for validation images. Defaults to None. -            num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - -        """ - -        super().__init__() -        self.caption = None -        self.example_indices = example_indices -        self.test_sample_indices = None -        self.num_examples = num_examples - -    def set_model(self, model: Type[Model]) -> None: -        """Sets the model and extracts validation images from the dataset.""" -        self.model = model -        self.caption = "Validation Reconstructions Examples" -        if self.example_indices is None: -            self.example_indices = np.random.randint( -                0, len(self.model.val_dataset), self.num_examples -            ) -        self.images = self.model.val_dataset.dataset.data[self.example_indices] - -    def on_test_begin(self) -> None: -        """Get samples from test dataset.""" -        self.caption = "Test Reconstructions Examples" -        if self.test_sample_indices is None: -            self.test_sample_indices = np.random.randint( -                0, len(self.model.test_dataset), self.num_examples -            ) -        self.images = self.model.test_dataset.data[self.test_sample_indices] - -    def on_test_end(self) -> None: -        """Log test images.""" -        self.on_epoch_end(0, {}) - -    def on_epoch_end(self, epoch: int, logs: Dict) -> None: -        """Get network predictions on validation images.""" -        images = [] -        for image in self.images: -            reconstructed_image = ( -                self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() -            ) -            images.append(image) -            images.append(reconstructed_image) - -        wandb.log( -            {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False, -        ) diff --git a/training/trainer/train.py b/training/trainer/train.py deleted file mode 100644 index b770c94..0000000 --- a/training/trainer/train.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Training script for PyTorch models.""" - -from pathlib import Path -import time -from typing import Dict, List, Optional, Tuple, Type -import warnings - -from einops import rearrange -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torch.optim.swa_utils import update_bn -from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA -from training.trainer.util import log_val_metric -import wandb - -from text_recognizer.models import Model - - -torch.backends.cudnn.benchmark = True -np.random.seed(4711) -torch.manual_seed(4711) -torch.cuda.manual_seed(4711) - - -warnings.filterwarnings("ignore") - - -class Trainer: -    """Trainer for training PyTorch models.""" - -    def __init__( -        self, -        max_epochs: int, -        callbacks: List[Type[Callback]], -        transformer_model: bool = False, -        max_norm: float = 0.0, -        freeze_backbone: Optional[int] = None, -    ) -> None: -        """Initialization of the Trainer. - -        Args: -            max_epochs (int): The maximum number of epochs in the training loop. -            callbacks (CallbackList): List of callbacks to be called. -            transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. -            max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0. -            freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training -                Transformers. Default is None. - -        """ -        # Training arguments. -        self.start_epoch = 1 -        self.max_epochs = max_epochs -        self.callbacks = callbacks -        self.freeze_backbone = freeze_backbone - -        # Flag for setting callbacks. -        self.callbacks_configured = False - -        self.transformer_model = transformer_model - -        self.max_norm = max_norm - -        # Model placeholders -        self.model = None - -    def _configure_callbacks(self) -> None: -        """Instantiate the CallbackList.""" -        if not self.callbacks_configured: -            # If learning rate schedulers are present, they need to be added to the callbacks. -            if self.model.swa_scheduler is not None: -                self.callbacks.append(SWA()) -            elif self.model.lr_scheduler is not None: -                self.callbacks.append(LRScheduler()) - -            self.callbacks = CallbackList(self.model, self.callbacks) - -    def compute_metrics( -        self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int -    ) -> Dict: -        """Computes metrics for output and target pairs.""" -        # Compute metrics. -        loss = loss.detach().float().item() -        output = output.detach() -        targets = targets.detach() -        if self.model.metrics is not None: -            metrics = {} -            for metric in self.model.metrics: -                if metric == "cer" or metric == "wer": -                    metrics[metric] = self.model.metrics[metric]( -                        output, -                        targets, -                        batch_size, -                        self.model.mapper(self.model.pad_token), -                    ) -                else: -                    metrics[metric] = self.model.metrics[metric](output, targets) -        else: -            metrics = {} -        metrics["loss"] = loss - -        return metrics - -    def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: -        """Performs the training step.""" -        # Pass the tensor to the device for computation. -        data, targets = samples -        data, targets = ( -            data.to(self.model.device), -            targets.to(self.model.device), -        ) - -        batch_size = data.shape[0] - -        # Placeholder for uxiliary loss. -        aux_loss = None - -        # Forward pass. -        # Get the network prediction. -        if self.transformer_model: -            if self.freeze_backbone is not None and batch < self.freeze_backbone: -                with torch.no_grad(): -                    image_features = self.model.network.extract_image_features(data) - -                if isinstance(image_features, Tuple): -                    image_features, _ = image_features - -                output = self.model.network.decode_image_features( -                    image_features, targets[:, :-1] -                ) -            else: -                output = self.model.network.forward(data, targets[:, :-1]) -                if isinstance(output, Tuple): -                    output, aux_loss = output -            output = rearrange(output, "b t v -> (b t) v") -            targets = rearrange(targets[:, 1:], "b t -> (b t)").long() -        else: -            output = self.model.forward(data) - -            if isinstance(output, Tuple): -                output, aux_loss = output -                targets = data - -        # Compute the loss. -        loss = self.model.criterion(output, targets) - -        if aux_loss is not None: -            loss += aux_loss - -        # Backward pass. -        # Clear the previous gradients. -        for p in self.model.network.parameters(): -            p.grad = None - -        # Compute the gradients. -        loss.backward() - -        if self.max_norm > 0: -            torch.nn.utils.clip_grad_norm_( -                self.model.network.parameters(), self.max_norm -            ) - -        # Perform updates using calculated gradients. -        self.model.optimizer.step() - -        metrics = self.compute_metrics(output, targets, loss, batch_size) - -        return metrics - -    def train(self) -> None: -        """Runs the training loop for one epoch.""" -        # Set model to traning mode. -        self.model.train() - -        for batch, samples in enumerate(self.model.train_dataloader()): -            self.callbacks.on_train_batch_begin(batch) -            metrics = self.training_step(batch, samples) -            self.callbacks.on_train_batch_end(batch, logs=metrics) - -    @torch.no_grad() -    def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: -        """Performs the validation step.""" - -        # Pass the tensor to the device for computation. -        data, targets = samples -        data, targets = ( -            data.to(self.model.device), -            targets.to(self.model.device), -        ) - -        batch_size = data.shape[0] - -        # Placeholder for uxiliary loss. -        aux_loss = None - -        # Forward pass. -        # Get the network prediction. -        # Use SWA if available and using test dataset. -        if self.transformer_model: -            output = self.model.network.forward(data, targets[:, :-1]) -            if isinstance(output, Tuple): -                output, aux_loss = output -            output = rearrange(output, "b t v -> (b t) v") -            targets = rearrange(targets[:, 1:], "b t -> (b t)").long() -        else: -            output = self.model.forward(data) - -            if isinstance(output, Tuple): -                output, aux_loss = output -                targets = data - -        # Compute the loss. -        loss = self.model.criterion(output, targets) - -        if aux_loss is not None: -            loss += aux_loss - -        # Compute metrics. -        metrics = self.compute_metrics(output, targets, loss, batch_size) - -        return metrics - -    def validate(self) -> Dict: -        """Runs the validation loop for one epoch.""" -        # Set model to eval mode. -        self.model.eval() - -        # Summary for the current eval loop. -        summary = [] - -        for batch, samples in enumerate(self.model.val_dataloader()): -            self.callbacks.on_validation_batch_begin(batch) -            metrics = self.validation_step(batch, samples) -            self.callbacks.on_validation_batch_end(batch, logs=metrics) -            summary.append(metrics) - -        # Compute mean of all metrics. -        metrics_mean = { -            "val_" + metric: np.mean([x[metric] for x in summary]) -            for metric in summary[0] -        } - -        return metrics_mean - -    def fit(self, model: Type[Model]) -> None: -        """Runs the training and evaluation loop.""" - -        # Sets model, loads the data, criterion, and optimizers. -        self.model = model -        self.model.prepare_data() -        self.model.configure_model() - -        # Configure callbacks. -        self._configure_callbacks() - -        # Set start time. -        t_start = time.time() - -        self.callbacks.on_fit_begin() - -        # Run the training loop. -        for epoch in range(self.start_epoch, self.max_epochs + 1): -            self.callbacks.on_epoch_begin(epoch) - -            # Perform one training pass over the training set. -            self.train() - -            # Evaluate the model on the validation set. -            val_metrics = self.validate() -            log_val_metric(val_metrics, epoch) - -            self.callbacks.on_epoch_end(epoch, logs=val_metrics) - -            if self.model.stop_training: -                break - -        # Calculate the total training time. -        t_end = time.time() -        t_training = t_end - t_start - -        self.callbacks.on_fit_end() - -        logger.info(f"Training took {t_training:.2f} s.") - -        # "Teardown". -        self.model = None - -    def test(self, model: Type[Model]) -> Dict: -        """Run inference on test data.""" - -        # Sets model, loads the data, criterion, and optimizers. -        self.model = model -        self.model.prepare_data() -        self.model.configure_model() - -        # Configure callbacks. -        self._configure_callbacks() - -        self.callbacks.on_test_begin() - -        self.model.eval() - -        # Check if SWA network is available. -        self.model.use_swa_model() - -        # Summary for the current test loop. -        summary = [] - -        for batch, samples in enumerate(self.model.test_dataloader()): -            metrics = self.validation_step(batch, samples) -            summary.append(metrics) - -        self.callbacks.on_test_end() - -        # Compute mean of all test metrics. -        metrics_mean = { -            "test_" + metric: np.mean([x[metric] for x in summary]) -            for metric in summary[0] -        } - -        # "Teardown". -        self.model = None - -        return metrics_mean diff --git a/training/trainer/util.py b/training/trainer/util.py deleted file mode 100644 index 7cf1b45..0000000 --- a/training/trainer/util.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Utility functions for training neural networks.""" -from typing import Dict, Optional - -from loguru import logger - - -def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None: -    """Logging of val metrics to file/terminal.""" -    log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") -    logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())) - - -class RunningAverage: -    """Maintains a running average.""" - -    def __init__(self) -> None: -        """Initializes the parameters.""" -        self.steps = 0 -        self.total = 0 - -    def update(self, val: float) -> None: -        """Updates the parameters.""" -        self.total += val -        self.steps += 1 - -    def __call__(self) -> float: -        """Computes the running average.""" -        return self.total / float(self.steps) |