diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 20:47:55 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 20:47:55 +0200 |
commit | 9ae5fa1a88899180f88ddb14d4cef457ceb847e5 (patch) | |
tree | 4fe2bcd82553c8062eb0908ae6442c123addf55d | |
parent | 9e54591b7e342edc93b0bb04809a0f54045c6a15 (diff) |
Add new training loop with PyTorch Lightning, remove stale files
26 files changed, 167 insertions, 2229 deletions
@@ -32,11 +32,11 @@ poetry run build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb - [x] transform that encodes iam targets to wordpieces - [x] transducer loss function - [ ] Train with word pieces - - [ ] implement wandb callback for logging +- [ ] Local attention in first layer of transformer +- [ ] Halonet encoder - [ ] Implement CPC - - [ ] Window images - - [ ] Train backbone -- [ ] Bert training, how? + - [ ] https://arxiv.org/pdf/1905.09272.pdf + - [ ] https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html?highlight=byol - [ ] Predictive coding @@ -60,20 +60,3 @@ wandb agent $SWEEP_ID ``` -## PyTorch Performance Guide -Tips and tricks from ["PyTorch Performance Tuning Guide - Szymon Migacz, NVIDIA"](https://www.youtube.com/watch?v=9mS1fIYj1So&t=125s): - -* Always better to use `num_workers > 0`, allows asynchronous data processing -* Use `pin_memory=True` to allow data loading and computations to happen on the GPU in parallel. -* Have to tune `num_workers` to use based on the problem, too many and data loading becomes slower. -* For CNNs use `torch.backends.cudnn.benchmark=True`, allows cuDNN to select the best algorithm for convolutional computations (autotuner). -* Increase batch size to max out GPU memory. -* Use optimizer for large batch training, e.g. LARS, LAMB etc. -* Set `bias=False` for convolutions directly followed by BatchNorm. -* Use `for p in model.parameters(): p.grad = None` instead of `model.zero_grad()`. -* Careful with disable debug APIs in prod (detect_anomaly, profiler, gradcheck). -* Use `DistributedDataParallel` not `DataParallel`, uses 1 CPU core for each GPU. -* Important to load balance compute on all GPUs, if variably-sized inputs or GPUs will idle. -* Use an apex fused optimizer -* Use checkpointing to recompute memory-intensive compute-efficient ops in backward pass (e.g. activations, upsampling), `torch.utils.checkpoint`. -* Use `@torch.jit.script`, especially to fuse long sequences of pointwise operations like GELU. diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb index 0fa5a1b..81957f7 100644 --- a/notebooks/00-testing-stuff-out.ipynb +++ b/notebooks/00-testing-stuff-out.ipynb @@ -25,8 +25,52 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class Hej:\n", + " a = 2\n", + " \n", + "class Hejjj:\n", + " b = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": {}, + "outputs": [], + "source": [ + "l = [Hej(), Hejjj()]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<__main__.Hej at 0x7efefc77f370>" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "next(o for o in l if isinstance(o, Hej))" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "scrolled": true + }, "outputs": [ { "data": { diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 46e5136..2d6e435 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,5 +1,5 @@ """Base PyTorch Lightning model.""" -from typing import Any, Dict, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import madgrad import pytorch_lightning as pl @@ -40,7 +40,7 @@ class LitBaseModel(pl.LightningModule): args = {} or criterion_args["args"] return getattr(nn, criterion_args["type"])(**args) - def configure_optimizer(self) -> Dict[str, Any]: + def configure_optimizer(self) -> Tuple[List[type], List[Dict[str, Any]]]: """Configures optimizer and lr scheduler.""" args = {} or self.optimizer_args["args"] if self.optimizer_args["type"] == "MADGRAD": @@ -48,15 +48,15 @@ class LitBaseModel(pl.LightningModule): else: optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args) + scheduler = {"monitor": self.monitor} args = {} or self.lr_scheduler_args["args"] - scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])( - **args - ) - return { - "optimizer": optimizer, - "lr_scheduler": scheduler, - "monitor": self.monitor, - } + if "interval" in args: + scheduler["interval"] = args.pop("interval") + + scheduler["scheduler"] = getattr( + torch.optim.lr_scheduler, self.lr_scheduler_args["type"] + )(**args) + return [optimizer], [scheduler] def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py index b489264..cb83608 100644 --- a/text_recognizer/networks/loss/__init__.py +++ b/text_recognizer/networks/loss/__init__.py @@ -1,2 +1,2 @@ """Loss module.""" -from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy +from .loss import LabelSmoothingCrossEntropy diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py index cf9fa0d..d12dc9c 100644 --- a/text_recognizer/networks/loss/loss.py +++ b/text_recognizer/networks/loss/loss.py @@ -1,39 +1,9 @@ """Implementations of custom loss functions.""" -from pytorch_metric_learning import distances, losses, miners, reducers import torch from torch import nn from torch import Tensor -from torch.autograd import Variable -import torch.nn.functional as F -__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] - - -class EmbeddingLoss: - """Metric loss for training encoders to produce information-rich latent embeddings.""" - - def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: - self.distance = distances.CosineSimilarity() - self.reducer = reducers.ThresholdReducer(low=0) - self.loss_fn = losses.TripletMarginLoss( - margin=margin, distance=self.distance, reducer=self.reducer - ) - self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) - - def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: - """Computes the metric loss for the embeddings based on their labels. - - Args: - embeddings (Tensor): The laten vectors encoded by the network. - labels (Tensor): Labels of the embeddings. - - Returns: - Tensor: The metric loss for the embeddings. - - """ - hard_pairs = self.miner(embeddings, labels) - loss = self.loss_fn(embeddings, labels, hard_pairs) - return loss +__all__ = ["LabelSmoothingCrossEntropy"] class LabelSmoothingCrossEntropy(nn.Module): diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 131a6b4..d292680 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -1,38 +1,13 @@ """Miscellaneous neural network functionality.""" import importlib from pathlib import Path -from typing import Dict, Tuple, Type +from typing import Dict, Type -from einops import rearrange from loguru import logger import torch from torch import nn -def sliding_window( - images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] -) -> torch.Tensor: - """Creates patches of an image. - - Args: - images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). - patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. - stride (Tuple[int, int]): The stride of the sliding window. - - Returns: - torch.Tensor: A tensor with the shape (batch, patches, height, width). - - """ - unfold = nn.Unfold(kernel_size=patch_size, stride=stride) - # Preform the sliding window, unsqueeze as the channel dimesion is lost. - c = images.shape[1] - patches = unfold(images) - patches = rearrange( - patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1], - ) - return patches - - def activation_function(activation: str) -> Type[nn.Module]: """Returns the callable activation function.""" activation_fns = nn.ModuleDict( diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py deleted file mode 100644 index c673d96..0000000 --- a/text_recognizer/networks/vq_transformer.py +++ /dev/null @@ -1,150 +0,0 @@ -"""A VQ-Transformer for image to text recognition.""" -from typing import Dict, Optional, Tuple - -from einops import rearrange, repeat -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.util import configure_backbone -from text_recognizer.networks.vqvae.encoder import _ResidualBlock - - -class VQTransformer(nn.Module): - """VQ+Transfomer for image to character sequence prediction.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - adaptive_pool_dim: Tuple, - expansion_dim: int, - dropout_rate: float, - trg_pad_index: int, - max_len: int, - backbone: str, - backbone_args: Optional[Dict] = None, - activation: str = "gelu", - ) -> None: - super().__init__() - - # Configure vector quantized backbone. - self.backbone = configure_backbone(backbone, backbone_args) - self.conv = nn.Sequential( - nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2), - nn.ReLU(inplace=True), - ) - - # Configure embeddings for Transformer network. - self.trg_pad_index = trg_pad_index - self.vocab_size = vocab_size - self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) - self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - nn.init.normal_(self.character_embedding.weight, std=0.02) - - self.adaptive_pool = ( - nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None - ) - - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, - ) - - self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) - - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) - ) - - def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D - - Args: - src (Tensor): Input tensor. - - Returns: - Tensor: The input src to the transformer and the vq loss. - - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - src, vq_loss = self.backbone.encode(src) - # src = self.backbone.decoder.res_block(src) - src = self.conv(src) - - if self.adaptive_pool is not None: - src = rearrange(src, "b c h w -> b w c h") - src = self.adaptive_pool(src) - src = src.squeeze(3) - else: - src = rearrange(src, "b c h w -> b (w h) c") - - b, t, _ = src.shape - - src += self.src_position_embedding[:, :t] - - return src, vq_loss - - def target_embedding(self, trg: Tensor) -> Tensor: - """Encodes target tensor with embedding and postion. - - Args: - trg (Tensor): Target tensor. - - Returns: - Tensor: Encoded target tensor. - - """ - trg = self.character_embedding(trg.long()) - trg = self.trg_position_encoding(trg) - return trg - - def decode_image_features( - self, image_features: Tensor, trg: Optional[Tensor] = None - ) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(image_features, trg, trg_mask=trg_mask) - - logits = self.head(out) - return logits - - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - image_features, vq_loss = self.extract_image_features(x) - logits = self.decode_image_features(image_features, trg) - return logits, vq_loss diff --git a/training/experiments/default_config_emnist.yml b/training/experiments/default_config_emnist.yml deleted file mode 100644 index bf2ed0a..0000000 --- a/training/experiments/default_config_emnist.yml +++ /dev/null @@ -1,70 +0,0 @@ -dataset: EmnistDataset -dataset_args: - sample_to_balance: true - subsample_fraction: 0.33 - transform: null - target_transform: null - seed: 4711 - -data_loader_args: - splits: [train, val] - shuffle: true - num_workers: 8 - cuda: true - -model: CharacterModel -metrics: [accuracy] - -network_args: - in_channels: 1 - num_classes: 80 - depths: [2] - block_sizes: [256] - -train_args: - batch_size: 256 - epochs: 5 - -criterion: CrossEntropyLoss -criterion_args: - weight: null - ignore_index: -100 - reduction: mean - -optimizer: AdamW -optimizer_args: - lr: 1.e-03 - betas: [0.9, 0.999] - eps: 1.e-08 - # weight_decay: 5.e-4 - amsgrad: false - -lr_scheduler: OneCycleLR -lr_scheduler_args: - max_lr: 1.e-03 - epochs: 5 - anneal_strategy: linear - - -callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] -callback_args: - Checkpoint: - monitor: val_accuracy - ProgressBar: - epochs: 5 - log_batch_frequency: 100 - EarlyStopping: - monitor: val_loss - min_delta: 0.0 - patience: 3 - mode: min - WandbCallback: - log_batch_frequency: 10 - WandbImageLogger: - num_examples: 4 - OneCycleLR: - null -verbosity: 1 # 0, 1, 2 -resume_experiment: null -train: true -validation_metric: val_accuracy diff --git a/training/experiments/embedding_experiment.yml b/training/experiments/embedding_experiment.yml deleted file mode 100644 index 1e5f941..0000000 --- a/training/experiments/embedding_experiment.yml +++ /dev/null @@ -1,64 +0,0 @@ -experiment_group: Embedding Experiments -experiments: - - train_args: - transformer_model: false - batch_size: &batch_size 256 - max_epochs: &max_epochs 32 - input_shape: [[1, 28, 28]] - dataset: - type: EmnistDataset - args: - sample_to_balance: true - subsample_fraction: null - transform: null - target_transform: null - seed: 4711 - train_args: - num_workers: 8 - train_fraction: 0.85 - batch_size: *batch_size - model: CharacterModel - metrics: [] - network: - type: DenseNet - args: - growth_rate: 4 - block_config: [4, 4] - in_channels: 1 - base_channels: 24 - num_classes: 128 - bn_size: 4 - dropout_rate: 0.1 - classifier: true - activation: elu - criterion: - type: EmbeddingLoss - args: - margin: 0.2 - type_of_triplets: semihard - optimizer: - type: AdamW - args: - lr: 1.e-02 - betas: [0.9, 0.999] - eps: 1.e-08 - weight_decay: 5.e-4 - amsgrad: false - lr_scheduler: - type: CosineAnnealingLR - args: - T_max: *max_epochs - callbacks: [Checkpoint, ProgressBar, WandbCallback] - callback_args: - Checkpoint: - monitor: val_loss - mode: min - ProgressBar: - epochs: *max_epochs - WandbCallback: - log_batch_frequency: 10 - verbosity: 1 # 0, 1, 2 - resume_experiment: null - train: true - test: true - test_metric: mean_average_precision_at_r diff --git a/training/experiments/sample_experiment.yml b/training/experiments/sample_experiment.yml deleted file mode 100644 index 8f94475..0000000 --- a/training/experiments/sample_experiment.yml +++ /dev/null @@ -1,99 +0,0 @@ -experiment_group: Sample Experiments -experiments: - - train_args: - batch_size: 256 - max_epochs: &max_epochs 32 - dataset: - type: EmnistDataset - args: - sample_to_balance: true - subsample_fraction: null - transform: null - target_transform: null - seed: 4711 - train_args: - num_workers: 6 - train_fraction: 0.8 - - model: CharacterModel - metrics: [accuracy] - # network: MLP - # network_args: - # input_size: 784 - # hidden_size: 512 - # output_size: 80 - # num_layers: 5 - # dropout_rate: 0.2 - # activation_fn: SELU - network: - type: ResidualNetwork - args: - in_channels: 1 - num_classes: 80 - depths: [2, 2] - block_sizes: [64, 64] - activation: leaky_relu - # network: - # type: WideResidualNetwork - # args: - # in_channels: 1 - # num_classes: 80 - # depth: 10 - # num_layers: 3 - # width_factor: 4 - # dropout_rate: 0.2 - # activation: SELU - # network: LeNet - # network_args: - # output_size: 62 - # activation_fn: GELU - criterion: - type: CrossEntropyLoss - args: - weight: null - ignore_index: -100 - reduction: mean - optimizer: - type: AdamW - args: - lr: 1.e-02 - betas: [0.9, 0.999] - eps: 1.e-08 - # weight_decay: 5.e-4 - amsgrad: false - # lr_scheduler: - # type: OneCycleLR - # args: - # max_lr: 1.e-03 - # epochs: *max_epochs - # anneal_strategy: linear - lr_scheduler: - type: CosineAnnealingLR - args: - T_max: *max_epochs - interval: epoch - swa_args: - start: 2 - lr: 5.e-2 - callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping] - callback_args: - Checkpoint: - monitor: val_accuracy - ProgressBar: - epochs: null - log_batch_frequency: 100 - EarlyStopping: - monitor: val_loss - min_delta: 0.0 - patience: 5 - mode: min - WandbCallback: - log_batch_frequency: 10 - WandbImageLogger: - num_examples: 4 - use_transpose: true - verbosity: 0 # 0, 1, 2 - resume_experiment: null - train: true - test: true - test_metric: test_accuracy diff --git a/training/gpu_manager.py b/training/gpu_manager.py deleted file mode 100644 index ce1b3dd..0000000 --- a/training/gpu_manager.py +++ /dev/null @@ -1,62 +0,0 @@ -"""GPUManager class.""" -import os -import time -from typing import Optional - -import gpustat -from loguru import logger -import numpy as np -from redlock import Redlock - - -GPU_LOCK_TIMEOUT = 5000 # ms - - -class GPUManager: - """Class for allocating GPUs.""" - - def __init__(self, verbose: bool = False) -> None: - """Initializes Redlock manager.""" - self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}]) - self.verbose = verbose - - def get_free_gpu(self) -> int: - """Gets a free GPU. - - If some GPUs are available, try reserving one by checking out an exclusive redis lock. - If none available or can not get lock, sleep and check again. - - Returns: - int: The gpu index. - - """ - while True: - gpu_index = self._get_free_gpu() - if gpu_index is not None: - return gpu_index - - if self.verbose: - logger.debug(f"pid {os.getpid()} sleeping") - time.sleep(GPU_LOCK_TIMEOUT / 1000) - - def _get_free_gpu(self) -> Optional[int]: - """Fetches an available GPU index.""" - try: - available_gpu_indices = [ - gpu.index - for gpu in gpustat.GPUStatCollection.new_query() - if gpu.memory_used < 0.5 * gpu.memory_total - ] - except Exception as e: - logger.debug(f"Got the following exception: {e}") - return None - - if available_gpu_indices: - gpu_index = np.random.choice(available_gpu_indices) - if self.verbose: - logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}") - if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT): - return int(gpu_index) - if self.verbose: - logger.debug(f"pid {os.getpid()} could not get lock.") - return None diff --git a/training/prepare_experiments.py b/training/prepare_experiments.py deleted file mode 100644 index 21997af..0000000 --- a/training/prepare_experiments.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Run a experiment from a config file.""" -import json - -import click -import yaml - - -def run_experiments(experiments_filename: str) -> None: - """Run experiment from file.""" - with open(experiments_filename, "r") as f: - experiments_config = yaml.safe_load(f) - - num_experiments = len(experiments_config["experiments"]) - for index in range(num_experiments): - experiment_config = experiments_config["experiments"][index] - experiment_config["experiment_group"] = experiments_config["experiment_group"] - cmd = f"poetry run run-experiment --gpu=-1 --save '{json.dumps(experiment_config)}'" - print(cmd) - - -@click.command() -@click.option( - "--experiments_filename", - required=True, - type=str, - help="Filename of Yaml file of experiments to run.", -) -def run_cli(experiments_filename: str) -> None: - """Parse command-line arguments and run experiments from provided file.""" - run_experiments(experiments_filename) - - -if __name__ == "__main__": - run_cli() diff --git a/training/run_experiment.py b/training/run_experiment.py index faafea6..ff8b886 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -1,162 +1,34 @@ """Script to run experiments.""" from datetime import datetime -from glob import glob import importlib -import json -import os from pathlib import Path -import re -from typing import Callable, Dict, List, Optional, Tuple, Type -import warnings +from typing import Dict, List, Optional, Type import click from loguru import logger import numpy as np +from omegaconf import OmegaConf +import pytorch_lightning as pl import torch +from torch import nn from torchsummary import summary from tqdm import tqdm -from training.gpu_manager import GPUManager -from training.trainer.callbacks import CallbackList -from training.trainer.train import Trainer import wandb -import yaml -import text_recognizer.models -from text_recognizer.models import Model -import text_recognizer.networks -from text_recognizer.networks.loss import loss as custom_loss_module +SEED = 4711 EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" -def _get_level(verbose: int) -> int: - """Sets the logger level.""" - levels = {0: 40, 1: 20, 2: 10} - verbose = verbose if verbose <= 2 else 2 - return levels[verbose] - - -def _create_experiment_dir( - experiment_config: Dict, checkpoint: Optional[str] = None -) -> Path: - """Create new experiment.""" - EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) - experiment_dir = EXPERIMENTS_DIRNAME / ( - f"{experiment_config['model']}_" - + f"{experiment_config['dataset']['type']}_" - + f"{experiment_config['network']['type']}" - ) - - if checkpoint is None: - experiment = datetime.now().strftime("%m%d_%H%M%S") - logger.debug(f"Creating a new experiment called {experiment}") - else: - available_experiments = glob(str(experiment_dir) + "/*") - available_experiments.sort() - if checkpoint == "last": - experiment = available_experiments[-1] - logger.debug(f"Resuming the latest experiment {experiment}") - else: - experiment = checkpoint - if not str(experiment_dir / experiment) in available_experiments: - raise FileNotFoundError("Experiment does not exist.") - logger.debug(f"Resuming the from experiment {checkpoint}") - - experiment_dir = experiment_dir / experiment - - # Create log and model directories. - log_dir = experiment_dir / "log" - model_dir = experiment_dir / "model" - - return experiment_dir, log_dir, model_dir - - -def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]: - """Loads all modules and arguments.""" - # Load the dataset module. - dataset_args = experiment_config.get("dataset", {}) - dataset_ = dataset_args["type"] - - # Import the model module and model arguments. - model_class_ = getattr(text_recognizer.models, experiment_config["model"]) - - # Import metrics. - metric_fns_ = ( - { - metric: getattr(text_recognizer.networks, metric) - for metric in experiment_config["metrics"] - } - if experiment_config["metrics"] is not None - else None - ) - - # Import network module and arguments. - network_fn_ = experiment_config["network"]["type"] - network_args = experiment_config["network"].get("args", {}) - - # Criterion - if experiment_config["criterion"]["type"] in custom_loss_module.__all__: - criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"]) - else: - criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) - criterion_args = experiment_config["criterion"].get("args", {}) or {} - - # Optimizers - optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) - optimizer_args = experiment_config["optimizer"].get("args", {}) - - # Learning rate scheduler - lr_scheduler_ = None - lr_scheduler_args = None - if "lr_scheduler" in experiment_config: - lr_scheduler_ = getattr( - torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] - ) - lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {} - - # SWA scheduler. - if "swa_args" in experiment_config: - swa_args = experiment_config.get("swa_args", {}) or {} - else: - swa_args = None - - model_args = { - "dataset": dataset_, - "dataset_args": dataset_args, - "metrics": metric_fns_, - "network_fn": network_fn_, - "network_args": network_args, - "criterion": criterion_, - "criterion_args": criterion_args, - "optimizer": optimizer_, - "optimizer_args": optimizer_args, - "lr_scheduler": lr_scheduler_, - "lr_scheduler_args": lr_scheduler_args, - "swa_args": swa_args, - } - - return model_class_, model_args - - -def _configure_callbacks(experiment_config: Dict, model_dir: Path) -> CallbackList: - """Configure a callback list for trainer.""" - if "Checkpoint" in experiment_config["callback_args"]: - experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = str( - model_dir - ) - - # Initializes callbacks. - callback_modules = importlib.import_module("training.trainer.callbacks") - callbacks = [] - for callback in experiment_config["callbacks"]: - args = experiment_config["callback_args"][callback] or {} - callbacks.append(getattr(callback_modules, callback)(**args)) - - return callbacks +def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: + """Configure the loguru logger for output to terminal and disk.""" + def _get_level(verbose: int) -> int: + """Sets the logger level.""" + levels = {0: 40, 1: 20, 2: 10} + verbose = verbose if verbose <= 2 else 2 + return levels[verbose] -def _configure_logger(log_dir: Path, verbose: int = 0) -> None: - """Configure the loguru logger for output to terminal and disk.""" # Have to remove default logger to get tqdm to work properly. logger.remove() @@ -164,219 +36,138 @@ def _configure_logger(log_dir: Path, verbose: int = 0) -> None: level = _get_level(verbose) logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level) - logger.add( - str(log_dir / "train.log"), - format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", - ) - - -def _save_config(experiment_dir: Path, experiment_config: Dict) -> None: - """Copy config to experiment directory.""" - config_path = experiment_dir / "config.yml" - with open(str(config_path), "w") as f: - yaml.dump(experiment_config, f) - - -def _load_from_checkpoint( - model: Type[Model], model_dir: Path, pretrained_weights: str = None, -) -> None: - """If checkpoint exists, load model weights and optimizers from checkpoint.""" - # Get checkpoint path. - if pretrained_weights is not None: - logger.info(f"Loading weights from {pretrained_weights}.") - checkpoint_path = ( - EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt" + if log_dir is not None: + logger.add( + str(log_dir / "train.log"), + format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", ) - else: - logger.info(f"Loading weights from {model_dir}.") - checkpoint_path = model_dir / "last.pt" - if checkpoint_path.exists(): - logger.info("Loading and resuming training from checkpoint.") - model.load_from_checkpoint(checkpoint_path) -def evaluate_embedding(model: Type[Model]) -> Dict: - """Evaluates the embedding space.""" - from pytorch_metric_learning import testers - from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator +def _import_class(module_and_class_name: str) -> type: + """Import class from module.""" + module_name, class_name = module_and_class_name.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) - accuracy_calculator = AccuracyCalculator( - include=("mean_average_precision_at_r",), k=10 - ) - def get_all_embeddings(model: Type[Model]) -> Tuple: - tester = testers.BaseTester() - return tester.get_all_embeddings(model.test_dataset, model.network) +def _configure_pl_callbacks(args: List[Dict]) -> List[Type[pl.callbacks.Callback]]: + """Configures PyTorch Lightning callbacks.""" + pl_callbacks = [ + getattr(pl.callbacks, callback["type"])(**callback["args"]) for callback in args + ] + return pl_callbacks - embeddings, labels = get_all_embeddings(model) - logger.info("Computing embedding accuracy") - accuracies = accuracy_calculator.get_accuracy( - embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True - ) - logger.info( - f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}" - ) - return accuracies +def _configure_wandb_callback( + network: Type[nn.Module], args: Dict +) -> pl.loggers.WandbLogger: + """Configures wandb logger.""" + pl_logger = pl.loggers.WandbLogger() + pl_logger.watch(network) + pl_logger.log_hyperparams(vars(args)) + return pl_logger -def run_experiment( - experiment_config: Dict, - save_weights: bool, - device: str, - use_wandb: bool, - train: bool, - test: bool, - verbose: int = 0, - checkpoint: Optional[str] = None, - pretrained_weights: Optional[str] = None, -) -> None: - """Runs an experiment.""" - logger.info(f"Experiment config: {json.dumps(experiment_config)}") - # Create new experiment. - experiment_dir, log_dir, model_dir = _create_experiment_dir( - experiment_config, checkpoint +def _save_best_weights( + callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool +) -> None: + """Saves the best model.""" + model_checkpoint_callback = next( + callback + for callback in callbacks + if isinstance(callback, pl.callbacks.ModelCheckpoint) ) + best_model_path = model_checkpoint_callback.best_model_path + if best_model_path: + logger.info(f"Best model saved at: {best_model_path}") + if use_wandb: + logger.info("Uploading model to W&B...") + wandb.save(best_model_path) - # Make sure the log/model directory exists. - log_dir.mkdir(parents=True, exist_ok=True) - model_dir.mkdir(parents=True, exist_ok=True) - - # Load the modules and model arguments. - model_class_, model_args = _load_modules_and_arguments(experiment_config) - - # Initializes the model with experiment config. - model = model_class_(**model_args, device=device) - - callbacks = _configure_callbacks(experiment_config, model_dir) - # Setup logger. - _configure_logger(log_dir, verbose) +def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None: + """Runs experiment.""" + logger.info("Starting experiment...") - # Load from checkpoint if resuming an experiment. - resume = False - if checkpoint is not None or pretrained_weights is not None: - # resume = True - _load_from_checkpoint(model, model_dir, pretrained_weights) + # Seed everything in the experiment + logger.info(f"Seeding everthing with seed={SEED}") + pl.utilities.seed.seed_everything(SEED) - logger.info(f"The class mapping is {model.mapping}") + # Load config. + logger.info(f"Loading config from: {path}") + config = OmegaConf.load(path) - # Initializes Weights & Biases - if use_wandb: - wandb.init(project="text-recognizer", config=experiment_config, resume=resume) + # Load classes + data_module_class = _import_class(f"text_recognizer.data.{config.data.type}") + network_class = _import_class(f"text_recognizer.networks.{config.network.type}") + lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}") - # Lets W&B save the model and track the gradients and optional parameters. - wandb.watch(model.network) + # Initialize data object and network. + data_module = data_module_class(**config.data.args) + network = network_class(**config.network.args) - experiment_config["experiment_group"] = experiment_config.get( - "experiment_group", None + # Load callback and logger + callbacks = _configure_pl_callbacks(config.callbacks) + pl_logger = ( + _configure_wandb_callback(network, config.network.args) + if use_wandb + else pl.logger.TensorBoardLogger("training/logs") ) - experiment_config["device"] = device - - # Save the config used in the experiment folder. - _save_config(experiment_dir, experiment_config) - - # Prints a summary of the network in terminal. - model.summary(experiment_config["train_args"]["input_shape"]) + # Checkpoint + if config.load_checkpoint is not None: + logger.info( + f"Loading network weights from checkpoint: {config.load_checkpoint}" + ) + lit_model = lit_model_class.load_from_checkpoint( + config.load_checkpoint, network=network, **config.model.args + ) + else: + lit_model = lit_model_class(**config.model.args) - # Load trainer. - trainer = Trainer( - max_epochs=experiment_config["train_args"]["max_epochs"], + trainer = pl.Trainer( + **config.trainer, callbacks=callbacks, - transformer_model=experiment_config["train_args"]["transformer_model"], - max_norm=experiment_config["train_args"]["max_norm"], - freeze_backbone=experiment_config["train_args"]["freeze_backbone"], + logger=pl_logger, + weigths_save_path="training/logs", ) - # Train the model. + if tune: + logger.info(f"Tuning learning rate and batch size...") + trainer.tune(lit_model, datamodule=data_module) + if train: - trainer.fit(model) + logger.info(f"Training network...") + trainer.fit(lit_model, datamodule=data_module) - # Run inference over test set. if test: - logger.info("Loading checkpoint with the best weights.") - if "checkpoint" in experiment_config["train_args"]: - model.load_from_checkpoint( - model_dir / experiment_config["train_args"]["checkpoint"] - ) - else: - model.load_from_checkpoint(model_dir / "best.pt") - - logger.info("Running inference on test set.") - if experiment_config["criterion"]["type"] == "EmbeddingLoss": - logger.info("Evaluating embedding.") - score = evaluate_embedding(model) - else: - score = trainer.test(model) - - logger.info(f"Test set evaluation: {score}") - - if use_wandb: - wandb.log( - { - experiment_config["test_metric"]: score[ - experiment_config["test_metric"] - ] - } - ) + logger.info(f"Testing network...") + trainer.test(lit_model, datamodule=data_module) - if save_weights: - model.save_weights(model_dir) + _save_best_weights(callbacks, use_wandb) @click.command() -@click.argument("experiment_config",) -@click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.") +@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.") +@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.") @click.option( - "--save", - is_flag=True, - help="If set, the final weights will be saved to a canonical, version-controlled location.", -) -@click.option( - "--nowandb", is_flag=False, help="If true, do not use wandb for this run." + "--tune", is_flag=True, help="If true, tune hyperparameters for training." ) +@click.option("--train", is_flag=True, help="If true, train the model.") @click.option("--test", is_flag=True, help="If true, test the model.") @click.option("-v", "--verbose", count=True) -@click.option("--checkpoint", type=str, help="Path to the experiment.") -@click.option( - "--pretrained_weights", type=str, help="Path to pretrained model weights." -) -@click.option( - "--notrain", is_flag=False, help="Do not train the model.", -) -def run_cli( +def cli( experiment_config: str, - gpu: int, - save: bool, - nowandb: bool, - notrain: bool, + use_wandb: bool, + tune: bool, + train: bool, test: bool, verbose: int, - checkpoint: Optional[str] = None, - pretrained_weights: Optional[str] = None, ) -> None: """Run experiment.""" - if gpu < 0: - gpu_manager = GPUManager(True) - gpu = gpu_manager.get_free_gpu() - device = "cuda:" + str(gpu) - - experiment_config = json.loads(experiment_config) - os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}" - - run_experiment( - experiment_config, - save, - device, - use_wandb=not nowandb, - train=not notrain, - test=test, - verbose=verbose, - checkpoint=checkpoint, - pretrained_weights=pretrained_weights, - ) + _configure_logging(None, verbose=verbose) + run(path=experiment_config, train=train, test=test, tune=tune, use_wandb=use_wandb) if __name__ == "__main__": - run_cli() + cli() diff --git a/training/run_sweep.py b/training/run_sweep.py deleted file mode 100644 index a578592..0000000 --- a/training/run_sweep.py +++ /dev/null @@ -1,92 +0,0 @@ -"""W&B Sweep Functionality.""" -from ast import literal_eval -import json -import os -from pathlib import Path -import signal -import subprocess # nosec -import sys -from typing import Dict, List, Tuple - -import click -import yaml - -EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" - - -def load_config() -> Dict: - """Load base hyperparameter config.""" - with open(str(EXPERIMENTS_DIRNAME / "default_config_emnist.yml"), "r") as f: - default_config = yaml.safe_load(f) - return default_config - - -def args_to_json( - default_config: dict, preserve_args: tuple = ("gpu", "save") -) -> Tuple[dict, list]: - """Convert command line arguments to nested config values. - - i.e. run_sweep.py --dataset_args.foo=1.7 - { - "dataset_args": { - "foo": 1.7 - } - } - - Args: - default_config (dict): The base config used for every experiment. - preserve_args (tuple): Arguments preserved for all runs. Defaults to ("gpu", "save"). - - Returns: - Tuple[dict, list]: Tuple of config dictionary and list of arguments. - - """ - - args = [] - config = default_config.copy() - key, val = None, None - for arg in sys.argv[1:]: - if "=" in arg: - key, val = arg.split("=") - elif key: - val = arg - else: - key = arg - if key and val: - parsed_key = key.lstrip("-").split(".") - if parsed_key[0] in preserve_args: - args.append("--{}={}".format(parsed_key[0], val)) - else: - nested = config - for level in parsed_key[:-1]: - nested[level] = config.get(level, {}) - nested = nested[level] - try: - # Convert numerics to floats / ints - val = literal_eval(val) - except ValueError: - pass - nested[parsed_key[-1]] = val - key, val = None, None - return config, args - - -def main() -> None: - """Runs a W&B sweep.""" - default_config = load_config() - config, args = args_to_json(default_config) - env = { - k: v for k, v in os.environ.items() if k not in ("WANDB_PROGRAM", "WANDB_ARGS") - } - # pylint: disable=subprocess-popen-preexec-fn - run = subprocess.Popen( - ["python", "training/run_experiment.py", *args, json.dumps(config)], - env=env, - preexec_fn=os.setsid, - ) # nosec - signal.signal(signal.SIGTERM, lambda *args: run.terminate()) - run.wait() - - -if __name__ == "__main__": - main() diff --git a/training/sweep_emnist.yml b/training/sweep_emnist.yml deleted file mode 100644 index 48d7261..0000000 --- a/training/sweep_emnist.yml +++ /dev/null @@ -1,26 +0,0 @@ -program: training/run_sweep.py -method: bayes -metric: - name: val_loss - goal: minimize -parameters: - dataset: - value: EmnistDataset - model: - value: CharacterModel - network: - value: MLP - network_args.hidden_size: - values: [128, 256] - network_args.dropout_rate: - values: [0.2, 0.4] - network_args.num_layers: - values: [3, 6] - optimizer_args.lr: - values: [1.e-1, 1.e-5] - lr_scheduler_args.max_lr: - values: [1.0e-1, 1.0e-5] - train_args.batch_size: - values: [64, 128] - train_args.epochs: - value: 5 diff --git a/training/sweep_emnist_resnet.yml b/training/sweep_emnist_resnet.yml deleted file mode 100644 index 19a3040..0000000 --- a/training/sweep_emnist_resnet.yml +++ /dev/null @@ -1,50 +0,0 @@ -program: training/run_sweep.py -method: bayes -metric: - name: val_accuracy - goal: maximize -parameters: - dataset: - value: EmnistDataset - model: - value: CharacterModel - network: - value: ResidualNetwork - network_args.block_sizes: - distribution: q_uniform - min: 16 - max: 256 - q: 8 - network_args.depths: - distribution: int_uniform - min: 1 - max: 3 - network_args.levels: - distribution: int_uniform - min: 1 - max: 2 - network_args.activation: - distribution: categorical - values: - - gelu - - leaky_relu - - relu - - selu - optimizer_args.lr: - distribution: uniform - min: 1.e-5 - max: 1.e-1 - lr_scheduler_args.max_lr: - distribution: uniform - min: 1.e-5 - max: 1.e-1 - train_args.batch_size: - distribution: q_uniform - min: 32 - max: 256 - q: 8 - train_args.epochs: - value: 5 -early_terminate: - type: hyperband - min_iter: 2 diff --git a/training/trainer/__init__.py b/training/trainer/__init__.py deleted file mode 100644 index de41bfb..0000000 --- a/training/trainer/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Trainer modules.""" -from .train import Trainer diff --git a/training/trainer/callbacks/__init__.py b/training/trainer/callbacks/__init__.py deleted file mode 100644 index 80c4177..0000000 --- a/training/trainer/callbacks/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -"""The callback modules used in the training script.""" -from .base import Callback, CallbackList -from .checkpoint import Checkpoint -from .early_stopping import EarlyStopping -from .lr_schedulers import ( - LRScheduler, - SWA, -) -from .progress_bar import ProgressBar -from .wandb_callbacks import ( - WandbCallback, - WandbImageLogger, - WandbReconstructionLogger, - WandbSegmentationLogger, -) - -__all__ = [ - "Callback", - "CallbackList", - "Checkpoint", - "EarlyStopping", - "LRScheduler", - "WandbCallback", - "WandbImageLogger", - "WandbReconstructionLogger", - "WandbSegmentationLogger", - "ProgressBar", - "SWA", -] diff --git a/training/trainer/callbacks/base.py b/training/trainer/callbacks/base.py deleted file mode 100644 index 500b642..0000000 --- a/training/trainer/callbacks/base.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Metaclass for callback functions.""" - -from enum import Enum -from typing import Callable, Dict, List, Optional, Type, Union - -from loguru import logger -import numpy as np -import torch - -from text_recognizer.models import Model - - -class ModeKeys: - """Mode keys for CallbackList.""" - - TRAIN = "train" - VALIDATION = "validation" - - -class Callback: - """Metaclass for callbacks used in training.""" - - def __init__(self) -> None: - """Initializes the Callback instance.""" - self.model = None - - def set_model(self, model: Type[Model]) -> None: - """Set the model.""" - self.model = model - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - pass - - def on_fit_end(self) -> None: - """Called when fit ends.""" - pass - - def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch. Only used in training mode.""" - pass - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch. Only used in training mode.""" - pass - - def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - pass - - def on_validation_batch_begin( - self, batch: int, logs: Optional[Dict] = None - ) -> None: - """Called at the beginning of an epoch.""" - pass - - def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - pass - - def on_test_begin(self) -> None: - """Called at the beginning of test.""" - pass - - def on_test_end(self) -> None: - """Called at the end of test.""" - pass - - -class CallbackList: - """Container for abstracting away callback calls.""" - - mode_keys = ModeKeys() - - def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None: - """Container for `Callback` instances. - - This object wraps a list of `Callback` instances and allows them all to be - called via a single end point. - - Args: - model (Type[Model]): A `Model` instance. - callbacks (List[Callback]): List of `Callback` instances. Defaults to None. - - """ - - self._callbacks = callbacks or [] - if model: - self.set_model(model) - - def set_model(self, model: Type[Model]) -> None: - """Set the model for all callbacks.""" - self.model = model - for callback in self._callbacks: - callback.set_model(model=self.model) - - def append(self, callback: Type[Callback]) -> None: - """Append new callback to callback list.""" - self._callbacks.append(callback) - - def on_fit_begin(self) -> None: - """Called when fit begins.""" - for callback in self._callbacks: - callback.on_fit_begin() - - def on_fit_end(self) -> None: - """Called when fit ends.""" - for callback in self._callbacks: - callback.on_fit_end() - - def on_test_begin(self) -> None: - """Called when test begins.""" - for callback in self._callbacks: - callback.on_test_begin() - - def on_test_end(self) -> None: - """Called when test ends.""" - for callback in self._callbacks: - callback.on_test_end() - - def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_begin(epoch, logs) - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - for callback in self._callbacks: - callback.on_epoch_end(epoch, logs) - - def _call_batch_hook( - self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for all batch_{begin | end} methods.""" - if hook == "begin": - self._call_batch_begin_hook(mode, batch, logs) - elif hook == "end": - self._call_batch_end_hook(mode, batch, logs) - else: - raise ValueError(f"Unrecognized hook {hook}.") - - def _call_batch_begin_hook( - self, mode: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for all `on_*_batch_begin` methods.""" - hook_name = f"on_{mode}_batch_begin" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_end_hook( - self, mode: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for all `on_*_batch_end` methods.""" - hook_name = f"on_{mode}_batch_end" - self._call_batch_hook_helper(hook_name, batch, logs) - - def _call_batch_hook_helper( - self, hook_name: str, batch: int, logs: Optional[Dict] = None - ) -> None: - """Helper function for `on_*_batch_begin` methods.""" - for callback in self._callbacks: - hook = getattr(callback, hook_name) - hook(batch, logs) - - def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) - - def on_validation_batch_begin( - self, batch: int, logs: Optional[Dict] = None - ) -> None: - """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) - - def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) - - def __iter__(self) -> iter: - """Iter function for callback list.""" - return iter(self._callbacks) diff --git a/training/trainer/callbacks/checkpoint.py b/training/trainer/callbacks/checkpoint.py deleted file mode 100644 index a54e0a9..0000000 --- a/training/trainer/callbacks/checkpoint.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Callback checkpoint for training models.""" -from enum import Enum -from pathlib import Path -from typing import Callable, Dict, List, Optional, Type, Union - -from loguru import logger -import numpy as np -import torch -from training.trainer.callbacks import Callback - -from text_recognizer.models import Model - - -class Checkpoint(Callback): - """Saving model parameters at the end of each epoch.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, - checkpoint_path: Union[str, Path], - monitor: str = "accuracy", - mode: str = "auto", - min_delta: float = 0.0, - ) -> None: - """Monitors a quantity that will allow us to determine the best model weights. - - Args: - checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint. - monitor (str): Name of the quantity to monitor. Defaults to "accuracy". - mode (str): Description of parameter `mode`. Defaults to "auto". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - - """ - super().__init__() - self.checkpoint_path = Path(checkpoint_path) - self.monitor = monitor - self.mode = mode - self.min_delta = torch.tensor(min_delta) - - if mode not in ["auto", "min", "max"]: - logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.") - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}." - ) - - torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Saves a checkpoint for the network parameters. - - Args: - epoch (int): The current epoch. - logs (Dict): The log containing the monitored metrics. - - """ - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - is_best = True - else: - is_best = False - - self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor) - - def get_monitor_value(self, logs: Dict) -> Union[float, None]: - """Extracts the monitored value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available" - + f" metrics are: {','.join(list(logs.keys()))}" - ) - return None - return monitor_value diff --git a/training/trainer/callbacks/early_stopping.py b/training/trainer/callbacks/early_stopping.py deleted file mode 100644 index 02b431f..0000000 --- a/training/trainer/callbacks/early_stopping.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Implements Early stopping for PyTorch model.""" -from typing import Dict, Union - -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from training.trainer.callbacks import Callback - - -class EarlyStopping(Callback): - """Stops training when a monitored metric stops improving.""" - - mode_dict = { - "min": torch.lt, - "max": torch.gt, - } - - def __init__( - self, - monitor: str = "val_loss", - min_delta: float = 0.0, - patience: int = 3, - mode: str = "auto", - ) -> None: - """Initializes the EarlyStopping callback. - - Args: - monitor (str): Description of parameter `monitor`. Defaults to "val_loss". - min_delta (float): Description of parameter `min_delta`. Defaults to 0.0. - patience (int): Description of parameter `patience`. Defaults to 3. - mode (str): Description of parameter `mode`. Defaults to "auto". - - """ - super().__init__() - self.monitor = monitor - self.patience = patience - self.min_delta = torch.tensor(min_delta) - self.mode = mode - self.wait_count = 0 - self.stopped_epoch = 0 - - if mode not in ["auto", "min", "max"]: - logger.warning( - f"EarlyStopping mode {mode} is unkown, fallback to auto mode." - ) - - self.mode = "auto" - - if self.mode == "auto": - if "accuracy" in self.monitor: - self.mode = "max" - else: - self.mode = "min" - logger.debug( - f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}." - ) - - self.torch_inf = torch.tensor(np.inf) - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = ( - self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf - ) - - @property - def monitor_op(self) -> float: - """Returns the comparison method.""" - return self.mode_dict[self.mode] - - def on_fit_begin(self) -> Union[torch.lt, torch.gt]: - """Reset the early stopping variables for reuse.""" - self.wait_count = 0 - self.stopped_epoch = 0 - self.best_score = ( - self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf - ) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Computes the early stop criterion.""" - current = self.get_monitor_value(logs) - if current is None: - return - if self.monitor_op(current - self.min_delta, self.best_score): - self.best_score = current - self.wait_count = 0 - else: - self.wait_count += 1 - if self.wait_count >= self.patience: - self.stopped_epoch = epoch - self.model.stop_training = True - - def on_fit_end(self) -> None: - """Logs if early stopping was used.""" - if self.stopped_epoch > 0: - logger.info( - f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." - ) - - def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: - """Extracts the monitor value.""" - monitor_value = logs.get(self.monitor) - if monitor_value is None: - logger.warning( - f"Early stopping is conditioned on metric {self.monitor} which is not available. Available" - + f"metrics are: {','.join(list(logs.keys()))}" - ) - return None - return torch.tensor(monitor_value) diff --git a/training/trainer/callbacks/lr_schedulers.py b/training/trainer/callbacks/lr_schedulers.py deleted file mode 100644 index 630c434..0000000 --- a/training/trainer/callbacks/lr_schedulers.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Callbacks for learning rate schedulers.""" -from typing import Callable, Dict, List, Optional, Type - -from torch.optim.swa_utils import update_bn -from training.trainer.callbacks import Callback - -from text_recognizer.models import Model - - -class LRScheduler(Callback): - """Generic learning rate scheduler callback.""" - - def __init__(self) -> None: - super().__init__() - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] - self.interval = self.model.lr_scheduler["interval"] - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - if self.interval == "epoch": - if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__: - self.lr_scheduler.step(logs["val_loss"]) - else: - self.lr_scheduler.step() - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every training batch.""" - if self.interval == "step": - self.lr_scheduler.step() - - -class SWA(Callback): - """Stochastic Weight Averaging callback.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - self.interval = None - self.swa_scheduler = None - self.swa_start = None - self.current_epoch = 1 - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] - self.interval = self.model.lr_scheduler["interval"] - self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"] - self.swa_start = self.model.swa_scheduler["swa_start"] - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every training batch.""" - if epoch > self.swa_start: - self.model.swa_network.update_parameters(self.model.network) - self.swa_scheduler.step() - elif self.interval == "epoch": - self.lr_scheduler.step() - self.current_epoch = epoch - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every training batch.""" - if self.current_epoch < self.swa_start and self.interval == "step": - self.lr_scheduler.step() - - def on_fit_end(self) -> None: - """Update batch norm statistics for the swa model at the end of training.""" - if self.model.swa_network: - update_bn( - self.model.val_dataloader(), - self.model.swa_network, - device=self.model.device, - ) diff --git a/training/trainer/callbacks/progress_bar.py b/training/trainer/callbacks/progress_bar.py deleted file mode 100644 index 6c4305a..0000000 --- a/training/trainer/callbacks/progress_bar.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Progress bar callback for the training loop.""" -from typing import Dict, Optional - -from tqdm import tqdm -from training.trainer.callbacks import Callback - - -class ProgressBar(Callback): - """A TQDM progress bar for the training loop.""" - - def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: - """Initializes the tqdm callback.""" - self.epochs = epochs - print(epochs, type(epochs)) - self.log_batch_frequency = log_batch_frequency - self.progress_bar = None - self.val_metrics = {} - - def _configure_progress_bar(self) -> None: - """Configures the tqdm progress bar with custom bar format.""" - self.progress_bar = tqdm( - total=len(self.model.train_dataloader()), - leave=False, - unit="steps", - mininterval=self.log_batch_frequency, - bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", - ) - - def _key_abbreviations(self, logs: Dict) -> Dict: - """Changes the length of keys, so that the progress bar fits better.""" - - def rename(key: str) -> str: - """Renames accuracy to acc.""" - return key.replace("accuracy", "acc") - - return {rename(key): value for key, value in logs.items()} - - # def on_fit_begin(self) -> None: - # """Creates a tqdm progress bar.""" - # self._configure_progress_bar() - - def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: - """Updates the description with the current epoch.""" - if epoch == 1: - self._configure_progress_bar() - else: - self.progress_bar.reset() - self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """At the end of each epoch, the validation metrics are updated to the progress bar.""" - self.val_metrics = logs - self.progress_bar.set_postfix(**self._key_abbreviations(logs)) - self.progress_bar.update() - - def on_train_batch_end(self, batch: int, logs: Dict) -> None: - """Updates the progress bar for each training step.""" - if self.val_metrics: - logs.update(self.val_metrics) - self.progress_bar.set_postfix(**self._key_abbreviations(logs)) - self.progress_bar.update() - - def on_fit_end(self) -> None: - """Closes the tqdm progress bar.""" - self.progress_bar.close() diff --git a/training/trainer/callbacks/wandb_callbacks.py b/training/trainer/callbacks/wandb_callbacks.py deleted file mode 100644 index 552a4f4..0000000 --- a/training/trainer/callbacks/wandb_callbacks.py +++ /dev/null @@ -1,261 +0,0 @@ -"""Callback for W&B.""" -from typing import Callable, Dict, List, Optional, Type - -import numpy as np -from training.trainer.callbacks import Callback -import wandb - -import text_recognizer.datasets.transforms as transforms -from text_recognizer.models.base import Model - - -class WandbCallback(Callback): - """A custom W&B metric logger for the trainer.""" - - def __init__(self, log_batch_frequency: int = None) -> None: - """Short summary. - - Args: - log_batch_frequency (int): If None, metrics will be logged every epoch. - If set to an integer, callback will log every metrics every log_batch_frequency. - - """ - super().__init__() - self.log_batch_frequency = log_batch_frequency - - def _on_batch_end(self, batch: int, logs: Dict) -> None: - if self.log_batch_frequency and batch % self.log_batch_frequency == 0: - wandb.log(logs, commit=True) - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Logs training metrics.""" - if logs is not None: - logs["lr"] = self.model.optimizer.param_groups[0]["lr"] - self._on_batch_end(batch, logs) - - def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Logs validation metrics.""" - if logs is not None: - self._on_batch_end(batch, logs) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Logs at epoch end.""" - wandb.log(logs, commit=True) - - -class WandbImageLogger(Callback): - """Custom W&B callback for image logging.""" - - def __init__( - self, - example_indices: Optional[List] = None, - num_examples: int = 4, - transform: Optional[bool] = None, - ) -> None: - """Initializes the WandbImageLogger with the model to train. - - Args: - example_indices (Optional[List]): Indices for validation images. Defaults to None. - num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - transform (Optional[Dict]): Use transform on image or not. Defaults to None. - - """ - - super().__init__() - self.caption = None - self.example_indices = example_indices - self.test_sample_indices = None - self.num_examples = num_examples - self.transform = ( - self._configure_transform(transform) if transform is not None else None - ) - - def _configure_transform(self, transform: Dict) -> Callable: - args = transform["args"] or {} - return getattr(transforms, transform["type"])(**args) - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and extracts validation images from the dataset.""" - self.model = model - self.caption = "Validation Examples" - if self.example_indices is None: - self.example_indices = np.random.randint( - 0, len(self.model.val_dataset), self.num_examples - ) - self.images = self.model.val_dataset.dataset.data[self.example_indices] - self.targets = self.model.val_dataset.dataset.targets[self.example_indices] - self.targets = self.targets.tolist() - - def on_test_begin(self) -> None: - """Get samples from test dataset.""" - self.caption = "Test Examples" - if self.test_sample_indices is None: - self.test_sample_indices = np.random.randint( - 0, len(self.model.test_dataset), self.num_examples - ) - self.images = self.model.test_dataset.data[self.test_sample_indices] - self.targets = self.model.test_dataset.targets[self.test_sample_indices] - self.targets = self.targets.tolist() - - def on_test_end(self) -> None: - """Log test images.""" - self.on_epoch_end(0, {}) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Get network predictions on validation images.""" - images = [] - for i, image in enumerate(self.images): - image = self.transform(image) if self.transform is not None else image - pred, conf = self.model.predict_on_image(image) - if isinstance(self.targets[i], list): - ground_truth = "".join( - [ - self.model.mapper(int(target_index) - 26) - if target_index > 35 - else self.model.mapper(int(target_index)) - for target_index in self.targets[i] - ] - ).rstrip("_") - else: - ground_truth = self.model.mapper(int(self.targets[i])) - caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}" - images.append(wandb.Image(image, caption=caption)) - - wandb.log({f"{self.caption}": images}, commit=False) - - -class WandbSegmentationLogger(Callback): - """Custom W&B callback for image logging.""" - - def __init__( - self, - class_labels: Dict, - example_indices: Optional[List] = None, - num_examples: int = 4, - ) -> None: - """Initializes the WandbImageLogger with the model to train. - - Args: - class_labels (Dict): A dict with int as key and class string as value. - example_indices (Optional[List]): Indices for validation images. Defaults to None. - num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - - """ - - super().__init__() - self.caption = None - self.class_labels = {int(k): v for k, v in class_labels.items()} - self.example_indices = example_indices - self.test_sample_indices = None - self.num_examples = num_examples - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and extracts validation images from the dataset.""" - self.model = model - self.caption = "Validation Segmentation Examples" - if self.example_indices is None: - self.example_indices = np.random.randint( - 0, len(self.model.val_dataset), self.num_examples - ) - self.images = self.model.val_dataset.dataset.data[self.example_indices] - self.targets = self.model.val_dataset.dataset.targets[self.example_indices] - self.targets = self.targets.tolist() - - def on_test_begin(self) -> None: - """Get samples from test dataset.""" - self.caption = "Test Segmentation Examples" - if self.test_sample_indices is None: - self.test_sample_indices = np.random.randint( - 0, len(self.model.test_dataset), self.num_examples - ) - self.images = self.model.test_dataset.data[self.test_sample_indices] - self.targets = self.model.test_dataset.targets[self.test_sample_indices] - self.targets = self.targets.tolist() - - def on_test_end(self) -> None: - """Log test images.""" - self.on_epoch_end(0, {}) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Get network predictions on validation images.""" - images = [] - for i, image in enumerate(self.images): - pred_mask = ( - self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() - ) - gt_mask = np.array(self.targets[i]) - images.append( - wandb.Image( - image, - masks={ - "predictions": { - "mask_data": pred_mask, - "class_labels": self.class_labels, - }, - "ground_truth": { - "mask_data": gt_mask, - "class_labels": self.class_labels, - }, - }, - ) - ) - - wandb.log({f"{self.caption}": images}, commit=False) - - -class WandbReconstructionLogger(Callback): - """Custom W&B callback for image reconstructions logging.""" - - def __init__( - self, example_indices: Optional[List] = None, num_examples: int = 4, - ) -> None: - """Initializes the WandbImageLogger with the model to train. - - Args: - example_indices (Optional[List]): Indices for validation images. Defaults to None. - num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4. - - """ - - super().__init__() - self.caption = None - self.example_indices = example_indices - self.test_sample_indices = None - self.num_examples = num_examples - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and extracts validation images from the dataset.""" - self.model = model - self.caption = "Validation Reconstructions Examples" - if self.example_indices is None: - self.example_indices = np.random.randint( - 0, len(self.model.val_dataset), self.num_examples - ) - self.images = self.model.val_dataset.dataset.data[self.example_indices] - - def on_test_begin(self) -> None: - """Get samples from test dataset.""" - self.caption = "Test Reconstructions Examples" - if self.test_sample_indices is None: - self.test_sample_indices = np.random.randint( - 0, len(self.model.test_dataset), self.num_examples - ) - self.images = self.model.test_dataset.data[self.test_sample_indices] - - def on_test_end(self) -> None: - """Log test images.""" - self.on_epoch_end(0, {}) - - def on_epoch_end(self, epoch: int, logs: Dict) -> None: - """Get network predictions on validation images.""" - images = [] - for image in self.images: - reconstructed_image = ( - self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy() - ) - images.append(image) - images.append(reconstructed_image) - - wandb.log( - {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False, - ) diff --git a/training/trainer/train.py b/training/trainer/train.py deleted file mode 100644 index b770c94..0000000 --- a/training/trainer/train.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Training script for PyTorch models.""" - -from pathlib import Path -import time -from typing import Dict, List, Optional, Tuple, Type -import warnings - -from einops import rearrange -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torch.optim.swa_utils import update_bn -from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA -from training.trainer.util import log_val_metric -import wandb - -from text_recognizer.models import Model - - -torch.backends.cudnn.benchmark = True -np.random.seed(4711) -torch.manual_seed(4711) -torch.cuda.manual_seed(4711) - - -warnings.filterwarnings("ignore") - - -class Trainer: - """Trainer for training PyTorch models.""" - - def __init__( - self, - max_epochs: int, - callbacks: List[Type[Callback]], - transformer_model: bool = False, - max_norm: float = 0.0, - freeze_backbone: Optional[int] = None, - ) -> None: - """Initialization of the Trainer. - - Args: - max_epochs (int): The maximum number of epochs in the training loop. - callbacks (CallbackList): List of callbacks to be called. - transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False. - max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0. - freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training - Transformers. Default is None. - - """ - # Training arguments. - self.start_epoch = 1 - self.max_epochs = max_epochs - self.callbacks = callbacks - self.freeze_backbone = freeze_backbone - - # Flag for setting callbacks. - self.callbacks_configured = False - - self.transformer_model = transformer_model - - self.max_norm = max_norm - - # Model placeholders - self.model = None - - def _configure_callbacks(self) -> None: - """Instantiate the CallbackList.""" - if not self.callbacks_configured: - # If learning rate schedulers are present, they need to be added to the callbacks. - if self.model.swa_scheduler is not None: - self.callbacks.append(SWA()) - elif self.model.lr_scheduler is not None: - self.callbacks.append(LRScheduler()) - - self.callbacks = CallbackList(self.model, self.callbacks) - - def compute_metrics( - self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int - ) -> Dict: - """Computes metrics for output and target pairs.""" - # Compute metrics. - loss = loss.detach().float().item() - output = output.detach() - targets = targets.detach() - if self.model.metrics is not None: - metrics = {} - for metric in self.model.metrics: - if metric == "cer" or metric == "wer": - metrics[metric] = self.model.metrics[metric]( - output, - targets, - batch_size, - self.model.mapper(self.model.pad_token), - ) - else: - metrics[metric] = self.model.metrics[metric](output, targets) - else: - metrics = {} - metrics["loss"] = loss - - return metrics - - def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: - """Performs the training step.""" - # Pass the tensor to the device for computation. - data, targets = samples - data, targets = ( - data.to(self.model.device), - targets.to(self.model.device), - ) - - batch_size = data.shape[0] - - # Placeholder for uxiliary loss. - aux_loss = None - - # Forward pass. - # Get the network prediction. - if self.transformer_model: - if self.freeze_backbone is not None and batch < self.freeze_backbone: - with torch.no_grad(): - image_features = self.model.network.extract_image_features(data) - - if isinstance(image_features, Tuple): - image_features, _ = image_features - - output = self.model.network.decode_image_features( - image_features, targets[:, :-1] - ) - else: - output = self.model.network.forward(data, targets[:, :-1]) - if isinstance(output, Tuple): - output, aux_loss = output - output = rearrange(output, "b t v -> (b t) v") - targets = rearrange(targets[:, 1:], "b t -> (b t)").long() - else: - output = self.model.forward(data) - - if isinstance(output, Tuple): - output, aux_loss = output - targets = data - - # Compute the loss. - loss = self.model.criterion(output, targets) - - if aux_loss is not None: - loss += aux_loss - - # Backward pass. - # Clear the previous gradients. - for p in self.model.network.parameters(): - p.grad = None - - # Compute the gradients. - loss.backward() - - if self.max_norm > 0: - torch.nn.utils.clip_grad_norm_( - self.model.network.parameters(), self.max_norm - ) - - # Perform updates using calculated gradients. - self.model.optimizer.step() - - metrics = self.compute_metrics(output, targets, loss, batch_size) - - return metrics - - def train(self) -> None: - """Runs the training loop for one epoch.""" - # Set model to traning mode. - self.model.train() - - for batch, samples in enumerate(self.model.train_dataloader()): - self.callbacks.on_train_batch_begin(batch) - metrics = self.training_step(batch, samples) - self.callbacks.on_train_batch_end(batch, logs=metrics) - - @torch.no_grad() - def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict: - """Performs the validation step.""" - - # Pass the tensor to the device for computation. - data, targets = samples - data, targets = ( - data.to(self.model.device), - targets.to(self.model.device), - ) - - batch_size = data.shape[0] - - # Placeholder for uxiliary loss. - aux_loss = None - - # Forward pass. - # Get the network prediction. - # Use SWA if available and using test dataset. - if self.transformer_model: - output = self.model.network.forward(data, targets[:, :-1]) - if isinstance(output, Tuple): - output, aux_loss = output - output = rearrange(output, "b t v -> (b t) v") - targets = rearrange(targets[:, 1:], "b t -> (b t)").long() - else: - output = self.model.forward(data) - - if isinstance(output, Tuple): - output, aux_loss = output - targets = data - - # Compute the loss. - loss = self.model.criterion(output, targets) - - if aux_loss is not None: - loss += aux_loss - - # Compute metrics. - metrics = self.compute_metrics(output, targets, loss, batch_size) - - return metrics - - def validate(self) -> Dict: - """Runs the validation loop for one epoch.""" - # Set model to eval mode. - self.model.eval() - - # Summary for the current eval loop. - summary = [] - - for batch, samples in enumerate(self.model.val_dataloader()): - self.callbacks.on_validation_batch_begin(batch) - metrics = self.validation_step(batch, samples) - self.callbacks.on_validation_batch_end(batch, logs=metrics) - summary.append(metrics) - - # Compute mean of all metrics. - metrics_mean = { - "val_" + metric: np.mean([x[metric] for x in summary]) - for metric in summary[0] - } - - return metrics_mean - - def fit(self, model: Type[Model]) -> None: - """Runs the training and evaluation loop.""" - - # Sets model, loads the data, criterion, and optimizers. - self.model = model - self.model.prepare_data() - self.model.configure_model() - - # Configure callbacks. - self._configure_callbacks() - - # Set start time. - t_start = time.time() - - self.callbacks.on_fit_begin() - - # Run the training loop. - for epoch in range(self.start_epoch, self.max_epochs + 1): - self.callbacks.on_epoch_begin(epoch) - - # Perform one training pass over the training set. - self.train() - - # Evaluate the model on the validation set. - val_metrics = self.validate() - log_val_metric(val_metrics, epoch) - - self.callbacks.on_epoch_end(epoch, logs=val_metrics) - - if self.model.stop_training: - break - - # Calculate the total training time. - t_end = time.time() - t_training = t_end - t_start - - self.callbacks.on_fit_end() - - logger.info(f"Training took {t_training:.2f} s.") - - # "Teardown". - self.model = None - - def test(self, model: Type[Model]) -> Dict: - """Run inference on test data.""" - - # Sets model, loads the data, criterion, and optimizers. - self.model = model - self.model.prepare_data() - self.model.configure_model() - - # Configure callbacks. - self._configure_callbacks() - - self.callbacks.on_test_begin() - - self.model.eval() - - # Check if SWA network is available. - self.model.use_swa_model() - - # Summary for the current test loop. - summary = [] - - for batch, samples in enumerate(self.model.test_dataloader()): - metrics = self.validation_step(batch, samples) - summary.append(metrics) - - self.callbacks.on_test_end() - - # Compute mean of all test metrics. - metrics_mean = { - "test_" + metric: np.mean([x[metric] for x in summary]) - for metric in summary[0] - } - - # "Teardown". - self.model = None - - return metrics_mean diff --git a/training/trainer/util.py b/training/trainer/util.py deleted file mode 100644 index 7cf1b45..0000000 --- a/training/trainer/util.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Utility functions for training neural networks.""" -from typing import Dict, Optional - -from loguru import logger - - -def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None: - """Logging of val metrics to file/terminal.""" - log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") - logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())) - - -class RunningAverage: - """Maintains a running average.""" - - def __init__(self) -> None: - """Initializes the parameters.""" - self.steps = 0 - self.total = 0 - - def update(self, val: float) -> None: - """Updates the parameters.""" - self.total += val - self.steps += 1 - - def __call__(self) -> float: - """Computes the running average.""" - return self.total / float(self.steps) |