summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 20:47:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 20:47:55 +0200
commit9ae5fa1a88899180f88ddb14d4cef457ceb847e5 (patch)
tree4fe2bcd82553c8062eb0908ae6442c123addf55d
parent9e54591b7e342edc93b0bb04809a0f54045c6a15 (diff)
Add new training loop with PyTorch Lightning, remove stale files
-rw-r--r--README.md25
-rw-r--r--notebooks/00-testing-stuff-out.ipynb46
-rw-r--r--text_recognizer/models/base.py20
-rw-r--r--text_recognizer/networks/loss/__init__.py2
-rw-r--r--text_recognizer/networks/loss/loss.py32
-rw-r--r--text_recognizer/networks/util.py27
-rw-r--r--text_recognizer/networks/vq_transformer.py150
-rw-r--r--training/experiments/default_config_emnist.yml70
-rw-r--r--training/experiments/embedding_experiment.yml64
-rw-r--r--training/experiments/sample_experiment.yml99
-rw-r--r--training/gpu_manager.py62
-rw-r--r--training/prepare_experiments.py34
-rw-r--r--training/run_experiment.py419
-rw-r--r--training/run_sweep.py92
-rw-r--r--training/sweep_emnist.yml26
-rw-r--r--training/sweep_emnist_resnet.yml50
-rw-r--r--training/trainer/__init__.py2
-rw-r--r--training/trainer/callbacks/__init__.py29
-rw-r--r--training/trainer/callbacks/base.py188
-rw-r--r--training/trainer/callbacks/checkpoint.py95
-rw-r--r--training/trainer/callbacks/early_stopping.py108
-rw-r--r--training/trainer/callbacks/lr_schedulers.py77
-rw-r--r--training/trainer/callbacks/progress_bar.py65
-rw-r--r--training/trainer/callbacks/wandb_callbacks.py261
-rw-r--r--training/trainer/train.py325
-rw-r--r--training/trainer/util.py28
26 files changed, 167 insertions, 2229 deletions
diff --git a/README.md b/README.md
index cfe37ff..29a71b1 100644
--- a/README.md
+++ b/README.md
@@ -32,11 +32,11 @@ poetry run build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb
- [x] transform that encodes iam targets to wordpieces
- [x] transducer loss function
- [ ] Train with word pieces
- - [ ] implement wandb callback for logging
+- [ ] Local attention in first layer of transformer
+- [ ] Halonet encoder
- [ ] Implement CPC
- - [ ] Window images
- - [ ] Train backbone
-- [ ] Bert training, how?
+ - [ ] https://arxiv.org/pdf/1905.09272.pdf
+ - [ ] https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html?highlight=byol
- [ ] Predictive coding
@@ -60,20 +60,3 @@ wandb agent $SWEEP_ID
```
-## PyTorch Performance Guide
-Tips and tricks from ["PyTorch Performance Tuning Guide - Szymon Migacz, NVIDIA"](https://www.youtube.com/watch?v=9mS1fIYj1So&t=125s):
-
-* Always better to use `num_workers > 0`, allows asynchronous data processing
-* Use `pin_memory=True` to allow data loading and computations to happen on the GPU in parallel.
-* Have to tune `num_workers` to use based on the problem, too many and data loading becomes slower.
-* For CNNs use `torch.backends.cudnn.benchmark=True`, allows cuDNN to select the best algorithm for convolutional computations (autotuner).
-* Increase batch size to max out GPU memory.
-* Use optimizer for large batch training, e.g. LARS, LAMB etc.
-* Set `bias=False` for convolutions directly followed by BatchNorm.
-* Use `for p in model.parameters(): p.grad = None` instead of `model.zero_grad()`.
-* Careful with disable debug APIs in prod (detect_anomaly, profiler, gradcheck).
-* Use `DistributedDataParallel` not `DataParallel`, uses 1 CPU core for each GPU.
-* Important to load balance compute on all GPUs, if variably-sized inputs or GPUs will idle.
-* Use an apex fused optimizer
-* Use checkpointing to recompute memory-intensive compute-efficient ops in backward pass (e.g. activations, upsampling), `torch.utils.checkpoint`.
-* Use `@torch.jit.script`, especially to fuse long sequences of pointwise operations like GELU.
diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb
index 0fa5a1b..81957f7 100644
--- a/notebooks/00-testing-stuff-out.ipynb
+++ b/notebooks/00-testing-stuff-out.ipynb
@@ -25,8 +25,52 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class Hej:\n",
+ " a = 2\n",
+ " \n",
+ "class Hejjj:\n",
+ " b = 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
"metadata": {},
+ "outputs": [],
+ "source": [
+ "l = [Hej(), Hejjj()]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<__main__.Hej at 0x7efefc77f370>"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "next(o for o in l if isinstance(o, Hej))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "scrolled": true
+ },
"outputs": [
{
"data": {
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 46e5136..2d6e435 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,5 +1,5 @@
"""Base PyTorch Lightning model."""
-from typing import Any, Dict, Tuple, Type
+from typing import Any, Dict, List, Tuple, Type
import madgrad
import pytorch_lightning as pl
@@ -40,7 +40,7 @@ class LitBaseModel(pl.LightningModule):
args = {} or criterion_args["args"]
return getattr(nn, criterion_args["type"])(**args)
- def configure_optimizer(self) -> Dict[str, Any]:
+ def configure_optimizer(self) -> Tuple[List[type], List[Dict[str, Any]]]:
"""Configures optimizer and lr scheduler."""
args = {} or self.optimizer_args["args"]
if self.optimizer_args["type"] == "MADGRAD":
@@ -48,15 +48,15 @@ class LitBaseModel(pl.LightningModule):
else:
optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args)
+ scheduler = {"monitor": self.monitor}
args = {} or self.lr_scheduler_args["args"]
- scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(
- **args
- )
- return {
- "optimizer": optimizer,
- "lr_scheduler": scheduler,
- "monitor": self.monitor,
- }
+ if "interval" in args:
+ scheduler["interval"] = args.pop("interval")
+
+ scheduler["scheduler"] = getattr(
+ torch.optim.lr_scheduler, self.lr_scheduler_args["type"]
+ )(**args)
+ return [optimizer], [scheduler]
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""
diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py
index b489264..cb83608 100644
--- a/text_recognizer/networks/loss/__init__.py
+++ b/text_recognizer/networks/loss/__init__.py
@@ -1,2 +1,2 @@
"""Loss module."""
-from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy
+from .loss import LabelSmoothingCrossEntropy
diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py
index cf9fa0d..d12dc9c 100644
--- a/text_recognizer/networks/loss/loss.py
+++ b/text_recognizer/networks/loss/loss.py
@@ -1,39 +1,9 @@
"""Implementations of custom loss functions."""
-from pytorch_metric_learning import distances, losses, miners, reducers
import torch
from torch import nn
from torch import Tensor
-from torch.autograd import Variable
-import torch.nn.functional as F
-__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"]
-
-
-class EmbeddingLoss:
- """Metric loss for training encoders to produce information-rich latent embeddings."""
-
- def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
- self.distance = distances.CosineSimilarity()
- self.reducer = reducers.ThresholdReducer(low=0)
- self.loss_fn = losses.TripletMarginLoss(
- margin=margin, distance=self.distance, reducer=self.reducer
- )
- self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
-
- def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
- """Computes the metric loss for the embeddings based on their labels.
-
- Args:
- embeddings (Tensor): The laten vectors encoded by the network.
- labels (Tensor): Labels of the embeddings.
-
- Returns:
- Tensor: The metric loss for the embeddings.
-
- """
- hard_pairs = self.miner(embeddings, labels)
- loss = self.loss_fn(embeddings, labels, hard_pairs)
- return loss
+__all__ = ["LabelSmoothingCrossEntropy"]
class LabelSmoothingCrossEntropy(nn.Module):
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index 131a6b4..d292680 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -1,38 +1,13 @@
"""Miscellaneous neural network functionality."""
import importlib
from pathlib import Path
-from typing import Dict, Tuple, Type
+from typing import Dict, Type
-from einops import rearrange
from loguru import logger
import torch
from torch import nn
-def sliding_window(
- images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
-) -> torch.Tensor:
- """Creates patches of an image.
-
- Args:
- images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
- patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
- stride (Tuple[int, int]): The stride of the sliding window.
-
- Returns:
- torch.Tensor: A tensor with the shape (batch, patches, height, width).
-
- """
- unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
- # Preform the sliding window, unsqueeze as the channel dimesion is lost.
- c = images.shape[1]
- patches = unfold(images)
- patches = rearrange(
- patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
- )
- return patches
-
-
def activation_function(activation: str) -> Type[nn.Module]:
"""Returns the callable activation function."""
activation_fns = nn.ModuleDict(
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
deleted file mode 100644
index c673d96..0000000
--- a/text_recognizer/networks/vq_transformer.py
+++ /dev/null
@@ -1,150 +0,0 @@
-"""A VQ-Transformer for image to text recognition."""
-from typing import Dict, Optional, Tuple
-
-from einops import rearrange, repeat
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.transformer import PositionalEncoding, Transformer
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.util import configure_backbone
-from text_recognizer.networks.vqvae.encoder import _ResidualBlock
-
-
-class VQTransformer(nn.Module):
- """VQ+Transfomer for image to character sequence prediction."""
-
- def __init__(
- self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- vocab_size: int,
- num_heads: int,
- adaptive_pool_dim: Tuple,
- expansion_dim: int,
- dropout_rate: float,
- trg_pad_index: int,
- max_len: int,
- backbone: str,
- backbone_args: Optional[Dict] = None,
- activation: str = "gelu",
- ) -> None:
- super().__init__()
-
- # Configure vector quantized backbone.
- self.backbone = configure_backbone(backbone, backbone_args)
- self.conv = nn.Sequential(
- nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2),
- nn.ReLU(inplace=True),
- )
-
- # Configure embeddings for Transformer network.
- self.trg_pad_index = trg_pad_index
- self.vocab_size = vocab_size
- self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
- self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
- self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
- nn.init.normal_(self.character_embedding.weight, std=0.02)
-
- self.adaptive_pool = (
- nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
- )
-
- self.transformer = Transformer(
- num_encoder_layers,
- num_decoder_layers,
- hidden_dim,
- num_heads,
- expansion_dim,
- dropout_rate,
- activation,
- )
-
- self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
-
- def _create_trg_mask(self, trg: Tensor) -> Tensor:
- # Move this outside the transformer.
- trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
- trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(
- torch.ones((trg_len, trg_len), device=trg.device)
- ).bool()
- trg_mask = trg_pad_mask & trg_sub_mask
- return trg_mask
-
- def encoder(self, src: Tensor) -> Tensor:
- """Forward pass with the encoder of the transformer."""
- return self.transformer.encoder(src)
-
- def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
- """Forward pass with the decoder of the transformer + classification head."""
- return self.head(
- self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
- )
-
- def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]:
- """Extracts image features with a backbone neural network.
-
- It seem like the winning idea was to swap channels and width dimension and collapse
- the height dimension. The transformer is learning like a baby with this implementation!!! :D
- Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
-
- Args:
- src (Tensor): Input tensor.
-
- Returns:
- Tensor: The input src to the transformer and the vq loss.
-
- """
- # If batch dimension is missing, it needs to be added.
- if len(src.shape) < 4:
- src = src[(None,) * (4 - len(src.shape))]
- src, vq_loss = self.backbone.encode(src)
- # src = self.backbone.decoder.res_block(src)
- src = self.conv(src)
-
- if self.adaptive_pool is not None:
- src = rearrange(src, "b c h w -> b w c h")
- src = self.adaptive_pool(src)
- src = src.squeeze(3)
- else:
- src = rearrange(src, "b c h w -> b (w h) c")
-
- b, t, _ = src.shape
-
- src += self.src_position_embedding[:, :t]
-
- return src, vq_loss
-
- def target_embedding(self, trg: Tensor) -> Tensor:
- """Encodes target tensor with embedding and postion.
-
- Args:
- trg (Tensor): Target tensor.
-
- Returns:
- Tensor: Encoded target tensor.
-
- """
- trg = self.character_embedding(trg.long())
- trg = self.trg_position_encoding(trg)
- return trg
-
- def decode_image_features(
- self, image_features: Tensor, trg: Optional[Tensor] = None
- ) -> Tensor:
- """Takes images features from the backbone and decodes them with the transformer."""
- trg_mask = self._create_trg_mask(trg)
- trg = self.target_embedding(trg)
- out = self.transformer(image_features, trg, trg_mask=trg_mask)
-
- logits = self.head(out)
- return logits
-
- def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Forward pass with CNN transfomer."""
- image_features, vq_loss = self.extract_image_features(x)
- logits = self.decode_image_features(image_features, trg)
- return logits, vq_loss
diff --git a/training/experiments/default_config_emnist.yml b/training/experiments/default_config_emnist.yml
deleted file mode 100644
index bf2ed0a..0000000
--- a/training/experiments/default_config_emnist.yml
+++ /dev/null
@@ -1,70 +0,0 @@
-dataset: EmnistDataset
-dataset_args:
- sample_to_balance: true
- subsample_fraction: 0.33
- transform: null
- target_transform: null
- seed: 4711
-
-data_loader_args:
- splits: [train, val]
- shuffle: true
- num_workers: 8
- cuda: true
-
-model: CharacterModel
-metrics: [accuracy]
-
-network_args:
- in_channels: 1
- num_classes: 80
- depths: [2]
- block_sizes: [256]
-
-train_args:
- batch_size: 256
- epochs: 5
-
-criterion: CrossEntropyLoss
-criterion_args:
- weight: null
- ignore_index: -100
- reduction: mean
-
-optimizer: AdamW
-optimizer_args:
- lr: 1.e-03
- betas: [0.9, 0.999]
- eps: 1.e-08
- # weight_decay: 5.e-4
- amsgrad: false
-
-lr_scheduler: OneCycleLR
-lr_scheduler_args:
- max_lr: 1.e-03
- epochs: 5
- anneal_strategy: linear
-
-
-callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR]
-callback_args:
- Checkpoint:
- monitor: val_accuracy
- ProgressBar:
- epochs: 5
- log_batch_frequency: 100
- EarlyStopping:
- monitor: val_loss
- min_delta: 0.0
- patience: 3
- mode: min
- WandbCallback:
- log_batch_frequency: 10
- WandbImageLogger:
- num_examples: 4
- OneCycleLR:
- null
-verbosity: 1 # 0, 1, 2
-resume_experiment: null
-train: true
-validation_metric: val_accuracy
diff --git a/training/experiments/embedding_experiment.yml b/training/experiments/embedding_experiment.yml
deleted file mode 100644
index 1e5f941..0000000
--- a/training/experiments/embedding_experiment.yml
+++ /dev/null
@@ -1,64 +0,0 @@
-experiment_group: Embedding Experiments
-experiments:
- - train_args:
- transformer_model: false
- batch_size: &batch_size 256
- max_epochs: &max_epochs 32
- input_shape: [[1, 28, 28]]
- dataset:
- type: EmnistDataset
- args:
- sample_to_balance: true
- subsample_fraction: null
- transform: null
- target_transform: null
- seed: 4711
- train_args:
- num_workers: 8
- train_fraction: 0.85
- batch_size: *batch_size
- model: CharacterModel
- metrics: []
- network:
- type: DenseNet
- args:
- growth_rate: 4
- block_config: [4, 4]
- in_channels: 1
- base_channels: 24
- num_classes: 128
- bn_size: 4
- dropout_rate: 0.1
- classifier: true
- activation: elu
- criterion:
- type: EmbeddingLoss
- args:
- margin: 0.2
- type_of_triplets: semihard
- optimizer:
- type: AdamW
- args:
- lr: 1.e-02
- betas: [0.9, 0.999]
- eps: 1.e-08
- weight_decay: 5.e-4
- amsgrad: false
- lr_scheduler:
- type: CosineAnnealingLR
- args:
- T_max: *max_epochs
- callbacks: [Checkpoint, ProgressBar, WandbCallback]
- callback_args:
- Checkpoint:
- monitor: val_loss
- mode: min
- ProgressBar:
- epochs: *max_epochs
- WandbCallback:
- log_batch_frequency: 10
- verbosity: 1 # 0, 1, 2
- resume_experiment: null
- train: true
- test: true
- test_metric: mean_average_precision_at_r
diff --git a/training/experiments/sample_experiment.yml b/training/experiments/sample_experiment.yml
deleted file mode 100644
index 8f94475..0000000
--- a/training/experiments/sample_experiment.yml
+++ /dev/null
@@ -1,99 +0,0 @@
-experiment_group: Sample Experiments
-experiments:
- - train_args:
- batch_size: 256
- max_epochs: &max_epochs 32
- dataset:
- type: EmnistDataset
- args:
- sample_to_balance: true
- subsample_fraction: null
- transform: null
- target_transform: null
- seed: 4711
- train_args:
- num_workers: 6
- train_fraction: 0.8
-
- model: CharacterModel
- metrics: [accuracy]
- # network: MLP
- # network_args:
- # input_size: 784
- # hidden_size: 512
- # output_size: 80
- # num_layers: 5
- # dropout_rate: 0.2
- # activation_fn: SELU
- network:
- type: ResidualNetwork
- args:
- in_channels: 1
- num_classes: 80
- depths: [2, 2]
- block_sizes: [64, 64]
- activation: leaky_relu
- # network:
- # type: WideResidualNetwork
- # args:
- # in_channels: 1
- # num_classes: 80
- # depth: 10
- # num_layers: 3
- # width_factor: 4
- # dropout_rate: 0.2
- # activation: SELU
- # network: LeNet
- # network_args:
- # output_size: 62
- # activation_fn: GELU
- criterion:
- type: CrossEntropyLoss
- args:
- weight: null
- ignore_index: -100
- reduction: mean
- optimizer:
- type: AdamW
- args:
- lr: 1.e-02
- betas: [0.9, 0.999]
- eps: 1.e-08
- # weight_decay: 5.e-4
- amsgrad: false
- # lr_scheduler:
- # type: OneCycleLR
- # args:
- # max_lr: 1.e-03
- # epochs: *max_epochs
- # anneal_strategy: linear
- lr_scheduler:
- type: CosineAnnealingLR
- args:
- T_max: *max_epochs
- interval: epoch
- swa_args:
- start: 2
- lr: 5.e-2
- callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping]
- callback_args:
- Checkpoint:
- monitor: val_accuracy
- ProgressBar:
- epochs: null
- log_batch_frequency: 100
- EarlyStopping:
- monitor: val_loss
- min_delta: 0.0
- patience: 5
- mode: min
- WandbCallback:
- log_batch_frequency: 10
- WandbImageLogger:
- num_examples: 4
- use_transpose: true
- verbosity: 0 # 0, 1, 2
- resume_experiment: null
- train: true
- test: true
- test_metric: test_accuracy
diff --git a/training/gpu_manager.py b/training/gpu_manager.py
deleted file mode 100644
index ce1b3dd..0000000
--- a/training/gpu_manager.py
+++ /dev/null
@@ -1,62 +0,0 @@
-"""GPUManager class."""
-import os
-import time
-from typing import Optional
-
-import gpustat
-from loguru import logger
-import numpy as np
-from redlock import Redlock
-
-
-GPU_LOCK_TIMEOUT = 5000 # ms
-
-
-class GPUManager:
- """Class for allocating GPUs."""
-
- def __init__(self, verbose: bool = False) -> None:
- """Initializes Redlock manager."""
- self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}])
- self.verbose = verbose
-
- def get_free_gpu(self) -> int:
- """Gets a free GPU.
-
- If some GPUs are available, try reserving one by checking out an exclusive redis lock.
- If none available or can not get lock, sleep and check again.
-
- Returns:
- int: The gpu index.
-
- """
- while True:
- gpu_index = self._get_free_gpu()
- if gpu_index is not None:
- return gpu_index
-
- if self.verbose:
- logger.debug(f"pid {os.getpid()} sleeping")
- time.sleep(GPU_LOCK_TIMEOUT / 1000)
-
- def _get_free_gpu(self) -> Optional[int]:
- """Fetches an available GPU index."""
- try:
- available_gpu_indices = [
- gpu.index
- for gpu in gpustat.GPUStatCollection.new_query()
- if gpu.memory_used < 0.5 * gpu.memory_total
- ]
- except Exception as e:
- logger.debug(f"Got the following exception: {e}")
- return None
-
- if available_gpu_indices:
- gpu_index = np.random.choice(available_gpu_indices)
- if self.verbose:
- logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}")
- if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT):
- return int(gpu_index)
- if self.verbose:
- logger.debug(f"pid {os.getpid()} could not get lock.")
- return None
diff --git a/training/prepare_experiments.py b/training/prepare_experiments.py
deleted file mode 100644
index 21997af..0000000
--- a/training/prepare_experiments.py
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Run a experiment from a config file."""
-import json
-
-import click
-import yaml
-
-
-def run_experiments(experiments_filename: str) -> None:
- """Run experiment from file."""
- with open(experiments_filename, "r") as f:
- experiments_config = yaml.safe_load(f)
-
- num_experiments = len(experiments_config["experiments"])
- for index in range(num_experiments):
- experiment_config = experiments_config["experiments"][index]
- experiment_config["experiment_group"] = experiments_config["experiment_group"]
- cmd = f"poetry run run-experiment --gpu=-1 --save '{json.dumps(experiment_config)}'"
- print(cmd)
-
-
-@click.command()
-@click.option(
- "--experiments_filename",
- required=True,
- type=str,
- help="Filename of Yaml file of experiments to run.",
-)
-def run_cli(experiments_filename: str) -> None:
- """Parse command-line arguments and run experiments from provided file."""
- run_experiments(experiments_filename)
-
-
-if __name__ == "__main__":
- run_cli()
diff --git a/training/run_experiment.py b/training/run_experiment.py
index faafea6..ff8b886 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -1,162 +1,34 @@
"""Script to run experiments."""
from datetime import datetime
-from glob import glob
import importlib
-import json
-import os
from pathlib import Path
-import re
-from typing import Callable, Dict, List, Optional, Tuple, Type
-import warnings
+from typing import Dict, List, Optional, Type
import click
from loguru import logger
import numpy as np
+from omegaconf import OmegaConf
+import pytorch_lightning as pl
import torch
+from torch import nn
from torchsummary import summary
from tqdm import tqdm
-from training.gpu_manager import GPUManager
-from training.trainer.callbacks import CallbackList
-from training.trainer.train import Trainer
import wandb
-import yaml
-import text_recognizer.models
-from text_recognizer.models import Model
-import text_recognizer.networks
-from text_recognizer.networks.loss import loss as custom_loss_module
+SEED = 4711
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
-def _get_level(verbose: int) -> int:
- """Sets the logger level."""
- levels = {0: 40, 1: 20, 2: 10}
- verbose = verbose if verbose <= 2 else 2
- return levels[verbose]
-
-
-def _create_experiment_dir(
- experiment_config: Dict, checkpoint: Optional[str] = None
-) -> Path:
- """Create new experiment."""
- EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
- experiment_dir = EXPERIMENTS_DIRNAME / (
- f"{experiment_config['model']}_"
- + f"{experiment_config['dataset']['type']}_"
- + f"{experiment_config['network']['type']}"
- )
-
- if checkpoint is None:
- experiment = datetime.now().strftime("%m%d_%H%M%S")
- logger.debug(f"Creating a new experiment called {experiment}")
- else:
- available_experiments = glob(str(experiment_dir) + "/*")
- available_experiments.sort()
- if checkpoint == "last":
- experiment = available_experiments[-1]
- logger.debug(f"Resuming the latest experiment {experiment}")
- else:
- experiment = checkpoint
- if not str(experiment_dir / experiment) in available_experiments:
- raise FileNotFoundError("Experiment does not exist.")
- logger.debug(f"Resuming the from experiment {checkpoint}")
-
- experiment_dir = experiment_dir / experiment
-
- # Create log and model directories.
- log_dir = experiment_dir / "log"
- model_dir = experiment_dir / "model"
-
- return experiment_dir, log_dir, model_dir
-
-
-def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]:
- """Loads all modules and arguments."""
- # Load the dataset module.
- dataset_args = experiment_config.get("dataset", {})
- dataset_ = dataset_args["type"]
-
- # Import the model module and model arguments.
- model_class_ = getattr(text_recognizer.models, experiment_config["model"])
-
- # Import metrics.
- metric_fns_ = (
- {
- metric: getattr(text_recognizer.networks, metric)
- for metric in experiment_config["metrics"]
- }
- if experiment_config["metrics"] is not None
- else None
- )
-
- # Import network module and arguments.
- network_fn_ = experiment_config["network"]["type"]
- network_args = experiment_config["network"].get("args", {})
-
- # Criterion
- if experiment_config["criterion"]["type"] in custom_loss_module.__all__:
- criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"])
- else:
- criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
- criterion_args = experiment_config["criterion"].get("args", {}) or {}
-
- # Optimizers
- optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
- optimizer_args = experiment_config["optimizer"].get("args", {})
-
- # Learning rate scheduler
- lr_scheduler_ = None
- lr_scheduler_args = None
- if "lr_scheduler" in experiment_config:
- lr_scheduler_ = getattr(
- torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"]
- )
- lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {}
-
- # SWA scheduler.
- if "swa_args" in experiment_config:
- swa_args = experiment_config.get("swa_args", {}) or {}
- else:
- swa_args = None
-
- model_args = {
- "dataset": dataset_,
- "dataset_args": dataset_args,
- "metrics": metric_fns_,
- "network_fn": network_fn_,
- "network_args": network_args,
- "criterion": criterion_,
- "criterion_args": criterion_args,
- "optimizer": optimizer_,
- "optimizer_args": optimizer_args,
- "lr_scheduler": lr_scheduler_,
- "lr_scheduler_args": lr_scheduler_args,
- "swa_args": swa_args,
- }
-
- return model_class_, model_args
-
-
-def _configure_callbacks(experiment_config: Dict, model_dir: Path) -> CallbackList:
- """Configure a callback list for trainer."""
- if "Checkpoint" in experiment_config["callback_args"]:
- experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = str(
- model_dir
- )
-
- # Initializes callbacks.
- callback_modules = importlib.import_module("training.trainer.callbacks")
- callbacks = []
- for callback in experiment_config["callbacks"]:
- args = experiment_config["callback_args"][callback] or {}
- callbacks.append(getattr(callback_modules, callback)(**args))
-
- return callbacks
+def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
+ """Configure the loguru logger for output to terminal and disk."""
+ def _get_level(verbose: int) -> int:
+ """Sets the logger level."""
+ levels = {0: 40, 1: 20, 2: 10}
+ verbose = verbose if verbose <= 2 else 2
+ return levels[verbose]
-def _configure_logger(log_dir: Path, verbose: int = 0) -> None:
- """Configure the loguru logger for output to terminal and disk."""
# Have to remove default logger to get tqdm to work properly.
logger.remove()
@@ -164,219 +36,138 @@ def _configure_logger(log_dir: Path, verbose: int = 0) -> None:
level = _get_level(verbose)
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)
- logger.add(
- str(log_dir / "train.log"),
- format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
- )
-
-
-def _save_config(experiment_dir: Path, experiment_config: Dict) -> None:
- """Copy config to experiment directory."""
- config_path = experiment_dir / "config.yml"
- with open(str(config_path), "w") as f:
- yaml.dump(experiment_config, f)
-
-
-def _load_from_checkpoint(
- model: Type[Model], model_dir: Path, pretrained_weights: str = None,
-) -> None:
- """If checkpoint exists, load model weights and optimizers from checkpoint."""
- # Get checkpoint path.
- if pretrained_weights is not None:
- logger.info(f"Loading weights from {pretrained_weights}.")
- checkpoint_path = (
- EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt"
+ if log_dir is not None:
+ logger.add(
+ str(log_dir / "train.log"),
+ format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
)
- else:
- logger.info(f"Loading weights from {model_dir}.")
- checkpoint_path = model_dir / "last.pt"
- if checkpoint_path.exists():
- logger.info("Loading and resuming training from checkpoint.")
- model.load_from_checkpoint(checkpoint_path)
-def evaluate_embedding(model: Type[Model]) -> Dict:
- """Evaluates the embedding space."""
- from pytorch_metric_learning import testers
- from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
+def _import_class(module_and_class_name: str) -> type:
+ """Import class from module."""
+ module_name, class_name = module_and_class_name.rsplit(".", 1)
+ module = importlib.import_module(module_name)
+ return getattr(module, class_name)
- accuracy_calculator = AccuracyCalculator(
- include=("mean_average_precision_at_r",), k=10
- )
- def get_all_embeddings(model: Type[Model]) -> Tuple:
- tester = testers.BaseTester()
- return tester.get_all_embeddings(model.test_dataset, model.network)
+def _configure_pl_callbacks(args: List[Dict]) -> List[Type[pl.callbacks.Callback]]:
+ """Configures PyTorch Lightning callbacks."""
+ pl_callbacks = [
+ getattr(pl.callbacks, callback["type"])(**callback["args"]) for callback in args
+ ]
+ return pl_callbacks
- embeddings, labels = get_all_embeddings(model)
- logger.info("Computing embedding accuracy")
- accuracies = accuracy_calculator.get_accuracy(
- embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True
- )
- logger.info(
- f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}"
- )
- return accuracies
+def _configure_wandb_callback(
+ network: Type[nn.Module], args: Dict
+) -> pl.loggers.WandbLogger:
+ """Configures wandb logger."""
+ pl_logger = pl.loggers.WandbLogger()
+ pl_logger.watch(network)
+ pl_logger.log_hyperparams(vars(args))
+ return pl_logger
-def run_experiment(
- experiment_config: Dict,
- save_weights: bool,
- device: str,
- use_wandb: bool,
- train: bool,
- test: bool,
- verbose: int = 0,
- checkpoint: Optional[str] = None,
- pretrained_weights: Optional[str] = None,
-) -> None:
- """Runs an experiment."""
- logger.info(f"Experiment config: {json.dumps(experiment_config)}")
- # Create new experiment.
- experiment_dir, log_dir, model_dir = _create_experiment_dir(
- experiment_config, checkpoint
+def _save_best_weights(
+ callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool
+) -> None:
+ """Saves the best model."""
+ model_checkpoint_callback = next(
+ callback
+ for callback in callbacks
+ if isinstance(callback, pl.callbacks.ModelCheckpoint)
)
+ best_model_path = model_checkpoint_callback.best_model_path
+ if best_model_path:
+ logger.info(f"Best model saved at: {best_model_path}")
+ if use_wandb:
+ logger.info("Uploading model to W&B...")
+ wandb.save(best_model_path)
- # Make sure the log/model directory exists.
- log_dir.mkdir(parents=True, exist_ok=True)
- model_dir.mkdir(parents=True, exist_ok=True)
-
- # Load the modules and model arguments.
- model_class_, model_args = _load_modules_and_arguments(experiment_config)
-
- # Initializes the model with experiment config.
- model = model_class_(**model_args, device=device)
-
- callbacks = _configure_callbacks(experiment_config, model_dir)
- # Setup logger.
- _configure_logger(log_dir, verbose)
+def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None:
+ """Runs experiment."""
+ logger.info("Starting experiment...")
- # Load from checkpoint if resuming an experiment.
- resume = False
- if checkpoint is not None or pretrained_weights is not None:
- # resume = True
- _load_from_checkpoint(model, model_dir, pretrained_weights)
+ # Seed everything in the experiment
+ logger.info(f"Seeding everthing with seed={SEED}")
+ pl.utilities.seed.seed_everything(SEED)
- logger.info(f"The class mapping is {model.mapping}")
+ # Load config.
+ logger.info(f"Loading config from: {path}")
+ config = OmegaConf.load(path)
- # Initializes Weights & Biases
- if use_wandb:
- wandb.init(project="text-recognizer", config=experiment_config, resume=resume)
+ # Load classes
+ data_module_class = _import_class(f"text_recognizer.data.{config.data.type}")
+ network_class = _import_class(f"text_recognizer.networks.{config.network.type}")
+ lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}")
- # Lets W&B save the model and track the gradients and optional parameters.
- wandb.watch(model.network)
+ # Initialize data object and network.
+ data_module = data_module_class(**config.data.args)
+ network = network_class(**config.network.args)
- experiment_config["experiment_group"] = experiment_config.get(
- "experiment_group", None
+ # Load callback and logger
+ callbacks = _configure_pl_callbacks(config.callbacks)
+ pl_logger = (
+ _configure_wandb_callback(network, config.network.args)
+ if use_wandb
+ else pl.logger.TensorBoardLogger("training/logs")
)
- experiment_config["device"] = device
-
- # Save the config used in the experiment folder.
- _save_config(experiment_dir, experiment_config)
-
- # Prints a summary of the network in terminal.
- model.summary(experiment_config["train_args"]["input_shape"])
+ # Checkpoint
+ if config.load_checkpoint is not None:
+ logger.info(
+ f"Loading network weights from checkpoint: {config.load_checkpoint}"
+ )
+ lit_model = lit_model_class.load_from_checkpoint(
+ config.load_checkpoint, network=network, **config.model.args
+ )
+ else:
+ lit_model = lit_model_class(**config.model.args)
- # Load trainer.
- trainer = Trainer(
- max_epochs=experiment_config["train_args"]["max_epochs"],
+ trainer = pl.Trainer(
+ **config.trainer,
callbacks=callbacks,
- transformer_model=experiment_config["train_args"]["transformer_model"],
- max_norm=experiment_config["train_args"]["max_norm"],
- freeze_backbone=experiment_config["train_args"]["freeze_backbone"],
+ logger=pl_logger,
+ weigths_save_path="training/logs",
)
- # Train the model.
+ if tune:
+ logger.info(f"Tuning learning rate and batch size...")
+ trainer.tune(lit_model, datamodule=data_module)
+
if train:
- trainer.fit(model)
+ logger.info(f"Training network...")
+ trainer.fit(lit_model, datamodule=data_module)
- # Run inference over test set.
if test:
- logger.info("Loading checkpoint with the best weights.")
- if "checkpoint" in experiment_config["train_args"]:
- model.load_from_checkpoint(
- model_dir / experiment_config["train_args"]["checkpoint"]
- )
- else:
- model.load_from_checkpoint(model_dir / "best.pt")
-
- logger.info("Running inference on test set.")
- if experiment_config["criterion"]["type"] == "EmbeddingLoss":
- logger.info("Evaluating embedding.")
- score = evaluate_embedding(model)
- else:
- score = trainer.test(model)
-
- logger.info(f"Test set evaluation: {score}")
-
- if use_wandb:
- wandb.log(
- {
- experiment_config["test_metric"]: score[
- experiment_config["test_metric"]
- ]
- }
- )
+ logger.info(f"Testing network...")
+ trainer.test(lit_model, datamodule=data_module)
- if save_weights:
- model.save_weights(model_dir)
+ _save_best_weights(callbacks, use_wandb)
@click.command()
-@click.argument("experiment_config",)
-@click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.")
+@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.")
+@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.")
@click.option(
- "--save",
- is_flag=True,
- help="If set, the final weights will be saved to a canonical, version-controlled location.",
-)
-@click.option(
- "--nowandb", is_flag=False, help="If true, do not use wandb for this run."
+ "--tune", is_flag=True, help="If true, tune hyperparameters for training."
)
+@click.option("--train", is_flag=True, help="If true, train the model.")
@click.option("--test", is_flag=True, help="If true, test the model.")
@click.option("-v", "--verbose", count=True)
-@click.option("--checkpoint", type=str, help="Path to the experiment.")
-@click.option(
- "--pretrained_weights", type=str, help="Path to pretrained model weights."
-)
-@click.option(
- "--notrain", is_flag=False, help="Do not train the model.",
-)
-def run_cli(
+def cli(
experiment_config: str,
- gpu: int,
- save: bool,
- nowandb: bool,
- notrain: bool,
+ use_wandb: bool,
+ tune: bool,
+ train: bool,
test: bool,
verbose: int,
- checkpoint: Optional[str] = None,
- pretrained_weights: Optional[str] = None,
) -> None:
"""Run experiment."""
- if gpu < 0:
- gpu_manager = GPUManager(True)
- gpu = gpu_manager.get_free_gpu()
- device = "cuda:" + str(gpu)
-
- experiment_config = json.loads(experiment_config)
- os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}"
-
- run_experiment(
- experiment_config,
- save,
- device,
- use_wandb=not nowandb,
- train=not notrain,
- test=test,
- verbose=verbose,
- checkpoint=checkpoint,
- pretrained_weights=pretrained_weights,
- )
+ _configure_logging(None, verbose=verbose)
+ run(path=experiment_config, train=train, test=test, tune=tune, use_wandb=use_wandb)
if __name__ == "__main__":
- run_cli()
+ cli()
diff --git a/training/run_sweep.py b/training/run_sweep.py
deleted file mode 100644
index a578592..0000000
--- a/training/run_sweep.py
+++ /dev/null
@@ -1,92 +0,0 @@
-"""W&B Sweep Functionality."""
-from ast import literal_eval
-import json
-import os
-from pathlib import Path
-import signal
-import subprocess # nosec
-import sys
-from typing import Dict, List, Tuple
-
-import click
-import yaml
-
-EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
-
-
-def load_config() -> Dict:
- """Load base hyperparameter config."""
- with open(str(EXPERIMENTS_DIRNAME / "default_config_emnist.yml"), "r") as f:
- default_config = yaml.safe_load(f)
- return default_config
-
-
-def args_to_json(
- default_config: dict, preserve_args: tuple = ("gpu", "save")
-) -> Tuple[dict, list]:
- """Convert command line arguments to nested config values.
-
- i.e. run_sweep.py --dataset_args.foo=1.7
- {
- "dataset_args": {
- "foo": 1.7
- }
- }
-
- Args:
- default_config (dict): The base config used for every experiment.
- preserve_args (tuple): Arguments preserved for all runs. Defaults to ("gpu", "save").
-
- Returns:
- Tuple[dict, list]: Tuple of config dictionary and list of arguments.
-
- """
-
- args = []
- config = default_config.copy()
- key, val = None, None
- for arg in sys.argv[1:]:
- if "=" in arg:
- key, val = arg.split("=")
- elif key:
- val = arg
- else:
- key = arg
- if key and val:
- parsed_key = key.lstrip("-").split(".")
- if parsed_key[0] in preserve_args:
- args.append("--{}={}".format(parsed_key[0], val))
- else:
- nested = config
- for level in parsed_key[:-1]:
- nested[level] = config.get(level, {})
- nested = nested[level]
- try:
- # Convert numerics to floats / ints
- val = literal_eval(val)
- except ValueError:
- pass
- nested[parsed_key[-1]] = val
- key, val = None, None
- return config, args
-
-
-def main() -> None:
- """Runs a W&B sweep."""
- default_config = load_config()
- config, args = args_to_json(default_config)
- env = {
- k: v for k, v in os.environ.items() if k not in ("WANDB_PROGRAM", "WANDB_ARGS")
- }
- # pylint: disable=subprocess-popen-preexec-fn
- run = subprocess.Popen(
- ["python", "training/run_experiment.py", *args, json.dumps(config)],
- env=env,
- preexec_fn=os.setsid,
- ) # nosec
- signal.signal(signal.SIGTERM, lambda *args: run.terminate())
- run.wait()
-
-
-if __name__ == "__main__":
- main()
diff --git a/training/sweep_emnist.yml b/training/sweep_emnist.yml
deleted file mode 100644
index 48d7261..0000000
--- a/training/sweep_emnist.yml
+++ /dev/null
@@ -1,26 +0,0 @@
-program: training/run_sweep.py
-method: bayes
-metric:
- name: val_loss
- goal: minimize
-parameters:
- dataset:
- value: EmnistDataset
- model:
- value: CharacterModel
- network:
- value: MLP
- network_args.hidden_size:
- values: [128, 256]
- network_args.dropout_rate:
- values: [0.2, 0.4]
- network_args.num_layers:
- values: [3, 6]
- optimizer_args.lr:
- values: [1.e-1, 1.e-5]
- lr_scheduler_args.max_lr:
- values: [1.0e-1, 1.0e-5]
- train_args.batch_size:
- values: [64, 128]
- train_args.epochs:
- value: 5
diff --git a/training/sweep_emnist_resnet.yml b/training/sweep_emnist_resnet.yml
deleted file mode 100644
index 19a3040..0000000
--- a/training/sweep_emnist_resnet.yml
+++ /dev/null
@@ -1,50 +0,0 @@
-program: training/run_sweep.py
-method: bayes
-metric:
- name: val_accuracy
- goal: maximize
-parameters:
- dataset:
- value: EmnistDataset
- model:
- value: CharacterModel
- network:
- value: ResidualNetwork
- network_args.block_sizes:
- distribution: q_uniform
- min: 16
- max: 256
- q: 8
- network_args.depths:
- distribution: int_uniform
- min: 1
- max: 3
- network_args.levels:
- distribution: int_uniform
- min: 1
- max: 2
- network_args.activation:
- distribution: categorical
- values:
- - gelu
- - leaky_relu
- - relu
- - selu
- optimizer_args.lr:
- distribution: uniform
- min: 1.e-5
- max: 1.e-1
- lr_scheduler_args.max_lr:
- distribution: uniform
- min: 1.e-5
- max: 1.e-1
- train_args.batch_size:
- distribution: q_uniform
- min: 32
- max: 256
- q: 8
- train_args.epochs:
- value: 5
-early_terminate:
- type: hyperband
- min_iter: 2
diff --git a/training/trainer/__init__.py b/training/trainer/__init__.py
deleted file mode 100644
index de41bfb..0000000
--- a/training/trainer/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""Trainer modules."""
-from .train import Trainer
diff --git a/training/trainer/callbacks/__init__.py b/training/trainer/callbacks/__init__.py
deleted file mode 100644
index 80c4177..0000000
--- a/training/trainer/callbacks/__init__.py
+++ /dev/null
@@ -1,29 +0,0 @@
-"""The callback modules used in the training script."""
-from .base import Callback, CallbackList
-from .checkpoint import Checkpoint
-from .early_stopping import EarlyStopping
-from .lr_schedulers import (
- LRScheduler,
- SWA,
-)
-from .progress_bar import ProgressBar
-from .wandb_callbacks import (
- WandbCallback,
- WandbImageLogger,
- WandbReconstructionLogger,
- WandbSegmentationLogger,
-)
-
-__all__ = [
- "Callback",
- "CallbackList",
- "Checkpoint",
- "EarlyStopping",
- "LRScheduler",
- "WandbCallback",
- "WandbImageLogger",
- "WandbReconstructionLogger",
- "WandbSegmentationLogger",
- "ProgressBar",
- "SWA",
-]
diff --git a/training/trainer/callbacks/base.py b/training/trainer/callbacks/base.py
deleted file mode 100644
index 500b642..0000000
--- a/training/trainer/callbacks/base.py
+++ /dev/null
@@ -1,188 +0,0 @@
-"""Metaclass for callback functions."""
-
-from enum import Enum
-from typing import Callable, Dict, List, Optional, Type, Union
-
-from loguru import logger
-import numpy as np
-import torch
-
-from text_recognizer.models import Model
-
-
-class ModeKeys:
- """Mode keys for CallbackList."""
-
- TRAIN = "train"
- VALIDATION = "validation"
-
-
-class Callback:
- """Metaclass for callbacks used in training."""
-
- def __init__(self) -> None:
- """Initializes the Callback instance."""
- self.model = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Set the model."""
- self.model = model
-
- def on_fit_begin(self) -> None:
- """Called when fit begins."""
- pass
-
- def on_fit_end(self) -> None:
- """Called when fit ends."""
- pass
-
- def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Called at the beginning of an epoch. Only used in training mode."""
- pass
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch. Only used in training mode."""
- pass
-
- def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the beginning of an epoch."""
- pass
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch."""
- pass
-
- def on_validation_batch_begin(
- self, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Called at the beginning of an epoch."""
- pass
-
- def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch."""
- pass
-
- def on_test_begin(self) -> None:
- """Called at the beginning of test."""
- pass
-
- def on_test_end(self) -> None:
- """Called at the end of test."""
- pass
-
-
-class CallbackList:
- """Container for abstracting away callback calls."""
-
- mode_keys = ModeKeys()
-
- def __init__(self, model: Type[Model], callbacks: List[Callback] = None) -> None:
- """Container for `Callback` instances.
-
- This object wraps a list of `Callback` instances and allows them all to be
- called via a single end point.
-
- Args:
- model (Type[Model]): A `Model` instance.
- callbacks (List[Callback]): List of `Callback` instances. Defaults to None.
-
- """
-
- self._callbacks = callbacks or []
- if model:
- self.set_model(model)
-
- def set_model(self, model: Type[Model]) -> None:
- """Set the model for all callbacks."""
- self.model = model
- for callback in self._callbacks:
- callback.set_model(model=self.model)
-
- def append(self, callback: Type[Callback]) -> None:
- """Append new callback to callback list."""
- self._callbacks.append(callback)
-
- def on_fit_begin(self) -> None:
- """Called when fit begins."""
- for callback in self._callbacks:
- callback.on_fit_begin()
-
- def on_fit_end(self) -> None:
- """Called when fit ends."""
- for callback in self._callbacks:
- callback.on_fit_end()
-
- def on_test_begin(self) -> None:
- """Called when test begins."""
- for callback in self._callbacks:
- callback.on_test_begin()
-
- def on_test_end(self) -> None:
- """Called when test ends."""
- for callback in self._callbacks:
- callback.on_test_end()
-
- def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Called at the beginning of an epoch."""
- for callback in self._callbacks:
- callback.on_epoch_begin(epoch, logs)
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch."""
- for callback in self._callbacks:
- callback.on_epoch_end(epoch, logs)
-
- def _call_batch_hook(
- self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Helper function for all batch_{begin | end} methods."""
- if hook == "begin":
- self._call_batch_begin_hook(mode, batch, logs)
- elif hook == "end":
- self._call_batch_end_hook(mode, batch, logs)
- else:
- raise ValueError(f"Unrecognized hook {hook}.")
-
- def _call_batch_begin_hook(
- self, mode: str, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Helper function for all `on_*_batch_begin` methods."""
- hook_name = f"on_{mode}_batch_begin"
- self._call_batch_hook_helper(hook_name, batch, logs)
-
- def _call_batch_end_hook(
- self, mode: str, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Helper function for all `on_*_batch_end` methods."""
- hook_name = f"on_{mode}_batch_end"
- self._call_batch_hook_helper(hook_name, batch, logs)
-
- def _call_batch_hook_helper(
- self, hook_name: str, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Helper function for `on_*_batch_begin` methods."""
- for callback in self._callbacks:
- hook = getattr(callback, hook_name)
- hook(batch, logs)
-
- def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)
-
- def on_validation_batch_begin(
- self, batch: int, logs: Optional[Dict] = None
- ) -> None:
- """Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)
-
- def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)
-
- def __iter__(self) -> iter:
- """Iter function for callback list."""
- return iter(self._callbacks)
diff --git a/training/trainer/callbacks/checkpoint.py b/training/trainer/callbacks/checkpoint.py
deleted file mode 100644
index a54e0a9..0000000
--- a/training/trainer/callbacks/checkpoint.py
+++ /dev/null
@@ -1,95 +0,0 @@
-"""Callback checkpoint for training models."""
-from enum import Enum
-from pathlib import Path
-from typing import Callable, Dict, List, Optional, Type, Union
-
-from loguru import logger
-import numpy as np
-import torch
-from training.trainer.callbacks import Callback
-
-from text_recognizer.models import Model
-
-
-class Checkpoint(Callback):
- """Saving model parameters at the end of each epoch."""
-
- mode_dict = {
- "min": torch.lt,
- "max": torch.gt,
- }
-
- def __init__(
- self,
- checkpoint_path: Union[str, Path],
- monitor: str = "accuracy",
- mode: str = "auto",
- min_delta: float = 0.0,
- ) -> None:
- """Monitors a quantity that will allow us to determine the best model weights.
-
- Args:
- checkpoint_path (Union[str, Path]): Path to the experiment with the checkpoint.
- monitor (str): Name of the quantity to monitor. Defaults to "accuracy".
- mode (str): Description of parameter `mode`. Defaults to "auto".
- min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
-
- """
- super().__init__()
- self.checkpoint_path = Path(checkpoint_path)
- self.monitor = monitor
- self.mode = mode
- self.min_delta = torch.tensor(min_delta)
-
- if mode not in ["auto", "min", "max"]:
- logger.warning(f"Checkpoint mode {mode} is unkown, fallback to auto mode.")
-
- self.mode = "auto"
-
- if self.mode == "auto":
- if "accuracy" in self.monitor:
- self.mode = "max"
- else:
- self.mode = "min"
- logger.debug(
- f"Checkpoint mode set to {self.mode} for monitoring {self.monitor}."
- )
-
- torch_inf = torch.tensor(np.inf)
- self.min_delta *= 1 if self.monitor_op == torch.gt else -1
- self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
-
- @property
- def monitor_op(self) -> float:
- """Returns the comparison method."""
- return self.mode_dict[self.mode]
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Saves a checkpoint for the network parameters.
-
- Args:
- epoch (int): The current epoch.
- logs (Dict): The log containing the monitored metrics.
-
- """
- current = self.get_monitor_value(logs)
- if current is None:
- return
- if self.monitor_op(current - self.min_delta, self.best_score):
- self.best_score = current
- is_best = True
- else:
- is_best = False
-
- self.model.save_checkpoint(self.checkpoint_path, is_best, epoch, self.monitor)
-
- def get_monitor_value(self, logs: Dict) -> Union[float, None]:
- """Extracts the monitored value."""
- monitor_value = logs.get(self.monitor)
- if monitor_value is None:
- logger.warning(
- f"Checkpoint is conditioned on metric {self.monitor} which is not available. Available"
- + f" metrics are: {','.join(list(logs.keys()))}"
- )
- return None
- return monitor_value
diff --git a/training/trainer/callbacks/early_stopping.py b/training/trainer/callbacks/early_stopping.py
deleted file mode 100644
index 02b431f..0000000
--- a/training/trainer/callbacks/early_stopping.py
+++ /dev/null
@@ -1,108 +0,0 @@
-"""Implements Early stopping for PyTorch model."""
-from typing import Dict, Union
-
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from training.trainer.callbacks import Callback
-
-
-class EarlyStopping(Callback):
- """Stops training when a monitored metric stops improving."""
-
- mode_dict = {
- "min": torch.lt,
- "max": torch.gt,
- }
-
- def __init__(
- self,
- monitor: str = "val_loss",
- min_delta: float = 0.0,
- patience: int = 3,
- mode: str = "auto",
- ) -> None:
- """Initializes the EarlyStopping callback.
-
- Args:
- monitor (str): Description of parameter `monitor`. Defaults to "val_loss".
- min_delta (float): Description of parameter `min_delta`. Defaults to 0.0.
- patience (int): Description of parameter `patience`. Defaults to 3.
- mode (str): Description of parameter `mode`. Defaults to "auto".
-
- """
- super().__init__()
- self.monitor = monitor
- self.patience = patience
- self.min_delta = torch.tensor(min_delta)
- self.mode = mode
- self.wait_count = 0
- self.stopped_epoch = 0
-
- if mode not in ["auto", "min", "max"]:
- logger.warning(
- f"EarlyStopping mode {mode} is unkown, fallback to auto mode."
- )
-
- self.mode = "auto"
-
- if self.mode == "auto":
- if "accuracy" in self.monitor:
- self.mode = "max"
- else:
- self.mode = "min"
- logger.debug(
- f"EarlyStopping mode set to {self.mode} for monitoring {self.monitor}."
- )
-
- self.torch_inf = torch.tensor(np.inf)
- self.min_delta *= 1 if self.monitor_op == torch.gt else -1
- self.best_score = (
- self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
- )
-
- @property
- def monitor_op(self) -> float:
- """Returns the comparison method."""
- return self.mode_dict[self.mode]
-
- def on_fit_begin(self) -> Union[torch.lt, torch.gt]:
- """Reset the early stopping variables for reuse."""
- self.wait_count = 0
- self.stopped_epoch = 0
- self.best_score = (
- self.torch_inf if self.monitor_op == torch.lt else -self.torch_inf
- )
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Computes the early stop criterion."""
- current = self.get_monitor_value(logs)
- if current is None:
- return
- if self.monitor_op(current - self.min_delta, self.best_score):
- self.best_score = current
- self.wait_count = 0
- else:
- self.wait_count += 1
- if self.wait_count >= self.patience:
- self.stopped_epoch = epoch
- self.model.stop_training = True
-
- def on_fit_end(self) -> None:
- """Logs if early stopping was used."""
- if self.stopped_epoch > 0:
- logger.info(
- f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
- )
-
- def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]:
- """Extracts the monitor value."""
- monitor_value = logs.get(self.monitor)
- if monitor_value is None:
- logger.warning(
- f"Early stopping is conditioned on metric {self.monitor} which is not available. Available"
- + f"metrics are: {','.join(list(logs.keys()))}"
- )
- return None
- return torch.tensor(monitor_value)
diff --git a/training/trainer/callbacks/lr_schedulers.py b/training/trainer/callbacks/lr_schedulers.py
deleted file mode 100644
index 630c434..0000000
--- a/training/trainer/callbacks/lr_schedulers.py
+++ /dev/null
@@ -1,77 +0,0 @@
-"""Callbacks for learning rate schedulers."""
-from typing import Callable, Dict, List, Optional, Type
-
-from torch.optim.swa_utils import update_bn
-from training.trainer.callbacks import Callback
-
-from text_recognizer.models import Model
-
-
-class LRScheduler(Callback):
- """Generic learning rate scheduler callback."""
-
- def __init__(self) -> None:
- super().__init__()
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
- self.interval = self.model.lr_scheduler["interval"]
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every epoch."""
- if self.interval == "epoch":
- if "ReduceLROnPlateau" in self.lr_scheduler.__class__.__name__:
- self.lr_scheduler.step(logs["val_loss"])
- else:
- self.lr_scheduler.step()
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every training batch."""
- if self.interval == "step":
- self.lr_scheduler.step()
-
-
-class SWA(Callback):
- """Stochastic Weight Averaging callback."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
- self.interval = None
- self.swa_scheduler = None
- self.swa_start = None
- self.current_epoch = 1
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
- self.interval = self.model.lr_scheduler["interval"]
- self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"]
- self.swa_start = self.model.swa_scheduler["swa_start"]
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every training batch."""
- if epoch > self.swa_start:
- self.model.swa_network.update_parameters(self.model.network)
- self.swa_scheduler.step()
- elif self.interval == "epoch":
- self.lr_scheduler.step()
- self.current_epoch = epoch
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every training batch."""
- if self.current_epoch < self.swa_start and self.interval == "step":
- self.lr_scheduler.step()
-
- def on_fit_end(self) -> None:
- """Update batch norm statistics for the swa model at the end of training."""
- if self.model.swa_network:
- update_bn(
- self.model.val_dataloader(),
- self.model.swa_network,
- device=self.model.device,
- )
diff --git a/training/trainer/callbacks/progress_bar.py b/training/trainer/callbacks/progress_bar.py
deleted file mode 100644
index 6c4305a..0000000
--- a/training/trainer/callbacks/progress_bar.py
+++ /dev/null
@@ -1,65 +0,0 @@
-"""Progress bar callback for the training loop."""
-from typing import Dict, Optional
-
-from tqdm import tqdm
-from training.trainer.callbacks import Callback
-
-
-class ProgressBar(Callback):
- """A TQDM progress bar for the training loop."""
-
- def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
- """Initializes the tqdm callback."""
- self.epochs = epochs
- print(epochs, type(epochs))
- self.log_batch_frequency = log_batch_frequency
- self.progress_bar = None
- self.val_metrics = {}
-
- def _configure_progress_bar(self) -> None:
- """Configures the tqdm progress bar with custom bar format."""
- self.progress_bar = tqdm(
- total=len(self.model.train_dataloader()),
- leave=False,
- unit="steps",
- mininterval=self.log_batch_frequency,
- bar_format="{desc} |{bar:32}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
- )
-
- def _key_abbreviations(self, logs: Dict) -> Dict:
- """Changes the length of keys, so that the progress bar fits better."""
-
- def rename(key: str) -> str:
- """Renames accuracy to acc."""
- return key.replace("accuracy", "acc")
-
- return {rename(key): value for key, value in logs.items()}
-
- # def on_fit_begin(self) -> None:
- # """Creates a tqdm progress bar."""
- # self._configure_progress_bar()
-
- def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
- """Updates the description with the current epoch."""
- if epoch == 1:
- self._configure_progress_bar()
- else:
- self.progress_bar.reset()
- self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """At the end of each epoch, the validation metrics are updated to the progress bar."""
- self.val_metrics = logs
- self.progress_bar.set_postfix(**self._key_abbreviations(logs))
- self.progress_bar.update()
-
- def on_train_batch_end(self, batch: int, logs: Dict) -> None:
- """Updates the progress bar for each training step."""
- if self.val_metrics:
- logs.update(self.val_metrics)
- self.progress_bar.set_postfix(**self._key_abbreviations(logs))
- self.progress_bar.update()
-
- def on_fit_end(self) -> None:
- """Closes the tqdm progress bar."""
- self.progress_bar.close()
diff --git a/training/trainer/callbacks/wandb_callbacks.py b/training/trainer/callbacks/wandb_callbacks.py
deleted file mode 100644
index 552a4f4..0000000
--- a/training/trainer/callbacks/wandb_callbacks.py
+++ /dev/null
@@ -1,261 +0,0 @@
-"""Callback for W&B."""
-from typing import Callable, Dict, List, Optional, Type
-
-import numpy as np
-from training.trainer.callbacks import Callback
-import wandb
-
-import text_recognizer.datasets.transforms as transforms
-from text_recognizer.models.base import Model
-
-
-class WandbCallback(Callback):
- """A custom W&B metric logger for the trainer."""
-
- def __init__(self, log_batch_frequency: int = None) -> None:
- """Short summary.
-
- Args:
- log_batch_frequency (int): If None, metrics will be logged every epoch.
- If set to an integer, callback will log every metrics every log_batch_frequency.
-
- """
- super().__init__()
- self.log_batch_frequency = log_batch_frequency
-
- def _on_batch_end(self, batch: int, logs: Dict) -> None:
- if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
- wandb.log(logs, commit=True)
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Logs training metrics."""
- if logs is not None:
- logs["lr"] = self.model.optimizer.param_groups[0]["lr"]
- self._on_batch_end(batch, logs)
-
- def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Logs validation metrics."""
- if logs is not None:
- self._on_batch_end(batch, logs)
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Logs at epoch end."""
- wandb.log(logs, commit=True)
-
-
-class WandbImageLogger(Callback):
- """Custom W&B callback for image logging."""
-
- def __init__(
- self,
- example_indices: Optional[List] = None,
- num_examples: int = 4,
- transform: Optional[bool] = None,
- ) -> None:
- """Initializes the WandbImageLogger with the model to train.
-
- Args:
- example_indices (Optional[List]): Indices for validation images. Defaults to None.
- num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
- transform (Optional[Dict]): Use transform on image or not. Defaults to None.
-
- """
-
- super().__init__()
- self.caption = None
- self.example_indices = example_indices
- self.test_sample_indices = None
- self.num_examples = num_examples
- self.transform = (
- self._configure_transform(transform) if transform is not None else None
- )
-
- def _configure_transform(self, transform: Dict) -> Callable:
- args = transform["args"] or {}
- return getattr(transforms, transform["type"])(**args)
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and extracts validation images from the dataset."""
- self.model = model
- self.caption = "Validation Examples"
- if self.example_indices is None:
- self.example_indices = np.random.randint(
- 0, len(self.model.val_dataset), self.num_examples
- )
- self.images = self.model.val_dataset.dataset.data[self.example_indices]
- self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
- self.targets = self.targets.tolist()
-
- def on_test_begin(self) -> None:
- """Get samples from test dataset."""
- self.caption = "Test Examples"
- if self.test_sample_indices is None:
- self.test_sample_indices = np.random.randint(
- 0, len(self.model.test_dataset), self.num_examples
- )
- self.images = self.model.test_dataset.data[self.test_sample_indices]
- self.targets = self.model.test_dataset.targets[self.test_sample_indices]
- self.targets = self.targets.tolist()
-
- def on_test_end(self) -> None:
- """Log test images."""
- self.on_epoch_end(0, {})
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Get network predictions on validation images."""
- images = []
- for i, image in enumerate(self.images):
- image = self.transform(image) if self.transform is not None else image
- pred, conf = self.model.predict_on_image(image)
- if isinstance(self.targets[i], list):
- ground_truth = "".join(
- [
- self.model.mapper(int(target_index) - 26)
- if target_index > 35
- else self.model.mapper(int(target_index))
- for target_index in self.targets[i]
- ]
- ).rstrip("_")
- else:
- ground_truth = self.model.mapper(int(self.targets[i]))
- caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
- images.append(wandb.Image(image, caption=caption))
-
- wandb.log({f"{self.caption}": images}, commit=False)
-
-
-class WandbSegmentationLogger(Callback):
- """Custom W&B callback for image logging."""
-
- def __init__(
- self,
- class_labels: Dict,
- example_indices: Optional[List] = None,
- num_examples: int = 4,
- ) -> None:
- """Initializes the WandbImageLogger with the model to train.
-
- Args:
- class_labels (Dict): A dict with int as key and class string as value.
- example_indices (Optional[List]): Indices for validation images. Defaults to None.
- num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
-
- """
-
- super().__init__()
- self.caption = None
- self.class_labels = {int(k): v for k, v in class_labels.items()}
- self.example_indices = example_indices
- self.test_sample_indices = None
- self.num_examples = num_examples
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and extracts validation images from the dataset."""
- self.model = model
- self.caption = "Validation Segmentation Examples"
- if self.example_indices is None:
- self.example_indices = np.random.randint(
- 0, len(self.model.val_dataset), self.num_examples
- )
- self.images = self.model.val_dataset.dataset.data[self.example_indices]
- self.targets = self.model.val_dataset.dataset.targets[self.example_indices]
- self.targets = self.targets.tolist()
-
- def on_test_begin(self) -> None:
- """Get samples from test dataset."""
- self.caption = "Test Segmentation Examples"
- if self.test_sample_indices is None:
- self.test_sample_indices = np.random.randint(
- 0, len(self.model.test_dataset), self.num_examples
- )
- self.images = self.model.test_dataset.data[self.test_sample_indices]
- self.targets = self.model.test_dataset.targets[self.test_sample_indices]
- self.targets = self.targets.tolist()
-
- def on_test_end(self) -> None:
- """Log test images."""
- self.on_epoch_end(0, {})
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Get network predictions on validation images."""
- images = []
- for i, image in enumerate(self.images):
- pred_mask = (
- self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
- )
- gt_mask = np.array(self.targets[i])
- images.append(
- wandb.Image(
- image,
- masks={
- "predictions": {
- "mask_data": pred_mask,
- "class_labels": self.class_labels,
- },
- "ground_truth": {
- "mask_data": gt_mask,
- "class_labels": self.class_labels,
- },
- },
- )
- )
-
- wandb.log({f"{self.caption}": images}, commit=False)
-
-
-class WandbReconstructionLogger(Callback):
- """Custom W&B callback for image reconstructions logging."""
-
- def __init__(
- self, example_indices: Optional[List] = None, num_examples: int = 4,
- ) -> None:
- """Initializes the WandbImageLogger with the model to train.
-
- Args:
- example_indices (Optional[List]): Indices for validation images. Defaults to None.
- num_examples (int): Number of random samples to take if example_indices are not specified. Defaults to 4.
-
- """
-
- super().__init__()
- self.caption = None
- self.example_indices = example_indices
- self.test_sample_indices = None
- self.num_examples = num_examples
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and extracts validation images from the dataset."""
- self.model = model
- self.caption = "Validation Reconstructions Examples"
- if self.example_indices is None:
- self.example_indices = np.random.randint(
- 0, len(self.model.val_dataset), self.num_examples
- )
- self.images = self.model.val_dataset.dataset.data[self.example_indices]
-
- def on_test_begin(self) -> None:
- """Get samples from test dataset."""
- self.caption = "Test Reconstructions Examples"
- if self.test_sample_indices is None:
- self.test_sample_indices = np.random.randint(
- 0, len(self.model.test_dataset), self.num_examples
- )
- self.images = self.model.test_dataset.data[self.test_sample_indices]
-
- def on_test_end(self) -> None:
- """Log test images."""
- self.on_epoch_end(0, {})
-
- def on_epoch_end(self, epoch: int, logs: Dict) -> None:
- """Get network predictions on validation images."""
- images = []
- for image in self.images:
- reconstructed_image = (
- self.model.predict_on_image(image).detach().squeeze(0).cpu().numpy()
- )
- images.append(image)
- images.append(reconstructed_image)
-
- wandb.log(
- {f"{self.caption}": [wandb.Image(image) for image in images]}, commit=False,
- )
diff --git a/training/trainer/train.py b/training/trainer/train.py
deleted file mode 100644
index b770c94..0000000
--- a/training/trainer/train.py
+++ /dev/null
@@ -1,325 +0,0 @@
-"""Training script for PyTorch models."""
-
-from pathlib import Path
-import time
-from typing import Dict, List, Optional, Tuple, Type
-import warnings
-
-from einops import rearrange
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from torch.optim.swa_utils import update_bn
-from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA
-from training.trainer.util import log_val_metric
-import wandb
-
-from text_recognizer.models import Model
-
-
-torch.backends.cudnn.benchmark = True
-np.random.seed(4711)
-torch.manual_seed(4711)
-torch.cuda.manual_seed(4711)
-
-
-warnings.filterwarnings("ignore")
-
-
-class Trainer:
- """Trainer for training PyTorch models."""
-
- def __init__(
- self,
- max_epochs: int,
- callbacks: List[Type[Callback]],
- transformer_model: bool = False,
- max_norm: float = 0.0,
- freeze_backbone: Optional[int] = None,
- ) -> None:
- """Initialization of the Trainer.
-
- Args:
- max_epochs (int): The maximum number of epochs in the training loop.
- callbacks (CallbackList): List of callbacks to be called.
- transformer_model (bool): Transformer model flag, modifies the input to the model. Default is False.
- max_norm (float): Max norm for gradient cl:ipping. Defaults to 0.0.
- freeze_backbone (Optional[int]): How many epochs to freeze the backbone for. Used when training
- Transformers. Default is None.
-
- """
- # Training arguments.
- self.start_epoch = 1
- self.max_epochs = max_epochs
- self.callbacks = callbacks
- self.freeze_backbone = freeze_backbone
-
- # Flag for setting callbacks.
- self.callbacks_configured = False
-
- self.transformer_model = transformer_model
-
- self.max_norm = max_norm
-
- # Model placeholders
- self.model = None
-
- def _configure_callbacks(self) -> None:
- """Instantiate the CallbackList."""
- if not self.callbacks_configured:
- # If learning rate schedulers are present, they need to be added to the callbacks.
- if self.model.swa_scheduler is not None:
- self.callbacks.append(SWA())
- elif self.model.lr_scheduler is not None:
- self.callbacks.append(LRScheduler())
-
- self.callbacks = CallbackList(self.model, self.callbacks)
-
- def compute_metrics(
- self, output: Tensor, targets: Tensor, loss: Tensor, batch_size: int
- ) -> Dict:
- """Computes metrics for output and target pairs."""
- # Compute metrics.
- loss = loss.detach().float().item()
- output = output.detach()
- targets = targets.detach()
- if self.model.metrics is not None:
- metrics = {}
- for metric in self.model.metrics:
- if metric == "cer" or metric == "wer":
- metrics[metric] = self.model.metrics[metric](
- output,
- targets,
- batch_size,
- self.model.mapper(self.model.pad_token),
- )
- else:
- metrics[metric] = self.model.metrics[metric](output, targets)
- else:
- metrics = {}
- metrics["loss"] = loss
-
- return metrics
-
- def training_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
- """Performs the training step."""
- # Pass the tensor to the device for computation.
- data, targets = samples
- data, targets = (
- data.to(self.model.device),
- targets.to(self.model.device),
- )
-
- batch_size = data.shape[0]
-
- # Placeholder for uxiliary loss.
- aux_loss = None
-
- # Forward pass.
- # Get the network prediction.
- if self.transformer_model:
- if self.freeze_backbone is not None and batch < self.freeze_backbone:
- with torch.no_grad():
- image_features = self.model.network.extract_image_features(data)
-
- if isinstance(image_features, Tuple):
- image_features, _ = image_features
-
- output = self.model.network.decode_image_features(
- image_features, targets[:, :-1]
- )
- else:
- output = self.model.network.forward(data, targets[:, :-1])
- if isinstance(output, Tuple):
- output, aux_loss = output
- output = rearrange(output, "b t v -> (b t) v")
- targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
- else:
- output = self.model.forward(data)
-
- if isinstance(output, Tuple):
- output, aux_loss = output
- targets = data
-
- # Compute the loss.
- loss = self.model.criterion(output, targets)
-
- if aux_loss is not None:
- loss += aux_loss
-
- # Backward pass.
- # Clear the previous gradients.
- for p in self.model.network.parameters():
- p.grad = None
-
- # Compute the gradients.
- loss.backward()
-
- if self.max_norm > 0:
- torch.nn.utils.clip_grad_norm_(
- self.model.network.parameters(), self.max_norm
- )
-
- # Perform updates using calculated gradients.
- self.model.optimizer.step()
-
- metrics = self.compute_metrics(output, targets, loss, batch_size)
-
- return metrics
-
- def train(self) -> None:
- """Runs the training loop for one epoch."""
- # Set model to traning mode.
- self.model.train()
-
- for batch, samples in enumerate(self.model.train_dataloader()):
- self.callbacks.on_train_batch_begin(batch)
- metrics = self.training_step(batch, samples)
- self.callbacks.on_train_batch_end(batch, logs=metrics)
-
- @torch.no_grad()
- def validation_step(self, batch: int, samples: Tuple[Tensor, Tensor],) -> Dict:
- """Performs the validation step."""
-
- # Pass the tensor to the device for computation.
- data, targets = samples
- data, targets = (
- data.to(self.model.device),
- targets.to(self.model.device),
- )
-
- batch_size = data.shape[0]
-
- # Placeholder for uxiliary loss.
- aux_loss = None
-
- # Forward pass.
- # Get the network prediction.
- # Use SWA if available and using test dataset.
- if self.transformer_model:
- output = self.model.network.forward(data, targets[:, :-1])
- if isinstance(output, Tuple):
- output, aux_loss = output
- output = rearrange(output, "b t v -> (b t) v")
- targets = rearrange(targets[:, 1:], "b t -> (b t)").long()
- else:
- output = self.model.forward(data)
-
- if isinstance(output, Tuple):
- output, aux_loss = output
- targets = data
-
- # Compute the loss.
- loss = self.model.criterion(output, targets)
-
- if aux_loss is not None:
- loss += aux_loss
-
- # Compute metrics.
- metrics = self.compute_metrics(output, targets, loss, batch_size)
-
- return metrics
-
- def validate(self) -> Dict:
- """Runs the validation loop for one epoch."""
- # Set model to eval mode.
- self.model.eval()
-
- # Summary for the current eval loop.
- summary = []
-
- for batch, samples in enumerate(self.model.val_dataloader()):
- self.callbacks.on_validation_batch_begin(batch)
- metrics = self.validation_step(batch, samples)
- self.callbacks.on_validation_batch_end(batch, logs=metrics)
- summary.append(metrics)
-
- # Compute mean of all metrics.
- metrics_mean = {
- "val_" + metric: np.mean([x[metric] for x in summary])
- for metric in summary[0]
- }
-
- return metrics_mean
-
- def fit(self, model: Type[Model]) -> None:
- """Runs the training and evaluation loop."""
-
- # Sets model, loads the data, criterion, and optimizers.
- self.model = model
- self.model.prepare_data()
- self.model.configure_model()
-
- # Configure callbacks.
- self._configure_callbacks()
-
- # Set start time.
- t_start = time.time()
-
- self.callbacks.on_fit_begin()
-
- # Run the training loop.
- for epoch in range(self.start_epoch, self.max_epochs + 1):
- self.callbacks.on_epoch_begin(epoch)
-
- # Perform one training pass over the training set.
- self.train()
-
- # Evaluate the model on the validation set.
- val_metrics = self.validate()
- log_val_metric(val_metrics, epoch)
-
- self.callbacks.on_epoch_end(epoch, logs=val_metrics)
-
- if self.model.stop_training:
- break
-
- # Calculate the total training time.
- t_end = time.time()
- t_training = t_end - t_start
-
- self.callbacks.on_fit_end()
-
- logger.info(f"Training took {t_training:.2f} s.")
-
- # "Teardown".
- self.model = None
-
- def test(self, model: Type[Model]) -> Dict:
- """Run inference on test data."""
-
- # Sets model, loads the data, criterion, and optimizers.
- self.model = model
- self.model.prepare_data()
- self.model.configure_model()
-
- # Configure callbacks.
- self._configure_callbacks()
-
- self.callbacks.on_test_begin()
-
- self.model.eval()
-
- # Check if SWA network is available.
- self.model.use_swa_model()
-
- # Summary for the current test loop.
- summary = []
-
- for batch, samples in enumerate(self.model.test_dataloader()):
- metrics = self.validation_step(batch, samples)
- summary.append(metrics)
-
- self.callbacks.on_test_end()
-
- # Compute mean of all test metrics.
- metrics_mean = {
- "test_" + metric: np.mean([x[metric] for x in summary])
- for metric in summary[0]
- }
-
- # "Teardown".
- self.model = None
-
- return metrics_mean
diff --git a/training/trainer/util.py b/training/trainer/util.py
deleted file mode 100644
index 7cf1b45..0000000
--- a/training/trainer/util.py
+++ /dev/null
@@ -1,28 +0,0 @@
-"""Utility functions for training neural networks."""
-from typing import Dict, Optional
-
-from loguru import logger
-
-
-def log_val_metric(metrics_mean: Dict, epoch: Optional[int] = None) -> None:
- """Logging of val metrics to file/terminal."""
- log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
- logger.debug(log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()))
-
-
-class RunningAverage:
- """Maintains a running average."""
-
- def __init__(self) -> None:
- """Initializes the parameters."""
- self.steps = 0
- self.total = 0
-
- def update(self, val: float) -> None:
- """Updates the parameters."""
- self.total += val
- self.steps += 1
-
- def __call__(self) -> float:
- """Computes the running average."""
- return self.total / float(self.steps)