diff options
Diffstat (limited to 'src/text_recognizer')
21 files changed, 1167 insertions, 34 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index d8372e3..a6c1c59 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -8,6 +8,7 @@ from .emnist_lines_dataset import ( from .iam_dataset import IamDataset from .iam_lines_dataset import IamLinesDataset from .iam_paragraphs_dataset import IamParagraphsDataset +from .iam_preprocessor import load_metadata, Preprocessor from .transforms import AddTokens, Transpose from .util import ( _download_raw_dataset, @@ -29,8 +30,10 @@ __all__ = [ "EmnistMapper", "EmnistLinesDataset", "get_samples_by_character", + "load_metadata", "IamDataset", "IamLinesDataset", "IamParagraphsDataset", + "Preprocessor", "Transpose", ] diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/src/text_recognizer/datasets/iam_preprocessor.py new file mode 100644 index 0000000..5a5136c --- /dev/null +++ b/src/text_recognizer/datasets/iam_preprocessor.py @@ -0,0 +1,196 @@ +"""Preprocessor for extracting word letters from the IAM dataset. + +The code is mostly stolen from: + + https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py + +""" + +import collections +import itertools +from pathlib import Path +import re +from typing import List, Optional, Union + +import click +from loguru import logger +import torch + + +def load_metadata( + data_dir: Path, wordsep: str, use_words: bool = False +) -> collections.defaultdict: + """Loads IAM metadata and returns it as a dictionary.""" + forms = collections.defaultdict(list) + filename = "words.txt" if use_words else "lines.txt" + + with open(data_dir / "ascii" / filename, "r") as f: + lines = (line.strip().split() for line in f if line[0] != "#") + for line in lines: + # Skip word segmentation errors. + if use_words and line[1] == "err": + continue + text = " ".join(line[8:]) + + # Remove garbage tokens: + text = text.replace("#", "") + + # Swap word sep form | to wordsep + text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep) + form_key = "-".join(line[0].split("-")[:2]) + line_key = "-".join(line[0].split("-")[:3]) + box_idx = 4 - use_words + box = tuple(int(val) for val in line[box_idx : box_idx + 4]) + forms[form_key].append({"key": line_key, "box": box, "text": text}) + return forms + + +class Preprocessor: + """A preprocessor for the IAM dataset.""" + + # TODO: add lower case only to when generating... + + def __init__( + self, + data_dir: Union[str, Path], + num_features: int, + tokens_path: Optional[Union[str, Path]] = None, + lexicon_path: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + self.wordsep = "_" + self._use_word = use_words + self._prepend_wordsep = prepend_wordsep + + self.data_dir = Path(data_dir) + + self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) + + # Load the set of graphemes: + graphemes = set() + for _, form in self.forms.items(): + for line in form: + graphemes.update(line["text"].lower()) + self.graphemes = sorted(graphemes) + + # Build the token-to-index and index-to-token maps. + if tokens_path is not None: + with open(tokens_path, "r") as f: + self.tokens = [line.strip() for line in f] + else: + self.tokens = self.graphemes + + if lexicon_path is not None: + with open(lexicon_path, "r") as f: + lexicon = (line.strip().split() for line in f) + lexicon = {line[0]: line[1:] for line in lexicon} + self.lexicon = lexicon + else: + self.lexicon = None + + self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} + self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} + self.num_features = num_features + self.text = [] + + @property + def num_tokens(self) -> int: + """Returns the number or tokens.""" + return len(self.tokens) + + @property + def use_words(self) -> bool: + """If words are used.""" + return self._use_word + + def extract_train_text(self) -> None: + """Extracts training text.""" + keys = [] + with open(self.data_dir / "task" / "trainset.txt") as f: + keys.extend((line.strip() for line in f)) + + for _, examples in self.forms.items(): + for example in examples: + if example["key"] not in keys: + continue + self.text.append(example["text"].lower()) + + def to_index(self, line: str) -> torch.LongTensor: + """Converts text to a tensor of indices.""" + token_to_index = self.graphemes_to_index + if self.lexicon is not None: + if len(line) > 0: + # If the word is not found in the lexicon, fall back to letters. + line = [ + t + for w in line.split(self.wordsep) + for t in self.lexicon.get(w, self.wordsep + w) + ] + token_to_index = self.tokens_to_index + if self._prepend_wordsep: + line = itertools.chain([self.wordsep], line) + return torch.LongTensor([token_to_index[t] for t in line]) + + def to_text(self, indices: List[int]) -> str: + """Converts indices to text.""" + # Roughly the inverse of `to_index` + encoding = self.graphemes + if self.lexicon is not None: + encoding = self.tokens + return self._post_process(encoding[i] for i in indices) + + def tokens_to_text(self, indices: List[int]) -> str: + """Converts tokens to text.""" + return self._post_process(self.tokens[i] for i in indices) + + def _post_process(self, indices: List[int]) -> str: + """A list join.""" + return "".join(indices).strip(self.wordsep) + + +@click.command() +@click.option("--data_dir", type=str, default=None, help="Path to iam dataset") +@click.option( + "--use_words", is_flag=True, help="Load word segmented dataset instead of lines" +) +@click.option( + "--save_text", type=str, default=None, help="Path to save parsed train text" +) +@click.option("--save_tokens", type=str, default=None, help="Path to save tokens") +def cli( + data_dir: Optional[str], + use_words: bool, + save_text: Optional[str], + save_tokens: Optional[str], +) -> None: + """CLI for extracting text data from the iam dataset.""" + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + + preprocessor = Preprocessor(data_dir, 64, use_words=use_words) + preprocessor.extract_train_text() + + processed_dir = data_dir.parents[2] / "processed" / "iam_lines" + logger.debug(f"Saving processed files at: {processed_dir}") + + if save_text is not None: + logger.info("Saving training text") + with open(processed_dir / save_text, "w") as f: + f.write("\n".join(t for t in preprocessor.text)) + + if save_tokens is not None: + logger.info("Saving tokens") + with open(processed_dir / save_tokens, "w") as f: + f.write("\n".join(preprocessor.tokens)) + + +if __name__ == "__main__": + cli() diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index 8956b01..60987e0 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -1,14 +1,57 @@ """Transforms for PyTorch datasets.""" +import random + import numpy as np from PIL import Image import torch from torch import Tensor import torch.nn.functional as F -from torchvision.transforms import Compose, RandomAffine, RandomHorizontalFlip, ToTensor +from torchvision import transforms +from torchvision.transforms import ( + ColorJitter, + Compose, + Normalize, + RandomAffine, + RandomHorizontalFlip, + RandomRotation, + ToPILImage, + ToTensor, +) from text_recognizer.datasets.util import EmnistMapper +class RandomResizeCrop: + """Image transform with random resize and crop applied. + + Stolen from + + https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py + + """ + + def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: + self.jitter = jitter + self.ratio = ratio + + def __call__(self, img: np.ndarray) -> np.ndarray: + """Applies random crop and rotation to an image.""" + w, h = img.size + + # pad with white: + img = transforms.functional.pad(img, self.jitter, fill=255) + + # crop at random (x, y): + x = self.jitter + random.randint(-self.jitter, self.jitter) + y = self.jitter + random.randint(-self.jitter, self.jitter) + + # randomize aspect ratio: + size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) + size = (h, int(size_w)) + img = transforms.functional.resized_crop(img, y, x, h, w, size) + return img + + class Transpose: """Transposes the EMNIST image to the correct orientation.""" diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index eb5dbce..7647d7e 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -5,6 +5,7 @@ from .crnn_model import CRNNModel from .ctc_transformer_model import CTCTransformerModel from .segmentation_model import SegmentationModel from .transformer_model import TransformerModel +from .vqvae_model import VQVAEModel __all__ = [ "CharacterModel", @@ -13,4 +14,5 @@ __all__ = [ "Model", "SegmentationModel", "TransformerModel", + "VQVAEModel", ] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index f2cd4b8..70f4cdb 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -332,7 +332,7 @@ class Model(ABC): def summary( self, input_shape: Optional[Union[List, Tuple]] = None, - depth: int = 4, + depth: int = 3, device: Optional[str] = None, ) -> None: """Prints a summary of the network architecture.""" diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py index 12e497f..3f63053 100644 --- a/src/text_recognizer/models/transformer_model.py +++ b/src/text_recognizer/models/transformer_model.py @@ -6,9 +6,9 @@ import torch from torch import nn from torch import Tensor from torch.utils.data import Dataset -from torchvision.transforms import ToTensor from text_recognizer.datasets import EmnistMapper +import text_recognizer.datasets.transforms as transforms from text_recognizer.models.base import Model from text_recognizer.networks import greedy_decoder @@ -60,13 +60,19 @@ class TransformerModel(Model): eos_token=self.eos_token, lower=self.lower, ) - self.tensor_transform = ToTensor() - + self.tensor_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])] + ) self.softmax = nn.Softmax(dim=2) @torch.no_grad() def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: src = self.network.extract_image_features(image) + + # Added for vqvae transformer. + if isinstance(src, Tuple): + src = src[0] + memory = self.network.encoder(src) confidence_of_predictions = [] diff --git a/src/text_recognizer/models/vqvae_model.py b/src/text_recognizer/models/vqvae_model.py new file mode 100644 index 0000000..70f6f1f --- /dev/null +++ b/src/text_recognizer/models/vqvae_model.py @@ -0,0 +1,80 @@ +"""Defines the VQVAEModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model + + +class VQVAEModel(Model): + """Model for reconstructing images from codebook.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Initializes the CharacterModel.""" + + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.pad_token = dataset_args["args"]["pad_token"] + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token,) + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=0) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + """Reconstruction of image. + + Args: + image (Union[np.ndarray, torch.Tensor]): An image containing a character. + + Returns: + Tuple[str, float]: The predicted character and the confidence in the prediction. + + """ + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + image_reconstructed, _ = self.forward(image) + + return image_reconstructed diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 2b624bb..bac5d28 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,4 +1,5 @@ """Network modules.""" +from .cnn import CNN from .cnn_transformer import CNNTransformer from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder @@ -7,15 +8,19 @@ from .lenet import LeNet from .metrics import accuracy, cer, wer from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder +from .transducer import TDS2d from .transformer import Transformer from .unet import UNet from .util import sliding_window from .vit import ViT +from .vq_transformer import VQTransformer +from .vqvae import VQVAE from .wide_resnet import WideResidualNetwork __all__ = [ "accuracy", "cer", + "CNN", "CNNTransformer", "ConvolutionalRecurrentNetwork", "DenseNet", @@ -27,8 +32,11 @@ __all__ = [ "ResidualNetworkEncoder", "sliding_window", "UNet", + "TDS2d", "Transformer", "ViT", + "VQTransformer", + "VQVAE", "wer", "WideResidualNetwork", ] diff --git a/src/text_recognizer/networks/cnn.py b/src/text_recognizer/networks/cnn.py new file mode 100644 index 0000000..1807bb9 --- /dev/null +++ b/src/text_recognizer/networks/cnn.py @@ -0,0 +1,101 @@ +"""Implementation of a simple backbone cnn network.""" +from typing import Callable, Dict, Optional, Tuple + +from einops.layers.torch import Rearrange +import torch +from torch import nn + +from text_recognizer.networks.util import activation_function + + +class CNN(nn.Module): + """LeNet network for character prediction.""" + + def __init__( + self, + channels: Tuple[int, ...] = (1, 32, 64, 128), + kernel_sizes: Tuple[int, ...] = (4, 4, 4), + strides: Tuple[int, ...] = (2, 2, 2), + max_pool_kernel: int = 2, + dropout_rate: float = 0.2, + activation: Optional[str] = "relu", + ) -> None: + """Initialization of the LeNet network. + + Args: + channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). + kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). + strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2). + max_pool_kernel (int): 2D max pooling kernel. Defaults to 2. + dropout_rate (float): The dropout rate. Defaults to 0.2. + activation (Optional[str]): The name of non-linear activation function. Defaults to relu. + + Raises: + RuntimeError: if the number of hyperparameters does not match in length. + + """ + super().__init__() + + if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides): + raise RuntimeError("The number of the hyperparameters does not match.") + + self.cnn = self._build_network( + channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation, + ) + + def _build_network( + self, + channels: Tuple[int, ...], + kernel_sizes: Tuple[int, ...], + strides: Tuple[int, ...], + max_pool_kernel: int, + dropout_rate: float, + activation: str, + ) -> nn.Sequential: + # Load activation function. + activation_fn = activation_function(activation) + + channels = list(channels) + in_channels = channels.pop(0) + configuration = zip(channels, kernel_sizes, strides) + + modules = nn.ModuleList([]) + + for i, (out_channels, kernel_size, stride) in enumerate(configuration): + # Add max pool to reduce output size. + if i == len(channels) // 2: + modules.append(nn.MaxPool2d(max_pool_kernel)) + if i == 0: + modules.append( + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=1 + ) + ) + else: + modules.append( + nn.Sequential( + activation_fn, + nn.BatchNorm2d(in_channels), + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=1, + ), + ) + ) + + if dropout_rate: + modules.append(nn.Dropout2d(p=dropout_rate)) + + in_channels = out_channels + + return nn.Sequential(*modules) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.cnn(x) diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 43e5403..7133c26 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -29,14 +29,22 @@ class CNNTransformer(nn.Module): backbone: str, backbone_args: Optional[Dict] = None, activation: str = "gelu", + pool_kernel: Optional[Tuple[int, int]] = None, ) -> None: super().__init__() self.trg_pad_index = trg_pad_index self.vocab_size = vocab_size self.backbone = configure_backbone(backbone, backbone_args) + + if pool_kernel is not None: + self.max_pool = nn.MaxPool2d(pool_kernel, stride=2) + else: + self.max_pool = None + self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.pos_dropout = nn.Dropout(p=dropout_rate) self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) nn.init.normal_(self.character_embedding.weight, std=0.02) @@ -98,18 +106,23 @@ class CNNTransformer(nn.Module): # If batch dimension is missing, it needs to be added. if len(src.shape) < 4: src = src[(None,) * (4 - len(src.shape))] + src = self.backbone(src) + if self.max_pool is not None: + src = self.max_pool(src) + if self.adaptive_pool is not None: src = rearrange(src, "b c h w -> b w c h") src = self.adaptive_pool(src) src = src.squeeze(3) else: - src = rearrange(src, "b c h w -> b (w h) c") + src = rearrange(src, "b c h w -> b (h w) c") b, t, _ = src.shape src += self.src_position_embedding[:, :t] + src = self.pos_dropout(src) return src diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py index ffad792..2605731 100644 --- a/src/text_recognizer/networks/metrics.py +++ b/src/text_recognizer/networks/metrics.py @@ -1,4 +1,7 @@ """Utility functions for models.""" +from typing import Optional + +from einops import rearrange import Levenshtein as Lev import torch from torch import Tensor @@ -32,22 +35,33 @@ def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: return acc -def cer(outputs: Tensor, targets: Tensor) -> float: +def cer( + outputs: Tensor, + targets: Tensor, + batch_size: Optional[int] = None, + blank_label: Optional[int] = int, +) -> float: """Computes the character error rate. Args: outputs (Tensor): The output from the network. targets (Tensor): Ground truth labels. + batch_size (Optional[int]): Batch size if target and output has been flattend. + blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. Returns: float: The cer for the batch. """ + if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: + targets = rearrange(targets, "(b t) -> b t", b=batch_size) + outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) + target_lengths = torch.full( size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, ) decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths + outputs, targets, target_lengths, blank_label=blank_label, ) lev_dist = 0 @@ -63,22 +77,33 @@ def cer(outputs: Tensor, targets: Tensor) -> float: return lev_dist / len(decoded_predictions) -def wer(outputs: Tensor, targets: Tensor) -> float: +def wer( + outputs: Tensor, + targets: Tensor, + batch_size: Optional[int] = None, + blank_label: Optional[int] = int, +) -> float: """Computes the Word error rate. Args: outputs (Tensor): The output from the network. targets (Tensor): Ground truth labels. + batch_size (optional[int]): Batch size if target and output has been flattend. + blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. Returns: float: The wer for the batch. """ + if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: + targets = rearrange(targets, "(b t) -> b t", b=batch_size) + outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) + target_lengths = torch.full( size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, ) decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths + outputs, targets, target_lengths, blank_label=blank_label, ) lev_dist = 0 diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py new file mode 100644 index 0000000..fdd6662 --- /dev/null +++ b/src/text_recognizer/networks/transducer/__init__.py @@ -0,0 +1,2 @@ +"""Transducer modules.""" +from .tds_conv import TDS2d diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py new file mode 100644 index 0000000..018caf2 --- /dev/null +++ b/src/text_recognizer/networks/transducer/tds_conv.py @@ -0,0 +1,205 @@ +"""Time-Depth Separable Convolutions. + +References: + https://arxiv.org/abs/1904.02619 + https://arxiv.org/pdf/2010.01003.pdf + +Code stolen from: + https://github.com/facebookresearch/gtn_applications + + +""" +from typing import List, Tuple + +from einops import rearrange +import gtn +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class TDSBlock2d(nn.Module): + """Internal block of a 2D TDSC network.""" + + def __init__( + self, + in_channels: int, + img_depth: int, + kernel_size: Tuple[int], + dropout_rate: float, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.img_depth = img_depth + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + self.fc_dim = in_channels * img_depth + + # Network placeholders. + self.conv = None + self.mlp = None + self.instance_norm = None + + self._build_block() + + def _build_block(self) -> None: + # Convolutional block. + self.conv = nn.Sequential( + nn.Conv3d( + in_channels=self.in_channels, + out_channels=self.in_channels, + kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), + padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2), + ), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + ) + + # MLP block. + self.mlp = nn.Sequential( + nn.Linear(self.fc_dim, self.fc_dim), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.Linear(self.fc_dim, self.fc_dim), + nn.Dropout(self.dropout_rate), + ) + + # Instance norm. + self.instance_norm = nn.ModuleList( + [ + nn.InstanceNorm2d(self.fc_dim, affine=True), + nn.InstanceNorm2d(self.fc_dim, affine=True), + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x (Tensor): Input tensor. + + Shape: + - x: :math: `(B, CD, H, W)` + + Returns: + Tensor: Output tensor. + + """ + B, CD, H, W = x.shape + C, D = self.in_channels, self.img_depth + residual = x + x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D) + x = self.conv(x) + x = rearrange(x, "b c d h w -> b (c d) h w") + x += residual + + x = self.instance_norm[0](x) + + x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x + x + self.instance_norm[1](x) + + # Output shape: [B, CD, H, W] + return x + + +class TDS2d(nn.Module): + """TDS Netowrk. + + Structure is the following: + Downsample layer -> TDS2d group -> ... -> Linear output layer + + + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + depth: int, + tds_groups: Tuple[int], + kernel_size: Tuple[int], + dropout_rate: float, + in_channels: int = 1, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.input_dim = input_dim + self.output_dim = output_dim + self.depth = depth + self.tds_groups = tds_groups + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + + self.tds = None + self.fc = None + + def _build_network(self) -> None: + + modules = [] + stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) + if self.input_dim % stride_h: + raise RuntimeError( + f"Image height not divisible by total stride {stride_h}." + ) + + for tds_group in self.tds_groups: + # Add downsample layer. + out_channels = self.depth * tds_group["channels"] + modules.extend( + [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=out_channels, + kernel_size=self.kernel_size, + padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), + stride=tds_group["stride"], + ), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.InstanceNorm2d(out_channels, affine=True), + ] + ) + + for _ in range(tds_group["num_blocks"]): + modules.append( + TDSBlock2d( + tds_group["channels"], + self.depth, + self.kernel_size, + self.dropout_rate, + ) + ) + + self.in_channels = out_channels + + self.tds = nn.Sequential(*modules) + self.fc = nn.Linear( + self.in_channels * self.input_dim // stride_h, self.output_dim + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x (Tensor): Input tensor. + + Shape: + - x: :math: `(B, H, W)` + + Returns: + Tensor: Output tensor. + + """ + B, H, W = x.shape + x = rearrange( + x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels + ) + x = self.tds(x) + + # x shape: [B, C, H, W] + x = rearrange(x, "b c h w -> b w (c h)") + + return self.fc(x) diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py index 711a952..131a6b4 100644 --- a/src/text_recognizer/networks/util.py +++ b/src/text_recognizer/networks/util.py @@ -65,13 +65,18 @@ def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]: network_args = state_dict["network_args"] weights = state_dict["model_state"] + freeze = False + if "freeze" in backbone_args and backbone_args["freeze"] is True: + backbone_args.pop("freeze") + freeze = True + network_args = backbone_args + # Initializes the network with trained weights. backbone = backbone_(**network_args) backbone.load_state_dict(weights) - if "freeze" in backbone_args and backbone_args["freeze"] is True: + if freeze: for params in backbone.parameters(): params.requires_grad = False - else: backbone_ = getattr(network_module, backbone) backbone = backbone_(**backbone_args) diff --git a/src/text_recognizer/networks/vq_transformer.py b/src/text_recognizer/networks/vq_transformer.py new file mode 100644 index 0000000..c673d96 --- /dev/null +++ b/src/text_recognizer/networks/vq_transformer.py @@ -0,0 +1,150 @@ +"""A VQ-Transformer for image to text recognition.""" +from typing import Dict, Optional, Tuple + +from einops import rearrange, repeat +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding, Transformer +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.util import configure_backbone +from text_recognizer.networks.vqvae.encoder import _ResidualBlock + + +class VQTransformer(nn.Module): + """VQ+Transfomer for image to character sequence prediction.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + vocab_size: int, + num_heads: int, + adaptive_pool_dim: Tuple, + expansion_dim: int, + dropout_rate: float, + trg_pad_index: int, + max_len: int, + backbone: str, + backbone_args: Optional[Dict] = None, + activation: str = "gelu", + ) -> None: + super().__init__() + + # Configure vector quantized backbone. + self.backbone = configure_backbone(backbone, backbone_args) + self.conv = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2), + nn.ReLU(inplace=True), + ) + + # Configure embeddings for Transformer network. + self.trg_pad_index = trg_pad_index + self.vocab_size = vocab_size + self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) + self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + nn.init.normal_(self.character_embedding.weight, std=0.02) + + self.adaptive_pool = ( + nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None + ) + + self.transformer = Transformer( + num_encoder_layers, + num_decoder_layers, + hidden_dim, + num_heads, + expansion_dim, + dropout_rate, + activation, + ) + + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) + + def _create_trg_mask(self, trg: Tensor) -> Tensor: + # Move this outside the transformer. + trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask + + def encoder(self, src: Tensor) -> Tensor: + """Forward pass with the encoder of the transformer.""" + return self.transformer.encoder(src) + + def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: + """Forward pass with the decoder of the transformer + classification head.""" + return self.head( + self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + ) + + def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]: + """Extracts image features with a backbone neural network. + + It seem like the winning idea was to swap channels and width dimension and collapse + the height dimension. The transformer is learning like a baby with this implementation!!! :D + Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D + + Args: + src (Tensor): Input tensor. + + Returns: + Tensor: The input src to the transformer and the vq loss. + + """ + # If batch dimension is missing, it needs to be added. + if len(src.shape) < 4: + src = src[(None,) * (4 - len(src.shape))] + src, vq_loss = self.backbone.encode(src) + # src = self.backbone.decoder.res_block(src) + src = self.conv(src) + + if self.adaptive_pool is not None: + src = rearrange(src, "b c h w -> b w c h") + src = self.adaptive_pool(src) + src = src.squeeze(3) + else: + src = rearrange(src, "b c h w -> b (w h) c") + + b, t, _ = src.shape + + src += self.src_position_embedding[:, :t] + + return src, vq_loss + + def target_embedding(self, trg: Tensor) -> Tensor: + """Encodes target tensor with embedding and postion. + + Args: + trg (Tensor): Target tensor. + + Returns: + Tensor: Encoded target tensor. + + """ + trg = self.character_embedding(trg.long()) + trg = self.trg_position_encoding(trg) + return trg + + def decode_image_features( + self, image_features: Tensor, trg: Optional[Tensor] = None + ) -> Tensor: + """Takes images features from the backbone and decodes them with the transformer.""" + trg_mask = self._create_trg_mask(trg) + trg = self.target_embedding(trg) + out = self.transformer(image_features, trg, trg_mask=trg_mask) + + logits = self.head(out) + return logits + + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Forward pass with CNN transfomer.""" + image_features, vq_loss = self.extract_image_features(x) + logits = self.decode_image_features(image_features, trg) + return logits, vq_loss diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py index e1f05fa..763953c 100644 --- a/src/text_recognizer/networks/vqvae/__init__.py +++ b/src/text_recognizer/networks/vqvae/__init__.py @@ -1 +1,5 @@ """VQ-VAE module.""" +from .decoder import Decoder +from .encoder import Encoder +from .vector_quantizer import VectorQuantizer +from .vqvae import VQVAE diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/src/text_recognizer/networks/vqvae/decoder.py new file mode 100644 index 0000000..8847aba --- /dev/null +++ b/src/text_recognizer/networks/vqvae/decoder.py @@ -0,0 +1,133 @@ +"""CNN decoder for the VQ-VAE.""" + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.encoder import _ResidualBlock + + +class Decoder(nn.Module): + """A CNN encoder network.""" + + def __init__( + self, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + upsampling: Optional[List[List[int]]] = None, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + if dropout_rate: + if activation == "selu": + dropout = nn.AlphaDropout(p=dropout_rate) + else: + dropout = nn.Dropout(p=dropout_rate) + else: + dropout = None + + self.upsampling = upsampling + + self.res_block = nn.ModuleList([]) + self.upsampling_block = nn.ModuleList([]) + + self.embedding_dim = embedding_dim + activation = activation_function(activation) + + # Configure encoder. + self.decoder = self._build_decoder( + channels, kernel_sizes, strides, num_residual_layers, activation, dropout, + ) + + def _build_decompression_block( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.ModuleList: + modules = nn.ModuleList([]) + configuration = zip(channels, kernel_sizes, strides) + for i, (out_channels, kernel_size, stride) in enumerate(configuration): + modules.append( + nn.Sequential( + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=1, + ), + activation, + ) + ) + + if i < len(self.upsampling): + modules.append(nn.Upsample(size=self.upsampling[i]),) + + if dropout is not None: + modules.append(dropout) + + in_channels = out_channels + + modules.extend( + nn.Sequential( + nn.ConvTranspose2d( + in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1 + ), + nn.Tanh(), + ) + ) + + return modules + + def _build_decoder( + self, + channels: int, + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.Sequential: + + self.res_block.append( + nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) + ) + + # Bottleneck module. + self.res_block.extend( + nn.ModuleList( + [ + _ResidualBlock(channels[0], channels[0], dropout) + for i in range(num_residual_layers) + ] + ) + ) + + # Decompression module + self.upsampling_block.extend( + self._build_decompression_block( + channels[0], channels[1:], kernel_sizes, strides, activation, dropout + ) + ) + + self.res_block = nn.Sequential(*self.res_block) + self.upsampling_block = nn.Sequential(*self.upsampling_block) + + return nn.Sequential(self.res_block, self.upsampling_block) + + def forward(self, z_q: Tensor) -> Tensor: + """Reconstruct input from given codes.""" + x_reconstruction = self.decoder(z_q) + return x_reconstruction diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py index 60c4c43..d3adac5 100644 --- a/src/text_recognizer/networks/vqvae/encoder.py +++ b/src/text_recognizer/networks/vqvae/encoder.py @@ -1,6 +1,5 @@ """CNN encoder for the VQ-VAE.""" - -from typing import List, Optional, Type +from typing import List, Optional, Tuple, Type import torch from torch import nn @@ -12,16 +11,12 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer class _ResidualBlock(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], + self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], ) -> None: super().__init__() self.block = [ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - activation, + nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), ] @@ -42,23 +37,111 @@ class Encoder(nn.Module): self, in_channels: int, channels: List[int], + kernel_sizes: List[int], + strides: List[int], num_residual_layers: int, embedding_dim: int, num_embeddings: int, beta: float = 0.25, - activation: str = "elu", + activation: str = "leaky_relu", dropout_rate: float = 0.0, ) -> None: super().__init__() - pass - # if dropout_rate: - # if activation == "selu": - # dropout = nn.AlphaDropout(p=dropout_rate) - # else: - # dropout = nn.Dropout(p=dropout_rate) - # else: - # dropout = None - - def _build_encoder(self) -> nn.Sequential: - # TODO: Continue to implement encoder. - pass + + if dropout_rate: + if activation == "selu": + dropout = nn.AlphaDropout(p=dropout_rate) + else: + dropout = nn.Dropout(p=dropout_rate) + else: + dropout = None + + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.beta = beta + activation = activation_function(activation) + + # Configure encoder. + self.encoder = self._build_encoder( + in_channels, + channels, + kernel_sizes, + strides, + num_residual_layers, + activation, + dropout, + ) + + # Configure Vector Quantizer. + self.vector_quantizer = VectorQuantizer( + self.num_embeddings, self.embedding_dim, self.beta + ) + + def _build_compression_block( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.ModuleList: + modules = nn.ModuleList([]) + configuration = zip(channels, kernel_sizes, strides) + for out_channels, kernel_size, stride in configuration: + modules.append( + nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=1 + ), + activation, + ) + ) + + if dropout is not None: + modules.append(dropout) + + in_channels = out_channels + + return modules + + def _build_encoder( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.Sequential: + encoder = nn.ModuleList([]) + + # compression module + encoder.extend( + self._build_compression_block( + in_channels, channels, kernel_sizes, strides, activation, dropout + ) + ) + + # Bottleneck module. + encoder.extend( + nn.ModuleList( + [ + _ResidualBlock(channels[-1], channels[-1], dropout) + for i in range(num_residual_layers) + ] + ) + ) + + encoder.append( + nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) + ) + + return nn.Sequential(*encoder) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes input into a discrete representation.""" + z_e = self.encoder(x) + z_q, vq_loss = self.vector_quantizer(z_e) + return z_q, vq_loss diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py index 25e5583..f92c7ee 100644 --- a/src/text_recognizer/networks/vqvae/vector_quantizer.py +++ b/src/text_recognizer/networks/vqvae/vector_quantizer.py @@ -26,7 +26,7 @@ class VectorQuantizer(nn.Module): self.embedding = nn.Embedding(self.K, self.D) # Initialize the codebook. - self.embedding.weight.uniform_(-1 / self.K, 1 / self.K) + nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K) def discretization_bottleneck(self, latent: Tensor) -> Tensor: """Computes the code nearest to the latent representation. diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/src/text_recognizer/networks/vqvae/vqvae.py new file mode 100644 index 0000000..50448b4 --- /dev/null +++ b/src/text_recognizer/networks/vqvae/vqvae.py @@ -0,0 +1,74 @@ +"""The VQ-VAE.""" + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.vqvae import Decoder, Encoder + + +class VQVAE(nn.Module): + """Vector Quantized Variational AutoEncoder.""" + + def __init__( + self, + in_channels: int, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + num_embeddings: int, + upsampling: Optional[List[List[int]]] = None, + beta: float = 0.25, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + # configure encoder. + self.encoder = Encoder( + in_channels, + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + num_embeddings, + beta, + activation, + dropout_rate, + ) + + # Configure decoder. + channels.reverse() + kernel_sizes.reverse() + strides.reverse() + self.decoder = Decoder( + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + upsampling, + activation, + dropout_rate, + ) + + def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes input to a latent code.""" + return self.encoder(x) + + def decode(self, z_q: Tensor) -> Tensor: + """Reconstructs input from latent codes.""" + return self.decoder(z_q) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Compresses and decompresses input.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + z_q, vq_loss = self.encode(x) + x_reconstruction = self.decode(z_q) + return x_reconstruction, vq_loss diff --git a/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt b/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt Binary files differnew file mode 100644 index 0000000..b5295c2 --- /dev/null +++ b/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt |