summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
commit4a54d7e690897dd6e6c719fb908fd371a44c2952 (patch)
tree04722ac94b9c3960baa5db7939d7ef01dbf535a6 /src/text_recognizer
parentd691b548cd0b6fc4ea184d64261f633789fee021 (diff)
Many updates, cool stuff on the way.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/__init__.py3
-rw-r--r--src/text_recognizer/datasets/iam_preprocessor.py196
-rw-r--r--src/text_recognizer/datasets/transforms.py45
-rw-r--r--src/text_recognizer/models/__init__.py2
-rw-r--r--src/text_recognizer/models/base.py2
-rw-r--r--src/text_recognizer/models/transformer_model.py12
-rw-r--r--src/text_recognizer/models/vqvae_model.py80
-rw-r--r--src/text_recognizer/networks/__init__.py8
-rw-r--r--src/text_recognizer/networks/cnn.py101
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py15
-rw-r--r--src/text_recognizer/networks/metrics.py33
-rw-r--r--src/text_recognizer/networks/transducer/__init__.py2
-rw-r--r--src/text_recognizer/networks/transducer/tds_conv.py205
-rw-r--r--src/text_recognizer/networks/util.py9
-rw-r--r--src/text_recognizer/networks/vq_transformer.py150
-rw-r--r--src/text_recognizer/networks/vqvae/__init__.py4
-rw-r--r--src/text_recognizer/networks/vqvae/decoder.py133
-rw-r--r--src/text_recognizer/networks/vqvae/encoder.py125
-rw-r--r--src/text_recognizer/networks/vqvae/vector_quantizer.py2
-rw-r--r--src/text_recognizer/networks/vqvae/vqvae.py74
-rw-r--r--src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.ptbin0 -> 21687018 bytes
21 files changed, 1167 insertions, 34 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index d8372e3..a6c1c59 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -8,6 +8,7 @@ from .emnist_lines_dataset import (
from .iam_dataset import IamDataset
from .iam_lines_dataset import IamLinesDataset
from .iam_paragraphs_dataset import IamParagraphsDataset
+from .iam_preprocessor import load_metadata, Preprocessor
from .transforms import AddTokens, Transpose
from .util import (
_download_raw_dataset,
@@ -29,8 +30,10 @@ __all__ = [
"EmnistMapper",
"EmnistLinesDataset",
"get_samples_by_character",
+ "load_metadata",
"IamDataset",
"IamLinesDataset",
"IamParagraphsDataset",
+ "Preprocessor",
"Transpose",
]
diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/src/text_recognizer/datasets/iam_preprocessor.py
new file mode 100644
index 0000000..5a5136c
--- /dev/null
+++ b/src/text_recognizer/datasets/iam_preprocessor.py
@@ -0,0 +1,196 @@
+"""Preprocessor for extracting word letters from the IAM dataset.
+
+The code is mostly stolen from:
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
+
+"""
+
+import collections
+import itertools
+from pathlib import Path
+import re
+from typing import List, Optional, Union
+
+import click
+from loguru import logger
+import torch
+
+
+def load_metadata(
+ data_dir: Path, wordsep: str, use_words: bool = False
+) -> collections.defaultdict:
+ """Loads IAM metadata and returns it as a dictionary."""
+ forms = collections.defaultdict(list)
+ filename = "words.txt" if use_words else "lines.txt"
+
+ with open(data_dir / "ascii" / filename, "r") as f:
+ lines = (line.strip().split() for line in f if line[0] != "#")
+ for line in lines:
+ # Skip word segmentation errors.
+ if use_words and line[1] == "err":
+ continue
+ text = " ".join(line[8:])
+
+ # Remove garbage tokens:
+ text = text.replace("#", "")
+
+ # Swap word sep form | to wordsep
+ text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep)
+ form_key = "-".join(line[0].split("-")[:2])
+ line_key = "-".join(line[0].split("-")[:3])
+ box_idx = 4 - use_words
+ box = tuple(int(val) for val in line[box_idx : box_idx + 4])
+ forms[form_key].append({"key": line_key, "box": box, "text": text})
+ return forms
+
+
+class Preprocessor:
+ """A preprocessor for the IAM dataset."""
+
+ # TODO: add lower case only to when generating...
+
+ def __init__(
+ self,
+ data_dir: Union[str, Path],
+ num_features: int,
+ tokens_path: Optional[Union[str, Path]] = None,
+ lexicon_path: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ self.wordsep = "_"
+ self._use_word = use_words
+ self._prepend_wordsep = prepend_wordsep
+
+ self.data_dir = Path(data_dir)
+
+ self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
+
+ # Load the set of graphemes:
+ graphemes = set()
+ for _, form in self.forms.items():
+ for line in form:
+ graphemes.update(line["text"].lower())
+ self.graphemes = sorted(graphemes)
+
+ # Build the token-to-index and index-to-token maps.
+ if tokens_path is not None:
+ with open(tokens_path, "r") as f:
+ self.tokens = [line.strip() for line in f]
+ else:
+ self.tokens = self.graphemes
+
+ if lexicon_path is not None:
+ with open(lexicon_path, "r") as f:
+ lexicon = (line.strip().split() for line in f)
+ lexicon = {line[0]: line[1:] for line in lexicon}
+ self.lexicon = lexicon
+ else:
+ self.lexicon = None
+
+ self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
+ self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
+ self.num_features = num_features
+ self.text = []
+
+ @property
+ def num_tokens(self) -> int:
+ """Returns the number or tokens."""
+ return len(self.tokens)
+
+ @property
+ def use_words(self) -> bool:
+ """If words are used."""
+ return self._use_word
+
+ def extract_train_text(self) -> None:
+ """Extracts training text."""
+ keys = []
+ with open(self.data_dir / "task" / "trainset.txt") as f:
+ keys.extend((line.strip() for line in f))
+
+ for _, examples in self.forms.items():
+ for example in examples:
+ if example["key"] not in keys:
+ continue
+ self.text.append(example["text"].lower())
+
+ def to_index(self, line: str) -> torch.LongTensor:
+ """Converts text to a tensor of indices."""
+ token_to_index = self.graphemes_to_index
+ if self.lexicon is not None:
+ if len(line) > 0:
+ # If the word is not found in the lexicon, fall back to letters.
+ line = [
+ t
+ for w in line.split(self.wordsep)
+ for t in self.lexicon.get(w, self.wordsep + w)
+ ]
+ token_to_index = self.tokens_to_index
+ if self._prepend_wordsep:
+ line = itertools.chain([self.wordsep], line)
+ return torch.LongTensor([token_to_index[t] for t in line])
+
+ def to_text(self, indices: List[int]) -> str:
+ """Converts indices to text."""
+ # Roughly the inverse of `to_index`
+ encoding = self.graphemes
+ if self.lexicon is not None:
+ encoding = self.tokens
+ return self._post_process(encoding[i] for i in indices)
+
+ def tokens_to_text(self, indices: List[int]) -> str:
+ """Converts tokens to text."""
+ return self._post_process(self.tokens[i] for i in indices)
+
+ def _post_process(self, indices: List[int]) -> str:
+ """A list join."""
+ return "".join(indices).strip(self.wordsep)
+
+
+@click.command()
+@click.option("--data_dir", type=str, default=None, help="Path to iam dataset")
+@click.option(
+ "--use_words", is_flag=True, help="Load word segmented dataset instead of lines"
+)
+@click.option(
+ "--save_text", type=str, default=None, help="Path to save parsed train text"
+)
+@click.option("--save_tokens", type=str, default=None, help="Path to save tokens")
+def cli(
+ data_dir: Optional[str],
+ use_words: bool,
+ save_text: Optional[str],
+ save_tokens: Optional[str],
+) -> None:
+ """CLI for extracting text data from the iam dataset."""
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+
+ preprocessor = Preprocessor(data_dir, 64, use_words=use_words)
+ preprocessor.extract_train_text()
+
+ processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
+ logger.debug(f"Saving processed files at: {processed_dir}")
+
+ if save_text is not None:
+ logger.info("Saving training text")
+ with open(processed_dir / save_text, "w") as f:
+ f.write("\n".join(t for t in preprocessor.text))
+
+ if save_tokens is not None:
+ logger.info("Saving tokens")
+ with open(processed_dir / save_tokens, "w") as f:
+ f.write("\n".join(preprocessor.tokens))
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 8956b01..60987e0 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -1,14 +1,57 @@
"""Transforms for PyTorch datasets."""
+import random
+
import numpy as np
from PIL import Image
import torch
from torch import Tensor
import torch.nn.functional as F
-from torchvision.transforms import Compose, RandomAffine, RandomHorizontalFlip, ToTensor
+from torchvision import transforms
+from torchvision.transforms import (
+ ColorJitter,
+ Compose,
+ Normalize,
+ RandomAffine,
+ RandomHorizontalFlip,
+ RandomRotation,
+ ToPILImage,
+ ToTensor,
+)
from text_recognizer.datasets.util import EmnistMapper
+class RandomResizeCrop:
+ """Image transform with random resize and crop applied.
+
+ Stolen from
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
+
+ """
+
+ def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None:
+ self.jitter = jitter
+ self.ratio = ratio
+
+ def __call__(self, img: np.ndarray) -> np.ndarray:
+ """Applies random crop and rotation to an image."""
+ w, h = img.size
+
+ # pad with white:
+ img = transforms.functional.pad(img, self.jitter, fill=255)
+
+ # crop at random (x, y):
+ x = self.jitter + random.randint(-self.jitter, self.jitter)
+ y = self.jitter + random.randint(-self.jitter, self.jitter)
+
+ # randomize aspect ratio:
+ size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio)
+ size = (h, int(size_w))
+ img = transforms.functional.resized_crop(img, y, x, h, w, size)
+ return img
+
+
class Transpose:
"""Transposes the EMNIST image to the correct orientation."""
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index eb5dbce..7647d7e 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -5,6 +5,7 @@ from .crnn_model import CRNNModel
from .ctc_transformer_model import CTCTransformerModel
from .segmentation_model import SegmentationModel
from .transformer_model import TransformerModel
+from .vqvae_model import VQVAEModel
__all__ = [
"CharacterModel",
@@ -13,4 +14,5 @@ __all__ = [
"Model",
"SegmentationModel",
"TransformerModel",
+ "VQVAEModel",
]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index f2cd4b8..70f4cdb 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -332,7 +332,7 @@ class Model(ABC):
def summary(
self,
input_shape: Optional[Union[List, Tuple]] = None,
- depth: int = 4,
+ depth: int = 3,
device: Optional[str] = None,
) -> None:
"""Prints a summary of the network architecture."""
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py
index 12e497f..3f63053 100644
--- a/src/text_recognizer/models/transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -6,9 +6,9 @@ import torch
from torch import nn
from torch import Tensor
from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
from text_recognizer.datasets import EmnistMapper
+import text_recognizer.datasets.transforms as transforms
from text_recognizer.models.base import Model
from text_recognizer.networks import greedy_decoder
@@ -60,13 +60,19 @@ class TransformerModel(Model):
eos_token=self.eos_token,
lower=self.lower,
)
- self.tensor_transform = ToTensor()
-
+ self.tensor_transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])]
+ )
self.softmax = nn.Softmax(dim=2)
@torch.no_grad()
def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
src = self.network.extract_image_features(image)
+
+ # Added for vqvae transformer.
+ if isinstance(src, Tuple):
+ src = src[0]
+
memory = self.network.encoder(src)
confidence_of_predictions = []
diff --git a/src/text_recognizer/models/vqvae_model.py b/src/text_recognizer/models/vqvae_model.py
new file mode 100644
index 0000000..70f6f1f
--- /dev/null
+++ b/src/text_recognizer/models/vqvae_model.py
@@ -0,0 +1,80 @@
+"""Defines the VQVAEModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+
+
+class VQVAEModel(Model):
+ """Model for reconstructing images from codebook."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Initializes the CharacterModel."""
+
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.pad_token = dataset_args["args"]["pad_token"]
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=0)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ """Reconstruction of image.
+
+ Args:
+ image (Union[np.ndarray, torch.Tensor]): An image containing a character.
+
+ Returns:
+ Tuple[str, float]: The predicted character and the confidence in the prediction.
+
+ """
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ image_reconstructed, _ = self.forward(image)
+
+ return image_reconstructed
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index 2b624bb..bac5d28 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,4 +1,5 @@
"""Network modules."""
+from .cnn import CNN
from .cnn_transformer import CNNTransformer
from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
@@ -7,15 +8,19 @@ from .lenet import LeNet
from .metrics import accuracy, cer, wer
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
+from .transducer import TDS2d
from .transformer import Transformer
from .unet import UNet
from .util import sliding_window
from .vit import ViT
+from .vq_transformer import VQTransformer
+from .vqvae import VQVAE
from .wide_resnet import WideResidualNetwork
__all__ = [
"accuracy",
"cer",
+ "CNN",
"CNNTransformer",
"ConvolutionalRecurrentNetwork",
"DenseNet",
@@ -27,8 +32,11 @@ __all__ = [
"ResidualNetworkEncoder",
"sliding_window",
"UNet",
+ "TDS2d",
"Transformer",
"ViT",
+ "VQTransformer",
+ "VQVAE",
"wer",
"WideResidualNetwork",
]
diff --git a/src/text_recognizer/networks/cnn.py b/src/text_recognizer/networks/cnn.py
new file mode 100644
index 0000000..1807bb9
--- /dev/null
+++ b/src/text_recognizer/networks/cnn.py
@@ -0,0 +1,101 @@
+"""Implementation of a simple backbone cnn network."""
+from typing import Callable, Dict, Optional, Tuple
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+
+class CNN(nn.Module):
+ """LeNet network for character prediction."""
+
+ def __init__(
+ self,
+ channels: Tuple[int, ...] = (1, 32, 64, 128),
+ kernel_sizes: Tuple[int, ...] = (4, 4, 4),
+ strides: Tuple[int, ...] = (2, 2, 2),
+ max_pool_kernel: int = 2,
+ dropout_rate: float = 0.2,
+ activation: Optional[str] = "relu",
+ ) -> None:
+ """Initialization of the LeNet network.
+
+ Args:
+ channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
+ kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
+ strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2).
+ max_pool_kernel (int): 2D max pooling kernel. Defaults to 2.
+ dropout_rate (float): The dropout rate. Defaults to 0.2.
+ activation (Optional[str]): The name of non-linear activation function. Defaults to relu.
+
+ Raises:
+ RuntimeError: if the number of hyperparameters does not match in length.
+
+ """
+ super().__init__()
+
+ if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides):
+ raise RuntimeError("The number of the hyperparameters does not match.")
+
+ self.cnn = self._build_network(
+ channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation,
+ )
+
+ def _build_network(
+ self,
+ channels: Tuple[int, ...],
+ kernel_sizes: Tuple[int, ...],
+ strides: Tuple[int, ...],
+ max_pool_kernel: int,
+ dropout_rate: float,
+ activation: str,
+ ) -> nn.Sequential:
+ # Load activation function.
+ activation_fn = activation_function(activation)
+
+ channels = list(channels)
+ in_channels = channels.pop(0)
+ configuration = zip(channels, kernel_sizes, strides)
+
+ modules = nn.ModuleList([])
+
+ for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+ # Add max pool to reduce output size.
+ if i == len(channels) // 2:
+ modules.append(nn.MaxPool2d(max_pool_kernel))
+ if i == 0:
+ modules.append(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=1
+ )
+ )
+ else:
+ modules.append(
+ nn.Sequential(
+ activation_fn,
+ nn.BatchNorm2d(in_channels),
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=1,
+ ),
+ )
+ )
+
+ if dropout_rate:
+ modules.append(nn.Dropout2d(p=dropout_rate))
+
+ in_channels = out_channels
+
+ return nn.Sequential(*modules)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.cnn(x)
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index 43e5403..7133c26 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -29,14 +29,22 @@ class CNNTransformer(nn.Module):
backbone: str,
backbone_args: Optional[Dict] = None,
activation: str = "gelu",
+ pool_kernel: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
self.trg_pad_index = trg_pad_index
self.vocab_size = vocab_size
self.backbone = configure_backbone(backbone, backbone_args)
+
+ if pool_kernel is not None:
+ self.max_pool = nn.MaxPool2d(pool_kernel, stride=2)
+ else:
+ self.max_pool = None
+
self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.pos_dropout = nn.Dropout(p=dropout_rate)
self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
nn.init.normal_(self.character_embedding.weight, std=0.02)
@@ -98,18 +106,23 @@ class CNNTransformer(nn.Module):
# If batch dimension is missing, it needs to be added.
if len(src.shape) < 4:
src = src[(None,) * (4 - len(src.shape))]
+
src = self.backbone(src)
+ if self.max_pool is not None:
+ src = self.max_pool(src)
+
if self.adaptive_pool is not None:
src = rearrange(src, "b c h w -> b w c h")
src = self.adaptive_pool(src)
src = src.squeeze(3)
else:
- src = rearrange(src, "b c h w -> b (w h) c")
+ src = rearrange(src, "b c h w -> b (h w) c")
b, t, _ = src.shape
src += self.src_position_embedding[:, :t]
+ src = self.pos_dropout(src)
return src
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py
index ffad792..2605731 100644
--- a/src/text_recognizer/networks/metrics.py
+++ b/src/text_recognizer/networks/metrics.py
@@ -1,4 +1,7 @@
"""Utility functions for models."""
+from typing import Optional
+
+from einops import rearrange
import Levenshtein as Lev
import torch
from torch import Tensor
@@ -32,22 +35,33 @@ def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
return acc
-def cer(outputs: Tensor, targets: Tensor) -> float:
+def cer(
+ outputs: Tensor,
+ targets: Tensor,
+ batch_size: Optional[int] = None,
+ blank_label: Optional[int] = int,
+) -> float:
"""Computes the character error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
+ batch_size (Optional[int]): Batch size if target and output has been flattend.
+ blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
Returns:
float: The cer for the batch.
"""
+ if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+ targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+ outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths
+ outputs, targets, target_lengths, blank_label=blank_label,
)
lev_dist = 0
@@ -63,22 +77,33 @@ def cer(outputs: Tensor, targets: Tensor) -> float:
return lev_dist / len(decoded_predictions)
-def wer(outputs: Tensor, targets: Tensor) -> float:
+def wer(
+ outputs: Tensor,
+ targets: Tensor,
+ batch_size: Optional[int] = None,
+ blank_label: Optional[int] = int,
+) -> float:
"""Computes the Word error rate.
Args:
outputs (Tensor): The output from the network.
targets (Tensor): Ground truth labels.
+ batch_size (optional[int]): Batch size if target and output has been flattend.
+ blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
Returns:
float: The wer for the batch.
"""
+ if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+ targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+ outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
target_lengths = torch.full(
size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
)
decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths
+ outputs, targets, target_lengths, blank_label=blank_label,
)
lev_dist = 0
diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py
new file mode 100644
index 0000000..fdd6662
--- /dev/null
+++ b/src/text_recognizer/networks/transducer/__init__.py
@@ -0,0 +1,2 @@
+"""Transducer modules."""
+from .tds_conv import TDS2d
diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py
new file mode 100644
index 0000000..018caf2
--- /dev/null
+++ b/src/text_recognizer/networks/transducer/tds_conv.py
@@ -0,0 +1,205 @@
+"""Time-Depth Separable Convolutions.
+
+References:
+ https://arxiv.org/abs/1904.02619
+ https://arxiv.org/pdf/2010.01003.pdf
+
+Code stolen from:
+ https://github.com/facebookresearch/gtn_applications
+
+
+"""
+from typing import List, Tuple
+
+from einops import rearrange
+import gtn
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class TDSBlock2d(nn.Module):
+ """Internal block of a 2D TDSC network."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ img_depth: int,
+ kernel_size: Tuple[int],
+ dropout_rate: float,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.img_depth = img_depth
+ self.kernel_size = kernel_size
+ self.dropout_rate = dropout_rate
+ self.fc_dim = in_channels * img_depth
+
+ # Network placeholders.
+ self.conv = None
+ self.mlp = None
+ self.instance_norm = None
+
+ self._build_block()
+
+ def _build_block(self) -> None:
+ # Convolutional block.
+ self.conv = nn.Sequential(
+ nn.Conv3d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=(1, self.kernel_size[0], self.kernel_size[1]),
+ padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+ ),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ )
+
+ # MLP block.
+ self.mlp = nn.Sequential(
+ nn.Linear(self.fc_dim, self.fc_dim),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ nn.Linear(self.fc_dim, self.fc_dim),
+ nn.Dropout(self.dropout_rate),
+ )
+
+ # Instance norm.
+ self.instance_norm = nn.ModuleList(
+ [
+ nn.InstanceNorm2d(self.fc_dim, affine=True),
+ nn.InstanceNorm2d(self.fc_dim, affine=True),
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ x (Tensor): Input tensor.
+
+ Shape:
+ - x: :math: `(B, CD, H, W)`
+
+ Returns:
+ Tensor: Output tensor.
+
+ """
+ B, CD, H, W = x.shape
+ C, D = self.in_channels, self.img_depth
+ residual = x
+ x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D)
+ x = self.conv(x)
+ x = rearrange(x, "b c d h w -> b (c d) h w")
+ x += residual
+
+ x = self.instance_norm[0](x)
+
+ x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x
+ x + self.instance_norm[1](x)
+
+ # Output shape: [B, CD, H, W]
+ return x
+
+
+class TDS2d(nn.Module):
+ """TDS Netowrk.
+
+ Structure is the following:
+ Downsample layer -> TDS2d group -> ... -> Linear output layer
+
+
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ depth: int,
+ tds_groups: Tuple[int],
+ kernel_size: Tuple[int],
+ dropout_rate: float,
+ in_channels: int = 1,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.depth = depth
+ self.tds_groups = tds_groups
+ self.kernel_size = kernel_size
+ self.dropout_rate = dropout_rate
+
+ self.tds = None
+ self.fc = None
+
+ def _build_network(self) -> None:
+
+ modules = []
+ stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
+ if self.input_dim % stride_h:
+ raise RuntimeError(
+ f"Image height not divisible by total stride {stride_h}."
+ )
+
+ for tds_group in self.tds_groups:
+ # Add downsample layer.
+ out_channels = self.depth * tds_group["channels"]
+ modules.extend(
+ [
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=out_channels,
+ kernel_size=self.kernel_size,
+ padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+ stride=tds_group["stride"],
+ ),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ nn.InstanceNorm2d(out_channels, affine=True),
+ ]
+ )
+
+ for _ in range(tds_group["num_blocks"]):
+ modules.append(
+ TDSBlock2d(
+ tds_group["channels"],
+ self.depth,
+ self.kernel_size,
+ self.dropout_rate,
+ )
+ )
+
+ self.in_channels = out_channels
+
+ self.tds = nn.Sequential(*modules)
+ self.fc = nn.Linear(
+ self.in_channels * self.input_dim // stride_h, self.output_dim
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ x (Tensor): Input tensor.
+
+ Shape:
+ - x: :math: `(B, H, W)`
+
+ Returns:
+ Tensor: Output tensor.
+
+ """
+ B, H, W = x.shape
+ x = rearrange(
+ x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
+ )
+ x = self.tds(x)
+
+ # x shape: [B, C, H, W]
+ x = rearrange(x, "b c h w -> b w (c h)")
+
+ return self.fc(x)
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
index 711a952..131a6b4 100644
--- a/src/text_recognizer/networks/util.py
+++ b/src/text_recognizer/networks/util.py
@@ -65,13 +65,18 @@ def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
network_args = state_dict["network_args"]
weights = state_dict["model_state"]
+ freeze = False
+ if "freeze" in backbone_args and backbone_args["freeze"] is True:
+ backbone_args.pop("freeze")
+ freeze = True
+ network_args = backbone_args
+
# Initializes the network with trained weights.
backbone = backbone_(**network_args)
backbone.load_state_dict(weights)
- if "freeze" in backbone_args and backbone_args["freeze"] is True:
+ if freeze:
for params in backbone.parameters():
params.requires_grad = False
-
else:
backbone_ = getattr(network_module, backbone)
backbone = backbone_(**backbone_args)
diff --git a/src/text_recognizer/networks/vq_transformer.py b/src/text_recognizer/networks/vq_transformer.py
new file mode 100644
index 0000000..c673d96
--- /dev/null
+++ b/src/text_recognizer/networks/vq_transformer.py
@@ -0,0 +1,150 @@
+"""A VQ-Transformer for image to text recognition."""
+from typing import Dict, Optional, Tuple
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.util import configure_backbone
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class VQTransformer(nn.Module):
+ """VQ+Transfomer for image to character sequence prediction."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ adaptive_pool_dim: Tuple,
+ expansion_dim: int,
+ dropout_rate: float,
+ trg_pad_index: int,
+ max_len: int,
+ backbone: str,
+ backbone_args: Optional[Dict] = None,
+ activation: str = "gelu",
+ ) -> None:
+ super().__init__()
+
+ # Configure vector quantized backbone.
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.conv = nn.Sequential(
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2),
+ nn.ReLU(inplace=True),
+ )
+
+ # Configure embeddings for Transformer network.
+ self.trg_pad_index = trg_pad_index
+ self.vocab_size = vocab_size
+ self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+ self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ nn.init.normal_(self.character_embedding.weight, std=0.02)
+
+ self.adaptive_pool = (
+ nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+ )
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]:
+ """Extracts image features with a backbone neural network.
+
+ It seem like the winning idea was to swap channels and width dimension and collapse
+ the height dimension. The transformer is learning like a baby with this implementation!!! :D
+ Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: The input src to the transformer and the vq loss.
+
+ """
+ # If batch dimension is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+ src, vq_loss = self.backbone.encode(src)
+ # src = self.backbone.decoder.res_block(src)
+ src = self.conv(src)
+
+ if self.adaptive_pool is not None:
+ src = rearrange(src, "b c h w -> b w c h")
+ src = self.adaptive_pool(src)
+ src = src.squeeze(3)
+ else:
+ src = rearrange(src, "b c h w -> b (w h) c")
+
+ b, t, _ = src.shape
+
+ src += self.src_position_embedding[:, :t]
+
+ return src, vq_loss
+
+ def target_embedding(self, trg: Tensor) -> Tensor:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tensor: Encoded target tensor.
+
+ """
+ trg = self.character_embedding(trg.long())
+ trg = self.trg_position_encoding(trg)
+ return trg
+
+ def decode_image_features(
+ self, image_features: Tensor, trg: Optional[Tensor] = None
+ ) -> Tensor:
+ """Takes images features from the backbone and decodes them with the transformer."""
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.target_embedding(trg)
+ out = self.transformer(image_features, trg, trg_mask=trg_mask)
+
+ logits = self.head(out)
+ return logits
+
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ image_features, vq_loss = self.extract_image_features(x)
+ logits = self.decode_image_features(image_features, trg)
+ return logits, vq_loss
diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py
index e1f05fa..763953c 100644
--- a/src/text_recognizer/networks/vqvae/__init__.py
+++ b/src/text_recognizer/networks/vqvae/__init__.py
@@ -1 +1,5 @@
"""VQ-VAE module."""
+from .decoder import Decoder
+from .encoder import Encoder
+from .vector_quantizer import VectorQuantizer
+from .vqvae import VQVAE
diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/src/text_recognizer/networks/vqvae/decoder.py
new file mode 100644
index 0000000..8847aba
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/decoder.py
@@ -0,0 +1,133 @@
+"""CNN decoder for the VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class Decoder(nn.Module):
+ """A CNN encoder network."""
+
+ def __init__(
+ self,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ upsampling: Optional[List[List[int]]] = None,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.upsampling = upsampling
+
+ self.res_block = nn.ModuleList([])
+ self.upsampling_block = nn.ModuleList([])
+
+ self.embedding_dim = embedding_dim
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.decoder = self._build_decoder(
+ channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ )
+
+ def _build_decompression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+ modules.append(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=1,
+ ),
+ activation,
+ )
+ )
+
+ if i < len(self.upsampling):
+ modules.append(nn.Upsample(size=self.upsampling[i]),)
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ modules.extend(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1
+ ),
+ nn.Tanh(),
+ )
+ )
+
+ return modules
+
+ def _build_decoder(
+ self,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+
+ self.res_block.append(
+ nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ )
+
+ # Bottleneck module.
+ self.res_block.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[0], channels[0], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ # Decompression module
+ self.upsampling_block.extend(
+ self._build_decompression_block(
+ channels[0], channels[1:], kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ self.res_block = nn.Sequential(*self.res_block)
+ self.upsampling_block = nn.Sequential(*self.upsampling_block)
+
+ return nn.Sequential(self.res_block, self.upsampling_block)
+
+ def forward(self, z_q: Tensor) -> Tensor:
+ """Reconstruct input from given codes."""
+ x_reconstruction = self.decoder(z_q)
+ return x_reconstruction
diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py
index 60c4c43..d3adac5 100644
--- a/src/text_recognizer/networks/vqvae/encoder.py
+++ b/src/text_recognizer/networks/vqvae/encoder.py
@@ -1,6 +1,5 @@
"""CNN encoder for the VQ-VAE."""
-
-from typing import List, Optional, Type
+from typing import List, Optional, Tuple, Type
import torch
from torch import nn
@@ -12,16 +11,12 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self,
- in_channels: int,
- out_channels: int,
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
+ self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
- activation,
+ nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
]
@@ -42,23 +37,111 @@ class Encoder(nn.Module):
self,
in_channels: int,
channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
num_residual_layers: int,
embedding_dim: int,
num_embeddings: int,
beta: float = 0.25,
- activation: str = "elu",
+ activation: str = "leaky_relu",
dropout_rate: float = 0.0,
) -> None:
super().__init__()
- pass
- # if dropout_rate:
- # if activation == "selu":
- # dropout = nn.AlphaDropout(p=dropout_rate)
- # else:
- # dropout = nn.Dropout(p=dropout_rate)
- # else:
- # dropout = None
-
- def _build_encoder(self) -> nn.Sequential:
- # TODO: Continue to implement encoder.
- pass
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.embedding_dim = embedding_dim
+ self.num_embeddings = num_embeddings
+ self.beta = beta
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.encoder = self._build_encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
+ )
+
+ # Configure Vector Quantizer.
+ self.vector_quantizer = VectorQuantizer(
+ self.num_embeddings, self.embedding_dim, self.beta
+ )
+
+ def _build_compression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for out_channels, kernel_size, stride in configuration:
+ modules.append(
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=1
+ ),
+ activation,
+ )
+ )
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ return modules
+
+ def _build_encoder(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+ encoder = nn.ModuleList([])
+
+ # compression module
+ encoder.extend(
+ self._build_compression_block(
+ in_channels, channels, kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ # Bottleneck module.
+ encoder.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[-1], channels[-1], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ encoder.append(
+ nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ )
+
+ return nn.Sequential(*encoder)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input into a discrete representation."""
+ z_e = self.encoder(x)
+ z_q, vq_loss = self.vector_quantizer(z_e)
+ return z_q, vq_loss
diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py
index 25e5583..f92c7ee 100644
--- a/src/text_recognizer/networks/vqvae/vector_quantizer.py
+++ b/src/text_recognizer/networks/vqvae/vector_quantizer.py
@@ -26,7 +26,7 @@ class VectorQuantizer(nn.Module):
self.embedding = nn.Embedding(self.K, self.D)
# Initialize the codebook.
- self.embedding.weight.uniform_(-1 / self.K, 1 / self.K)
+ nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K)
def discretization_bottleneck(self, latent: Tensor) -> Tensor:
"""Computes the code nearest to the latent representation.
diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/src/text_recognizer/networks/vqvae/vqvae.py
new file mode 100644
index 0000000..50448b4
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/vqvae.py
@@ -0,0 +1,74 @@
+"""The VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.vqvae import Decoder, Encoder
+
+
+class VQVAE(nn.Module):
+ """Vector Quantized Variational AutoEncoder."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ num_embeddings: int,
+ upsampling: Optional[List[List[int]]] = None,
+ beta: float = 0.25,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ # configure encoder.
+ self.encoder = Encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ num_embeddings,
+ beta,
+ activation,
+ dropout_rate,
+ )
+
+ # Configure decoder.
+ channels.reverse()
+ kernel_sizes.reverse()
+ strides.reverse()
+ self.decoder = Decoder(
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ upsampling,
+ activation,
+ dropout_rate,
+ )
+
+ def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input to a latent code."""
+ return self.encoder(x)
+
+ def decode(self, z_q: Tensor) -> Tensor:
+ """Reconstructs input from latent codes."""
+ return self.decoder(z_q)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Compresses and decompresses input."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ z_q, vq_loss = self.encode(x)
+ x_reconstruction = self.decode(z_q)
+ return x_reconstruction, vq_loss
diff --git a/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt b/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt
new file mode 100644
index 0000000..b5295c2
--- /dev/null
+++ b/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt
Binary files differ