summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/__init__.py4
-rw-r--r--src/text_recognizer/datasets/dataset.py22
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py3
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py9
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py6
-rw-r--r--src/text_recognizer/datasets/transforms.py33
-rw-r--r--src/text_recognizer/datasets/util.py29
-rw-r--r--src/text_recognizer/models/__init__.py11
-rw-r--r--src/text_recognizer/models/base.py55
-rw-r--r--src/text_recognizer/models/character_model.py1
-rw-r--r--src/text_recognizer/models/line_ctc_model.py8
-rw-r--r--src/text_recognizer/models/vision_transformer_model.py117
-rw-r--r--src/text_recognizer/networks/__init__.py18
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py111
-rw-r--r--src/text_recognizer/networks/crnn.py (renamed from src/text_recognizer/networks/line_lstm_ctc.py)58
-rw-r--r--src/text_recognizer/networks/densenet.py225
-rw-r--r--src/text_recognizer/networks/lenet.py6
-rw-r--r--src/text_recognizer/networks/loss.py (renamed from src/text_recognizer/networks/losses.py)3
-rw-r--r--src/text_recognizer/networks/mlp.py6
-rw-r--r--src/text_recognizer/networks/residual_network.py6
-rw-r--r--src/text_recognizer/networks/sparse_mlp.py78
-rw-r--r--src/text_recognizer/networks/transformer.py5
-rw-r--r--src/text_recognizer/networks/transformer/__init__.py3
-rw-r--r--src/text_recognizer/networks/transformer/attention.py93
-rw-r--r--src/text_recognizer/networks/transformer/positional_encoding.py31
-rw-r--r--src/text_recognizer/networks/transformer/sparse_transformer.py1
-rw-r--r--src/text_recognizer/networks/transformer/transformer.py241
-rw-r--r--src/text_recognizer/networks/util.py (renamed from src/text_recognizer/networks/misc.py)40
-rw-r--r--src/text_recognizer/networks/vision_transformer.py158
-rw-r--r--src/text_recognizer/networks/wide_resnet.py6
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.ptbin0 -> 1273881 bytes
-rw-r--r--src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.ptbin5701134 -> 3457858 bytes
32 files changed, 1287 insertions, 100 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index a3af9b1..d8372e3 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,5 +1,5 @@
"""Dataset modules."""
-from .emnist_dataset import EmnistDataset, Transpose
+from .emnist_dataset import EmnistDataset
from .emnist_lines_dataset import (
construct_image_from_string,
EmnistLinesDataset,
@@ -8,6 +8,7 @@ from .emnist_lines_dataset import (
from .iam_dataset import IamDataset
from .iam_lines_dataset import IamLinesDataset
from .iam_paragraphs_dataset import IamParagraphsDataset
+from .transforms import AddTokens, Transpose
from .util import (
_download_raw_dataset,
compute_sha256,
@@ -19,6 +20,7 @@ from .util import (
__all__ = [
"_download_raw_dataset",
+ "AddTokens",
"compute_sha256",
"construct_image_from_string",
"DATA_DIRNAME",
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index 05520e5..2de7f09 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -18,6 +18,9 @@ class Dataset(data.Dataset):
subsample_fraction: float = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
) -> None:
"""Initialization of Dataset class.
@@ -26,12 +29,14 @@ class Dataset(data.Dataset):
subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
Raises:
ValueError: If subsample_fraction is not None and outside the range (0, 1).
"""
-
self.train = train
self.split = "train" if self.train else "test"
@@ -40,19 +45,18 @@ class Dataset(data.Dataset):
raise ValueError("The subsample fraction must be in (0, 1).")
self.subsample_fraction = subsample_fraction
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(
+ init_token=init_token, eos_token=eos_token, pad_token=pad_token
+ )
self._input_shape = self._mapper.input_shape
self._output_shape = self._mapper._num_classes
self.num_classes = self.mapper.num_classes
# Set transforms.
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
+ self.transform = transform if transform is not None else ToTensor()
+ self.target_transform = (
+ target_transform if target_transform is not None else torch.tensor
+ )
self._data = None
self._targets = None
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index d01dcee..a8901d6 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -53,9 +53,6 @@ class EmnistDataset(Dataset):
if transform is None:
self.transform = Compose([Transpose(), ToTensor()])
- # The EMNIST dataset is already casted to tensors.
- self.target_transform = target_transform
-
self.seed = seed
def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index beb5343..6091da8 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -37,6 +37,9 @@ class EmnistLinesDataset(Dataset):
max_overlap: float = 0.33,
num_samples: int = 10000,
seed: int = 4711,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
) -> None:
"""Set attributes and loads the dataset.
@@ -50,6 +53,9 @@ class EmnistLinesDataset(Dataset):
max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
num_samples (int): Number of samples to generate. Defaults to 10000.
seed (int): Seed number. Defaults to 4711.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
"""
super().__init__(
@@ -57,6 +63,9 @@ class EmnistLinesDataset(Dataset):
transform=transform,
target_transform=target_transform,
subsample_fraction=subsample_fraction,
+ init_token=init_token,
+ pad_token=pad_token,
+ eos_token=eos_token,
)
# Extract dataset information.
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
index 4a74b2b..fdd2fe6 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -32,12 +32,18 @@ class IamLinesDataset(Dataset):
subsample_fraction: float = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
) -> None:
super().__init__(
train=train,
subsample_fraction=subsample_fraction,
transform=transform,
target_transform=target_transform,
+ init_token=init_token,
+ pad_token=pad_token,
+ eos_token=eos_token,
)
@property
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 17231a8..c058972 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -3,6 +3,9 @@ import numpy as np
from PIL import Image
import torch
from torch import Tensor
+from torchvision.transforms import Compose, ToTensor
+
+from text_recognizer.datasets.util import EmnistMapper
class Transpose:
@@ -11,3 +14,33 @@ class Transpose:
def __call__(self, image: Image) -> np.ndarray:
"""Swaps axis."""
return np.array(image).swapaxes(0, 1)
+
+
+class AddTokens:
+ """Adds start of sequence and end of sequence tokens to target tensor."""
+
+ def __init__(self, init_token: str, pad_token: str, eos_token: str,) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ self.pad_value = self.emnist_mapper(self.pad_token)
+ self.sos_value = self.emnist_mapper(self.init_token)
+ self.eos_value = self.emnist_mapper(self.eos_token)
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Adds a sos token to the begining and a eos token to the end of a target sequence."""
+ dtype, device = target.dtype, target.device
+ sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
+
+ # Find the where padding starts.
+ pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
+
+ target[pad_index] = self.eos_value
+
+ target = torch.cat([sos, target], dim=0)
+ return target
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 125f05a..d2df8b5 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -4,6 +4,7 @@ import importlib
import json
import os
from pathlib import Path
+import string
from typing import Callable, Dict, List, Optional, Type, Union
from urllib.request import urlopen, urlretrieve
@@ -43,11 +44,21 @@ def download_emnist() -> None:
class EmnistMapper:
"""Mapper between network output to Emnist character."""
- def __init__(self) -> None:
+ def __init__(
+ self,
+ pad_token: str,
+ init_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ ) -> None:
"""Loads the emnist essentials file with the mapping and input shape."""
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+
self.essentials = self._load_emnist_essentials()
# Load dataset infromation.
- self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
+ self._mapping = dict(self.essentials["mapping"])
+ self._augment_emnist_mapping()
self._inverse_mapping = {v: k for k, v in self.mapping.items()}
self._num_classes = len(self.mapping)
self._input_shape = self.essentials["input_shape"]
@@ -103,7 +114,7 @@ class EmnistMapper:
essentials = json.load(f)
return essentials
- def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
+ def _augment_emnist_mapping(self) -> None:
"""Augment the mapping with extra symbols."""
# Extra symbols in IAM dataset
extra_symbols = [
@@ -127,14 +138,20 @@ class EmnistMapper:
]
# padding symbol, and acts as blank symbol as well.
- extra_symbols.append("_")
+ extra_symbols.append(self.pad_token)
+
+ if self.init_token is not None:
+ extra_symbols.append(self.init_token)
+
+ if self.eos_token is not None:
+ extra_symbols.append(self.eos_token)
- max_key = max(mapping.keys())
+ max_key = max(self.mapping.keys())
extra_mapping = {}
for i, symbol in enumerate(extra_symbols):
extra_mapping[max_key + 1 + i] = symbol
- return {**mapping, **extra_mapping}
+ self._mapping = {**self.mapping, **extra_mapping}
def compute_sha256(filename: Union[Path, str]) -> str:
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index a3cfc15..0855079 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -3,5 +3,14 @@ from .base import Model
from .character_model import CharacterModel
from .line_ctc_model import LineCTCModel
from .metrics import accuracy, cer, wer
+from .vision_transformer_model import VisionTransformerModel
-__all__ = ["Model", "cer", "CharacterModel", "LineCTCModel", "accuracy", "wer"]
+__all__ = [
+ "Model",
+ "cer",
+ "CharacterModel",
+ "CNNTransfromerModel",
+ "LineCTCModel",
+ "accuracy",
+ "wer",
+]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index e89b670..cbef787 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -6,7 +6,7 @@ import importlib
from pathlib import Path
import re
import shutil
-from typing import Callable, Dict, Optional, Tuple, Type
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from loguru import logger
import torch
@@ -15,6 +15,7 @@ from torch import Tensor
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.utils.data import DataLoader, Dataset, random_split
from torchsummary import summary
+from torchvision.transforms import Compose
from text_recognizer.datasets import EmnistMapper
@@ -128,16 +129,41 @@ class Model(ABC):
self._configure_criterion()
self._configure_optimizers()
- # Prints a summary of the network in terminal.
- self.summary()
-
# Set this flag to true to prevent the model from configuring again.
self.is_configured = True
+ def _configure_transforms(self) -> None:
+ # Load transforms.
+ transforms_module = importlib.import_module(
+ "text_recognizer.datasets.transforms"
+ )
+ if (
+ "transform" in self.dataset_args["args"]
+ and self.dataset_args["args"]["transform"] is not None
+ ):
+ transform_ = [
+ getattr(transforms_module, t["type"])()
+ for t in self.dataset_args["args"]["transform"]
+ ]
+ self.dataset_args["args"]["transform"] = Compose(transform_)
+ if (
+ "target_transform" in self.dataset_args["args"]
+ and self.dataset_args["args"]["target_transform"] is not None
+ ):
+ target_transform_ = [
+ torch.tensor,
+ ]
+ for t in self.dataset_args["args"]["target_transform"]:
+ args = t["args"] or {}
+ target_transform_.append(getattr(transforms_module, t["type"])(**args))
+ self.dataset_args["args"]["target_transform"] = Compose(target_transform_)
+
def prepare_data(self) -> None:
"""Prepare data for training."""
# TODO add downloading.
if not self.data_prepared:
+ self._configure_transforms()
+
# Load train dataset.
train_dataset = self.dataset(train=True, **self.dataset_args["args"])
train_dataset.load_or_generate_data()
@@ -327,20 +353,20 @@ class Model(ABC):
else:
return self.network(x)
- def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
- """Compute the loss."""
- return self.criterion(output, targets)
-
def summary(
- self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3
+ self,
+ input_shape: Optional[Union[List, Tuple]] = None,
+ depth: int = 4,
+ device: Optional[str] = None,
) -> None:
"""Prints a summary of the network architecture."""
+ device = self.device if device is None else device
if input_shape is not None:
- summary(self.network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=device)
elif self._input_shape is not None:
input_shape = (1,) + tuple(self._input_shape)
- summary(self.network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=device)
else:
logger.warning("Could not print summary as input shape is not set.")
@@ -364,18 +390,21 @@ class Model(ABC):
return state
- def load_from_checkpoint(self, checkpoint_path: Path) -> None:
+ def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
"""Load a previously saved checkpoint.
Args:
checkpoint_path (Path): Path to the experiment with the checkpoint.
"""
+ checkpoint_path = Path(checkpoint_path)
+ self.prepare_data()
+ self.configure_model()
logger.debug("Loading checkpoint...")
if not checkpoint_path.exists():
logger.debug("File does not exist {str(checkpoint_path)}")
- checkpoint = torch.load(str(checkpoint_path))
+ checkpoint = torch.load(str(checkpoint_path), map_location=self.device)
self._network.load_state_dict(checkpoint["model_state"])
if self._optimizer is not None:
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 50e94a2..3cf6695 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -65,6 +65,7 @@ class CharacterModel(Model):
Tuple[str, float]: The predicted character and the confidence in the prediction.
"""
+ self.eval()
if image.dtype == np.uint8:
# Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
index 16eaed3..cdc2d8b 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -51,7 +51,7 @@ class LineCTCModel(Model):
self._mapper = EmnistMapper()
self.tensor_transform = ToTensor()
- def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
"""Computes the CTC loss.
Args:
@@ -82,11 +82,13 @@ class LineCTCModel(Model):
torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
)
- return self.criterion(output, targets, input_lengths, target_lengths)
+ return self._criterion(output, targets, input_lengths, target_lengths)
@torch.no_grad()
def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
"""Predict on a single input."""
+ self.eval()
+
if image.dtype == np.uint8:
# Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
image = self.tensor_transform(image)
@@ -110,6 +112,6 @@ class LineCTCModel(Model):
log_probs, _ = log_probs.max(dim=2)
predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = torch.exp(log_probs.sum()).item()
+ confidence_of_prediction = torch.exp(-log_probs.sum()).item()
return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py
new file mode 100644
index 0000000..20bd4ca
--- /dev/null
+++ b/src/text_recognizer/models/vision_transformer_model.py
@@ -0,0 +1,117 @@
+"""Defines the CNN-Transformer class."""
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class VisionTransformerModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.init_token = dataset_args["args"]["init_token"]
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.eos_token = dataset_args["args"]["eos_token"]
+ if network_args is not None:
+ self.max_len = network_args["max_len"]
+ else:
+ self.max_len = 128
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ self.tensor_transform = ToTensor()
+
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
+ src = self.network.preprocess_input(image)
+ memory = self.network.encoder(src)
+
+ confidence_of_predictions = []
+ trg_indices = [self.mapper(self.init_token)]
+
+ for _ in range(self.max_len):
+ trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
+ trg, trg_mask = self.network.preprocess_target(trg)
+ logits = self.network.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+
+ # Convert logits to probabilities.
+ probs = self.softmax(logits)
+
+ pred_token = probs.argmax(2)[:, -1].item()
+ confidence = probs.max(2).values[:, -1].item()
+
+ trg_indices.append(pred_token)
+ confidence_of_predictions.append(confidence)
+
+ if pred_token == self.mapper(self.eos_token):
+ break
+
+ confidence = np.min(confidence_of_predictions)
+ predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]])
+
+ return predicted_characters, confidence
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ predicted_characters, confidence_of_prediction = self._generate_sentence(image)
+
+ return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index a39975f..8b87797 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,21 +1,31 @@
"""Network modules."""
+from .cnn_transformer import CNNTransformer
+from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
+from .densenet import DenseNet
from .lenet import LeNet
-from .line_lstm_ctc import LineRecurrentNetwork
-from .losses import EmbeddingLoss
-from .misc import sliding_window
+from .loss import EmbeddingLoss
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
+from .sparse_mlp import SparseMLP
+from .transformer import Transformer
+from .util import sliding_window
+from .vision_transformer import VisionTransformer
from .wide_resnet import WideResidualNetwork
__all__ = [
+ "CNNTransformer",
+ "ConvolutionalRecurrentNetwork",
+ "DenseNet",
"EmbeddingLoss",
"greedy_decoder",
"MLP",
"LeNet",
- "LineRecurrentNetwork",
"ResidualNetwork",
"ResidualNetworkEncoder",
"sliding_window",
+ "Transformer",
+ "SparseMLP",
+ "VisionTransformer",
"WideResidualNetwork",
]
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
new file mode 100644
index 0000000..8666f11
--- /dev/null
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -0,0 +1,111 @@
+"""A DETR style transfomers but for text recognition."""
+from typing import Dict, Optional, Tuple, Type
+
+from einops.layers.torch import Rearrange
+from loguru import logger
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import configure_backbone
+
+
+class CNNTransformer(nn.Module):
+ """CNN+Transfomer for image to sequence prediction, sort of based on the ideas from DETR."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ max_len: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ trg_pad_index: int,
+ backbone: str,
+ backbone_args: Optional[Dict] = None,
+ activation: str = "gelu",
+ ) -> None:
+ super().__init__()
+ self.trg_pad_index = trg_pad_index
+ self.backbone_args = backbone_args
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+ self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
+ self.collapse_spatial_dim = nn.Sequential(
+ Rearrange("b t h w -> b t (h w)"), nn.AdaptiveAvgPool2d((None, hidden_dim))
+ )
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+ self.head = nn.Linear(hidden_dim, vocab_size)
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def preprocess_input(self, src: Tensor) -> Tensor:
+ """Encodes src with a backbone network and a positional encoding.
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: A input src to the transformer.
+
+ """
+ # If batch dimenstion is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+ src = self.backbone(src)
+ src = self.collapse_spatial_dim(src)
+ src = self.position_encoding(src)
+ return src
+
+ def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+ """
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.character_embedding(trg.long())
+ trg = self.position_encoding(trg)
+ return trg, trg_mask
+
+ def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ src = self.preprocess_input(x)
+ trg, trg_mask = self.preprocess_target(trg)
+ out = self.transformer(src, trg, trg_mask=trg_mask)
+ logits = self.head(out)
+ return logits
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/crnn.py
index 9009f94..3e605e2 100644
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ b/src/text_recognizer/networks/crnn.py
@@ -10,15 +10,16 @@ import torch
from torch import nn
from torch import Tensor
+from text_recognizer.networks.util import configure_backbone
-class LineRecurrentNetwork(nn.Module):
+
+class ConvolutionalRecurrentNetwork(nn.Module):
"""Network that takes a image of a text line and predicts tokens that are in the image."""
def __init__(
self,
backbone: str,
backbone_args: Dict = None,
- flatten: bool = True,
input_size: int = 128,
hidden_size: int = 128,
bidirectional: bool = False,
@@ -26,6 +27,7 @@ class LineRecurrentNetwork(nn.Module):
num_classes: int = 80,
patch_size: Tuple[int, int] = (28, 28),
stride: Tuple[int, int] = (1, 14),
+ recurrent_cell: str = "lstm",
) -> None:
super().__init__()
self.backbone_args = backbone_args or {}
@@ -34,17 +36,19 @@ class LineRecurrentNetwork(nn.Module):
self.sliding_window = self._configure_sliding_window()
self.input_size = input_size
self.hidden_size = hidden_size
- self.backbone = self._configure_backbone(backbone)
+ self.backbone = configure_backbone(backbone, backbone_args)
self.bidirectional = bidirectional
- self.flatten = flatten
- if self.flatten:
- self.fc = nn.Linear(
- in_features=self.input_size, out_features=self.hidden_size
+ if recurrent_cell.upper() in ["LSTM", "GRU"]:
+ recurrent_cell = getattr(nn, recurrent_cell)
+ else:
+ logger.warning(
+ f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
)
+ recurrent_cell = nn.LSTM
- self.rnn = nn.LSTM(
- input_size=self.hidden_size,
+ self.rnn = recurrent_cell(
+ input_size=self.input_size,
hidden_size=self.hidden_size,
bidirectional=bidirectional,
num_layers=num_layers,
@@ -57,32 +61,6 @@ class LineRecurrentNetwork(nn.Module):
nn.LogSoftmax(dim=2),
)
- def _configure_backbone(self, backbone: str) -> Type[nn.Module]:
- network_module = importlib.import_module("text_recognizer.networks")
- backbone_ = getattr(network_module, backbone)
-
- if "pretrained" in self.backbone_args:
- logger.info("Loading pretrained backbone.")
- checkpoint_file = Path(__file__).resolve().parents[
- 2
- ] / self.backbone_args.pop("pretrained")
-
- # Loading state directory.
- state_dict = torch.load(checkpoint_file)
- network_args = state_dict["network_args"]
- weights = state_dict["model_state"]
-
- # Initializes the network with trained weights.
- backbone = backbone_(**network_args)
- backbone.load_state_dict(weights)
- if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True:
- for params in backbone.parameters():
- params.requires_grad = False
-
- return backbone
- else:
- return backbone_(**self.backbone_args)
-
def _configure_sliding_window(self) -> nn.Sequential:
return nn.Sequential(
nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
@@ -96,8 +74,8 @@ class LineRecurrentNetwork(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
x = self.sliding_window(x)
# Rearrange from a sequence of patches for feedforward network.
@@ -106,11 +84,7 @@ class LineRecurrentNetwork(nn.Module):
x = self.backbone(x)
# Avgerage pooling.
- x = (
- self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t))
- if self.flatten
- else rearrange(x, "(b t) h -> t b h", b=b, t=t)
- )
+ x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
# Sequence predictions.
x, _ = self.rnn(x)
diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py
new file mode 100644
index 0000000..d2aad60
--- /dev/null
+++ b/src/text_recognizer/networks/densenet.py
@@ -0,0 +1,225 @@
+"""Defines a Densely Connected Convolutional Networks in PyTorch.
+
+Sources:
+https://arxiv.org/abs/1608.06993
+https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
+
+"""
+from typing import List, Optional, Union
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class _DenseLayer(nn.Module):
+ """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ growth_rate: int,
+ bn_size: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ activation_fn = activation_function(activation)
+ self.dense_layer = [
+ nn.BatchNorm2d(in_channels),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=bn_size * growth_rate,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(bn_size * growth_rate),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=bn_size * growth_rate,
+ out_channels=growth_rate,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ ]
+ if dropout_rate:
+ self.dense_layer.append(nn.Dropout(p=dropout_rate))
+
+ self.dense_layer = nn.Sequential(*self.dense_layer)
+
+ def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor:
+ if isinstance(x, list):
+ x = torch.cat(x, 1)
+ return self.dense_layer(x)
+
+
+class _DenseBlock(nn.Module):
+ def __init__(
+ self,
+ num_layers: int,
+ in_channels: int,
+ bn_size: int,
+ growth_rate: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.dense_block = self._build_dense_blocks(
+ num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation
+ )
+
+ def _build_dense_blocks(
+ self,
+ num_layers: int,
+ in_channels: int,
+ bn_size: int,
+ growth_rate: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> nn.ModuleList:
+ dense_block = []
+ for i in range(num_layers):
+ dense_block.append(
+ _DenseLayer(
+ in_channels=in_channels + i * growth_rate,
+ growth_rate=growth_rate,
+ bn_size=bn_size,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ )
+ )
+ return nn.ModuleList(dense_block)
+
+ def forward(self, x: Tensor) -> Tensor:
+ feature_maps = [x]
+ for layer in self.dense_block:
+ x = layer(feature_maps)
+ feature_maps.append(x)
+ return torch.cat(feature_maps, 1)
+
+
+class _Transition(nn.Module):
+ def __init__(
+ self, in_channels: int, out_channels: int, activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ activation_fn = activation_function(activation)
+ self.transition = nn.Sequential(
+ nn.BatchNorm2d(in_channels),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.AvgPool2d(kernel_size=2, stride=2),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.transition(x)
+
+
+class DenseNet(nn.Module):
+ """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow."""
+
+ def __init__(
+ self,
+ growth_rate: int = 32,
+ block_config: List[int] = (6, 12, 24, 16),
+ in_channels: int = 1,
+ base_channels: int = 64,
+ num_classes: int = 80,
+ bn_size: int = 4,
+ dropout_rate: float = 0,
+ classifier: bool = True,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.densenet = self._configure_densenet(
+ in_channels,
+ base_channels,
+ num_classes,
+ growth_rate,
+ block_config,
+ bn_size,
+ dropout_rate,
+ classifier,
+ activation,
+ )
+
+ def _configure_densenet(
+ self,
+ in_channels: int,
+ base_channels: int,
+ num_classes: int,
+ growth_rate: int,
+ block_config: List[int],
+ bn_size: int,
+ dropout_rate: float,
+ classifier: bool,
+ activation: str,
+ ) -> nn.Sequential:
+ activation_fn = activation_function(activation)
+ densenet = [
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=base_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(base_channels),
+ activation_fn,
+ ]
+
+ num_features = base_channels
+
+ for i, num_layers in enumerate(block_config):
+ densenet.append(
+ _DenseBlock(
+ num_layers=num_layers,
+ in_channels=num_features,
+ bn_size=bn_size,
+ growth_rate=growth_rate,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ )
+ )
+ num_features = num_features + num_layers * growth_rate
+ if i != len(block_config) - 1:
+ densenet.append(
+ _Transition(
+ in_channels=num_features,
+ out_channels=num_features // 2,
+ activation=activation,
+ )
+ )
+ num_features = num_features // 2
+
+ densenet.append(activation_fn)
+
+ if classifier:
+ densenet.append(nn.AdaptiveAvgPool2d((1, 1)))
+ densenet.append(Rearrange("b c h w -> b (c h w)"))
+ densenet.append(
+ nn.Linear(in_features=num_features, out_features=num_classes)
+ )
+
+ return nn.Sequential(*densenet)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass of Densenet."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.densenet(x)
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index 53c575e..527e1a0 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
class LeNet(nn.Module):
@@ -63,6 +63,6 @@ class LeNet(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
return self.layers(x)
diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/loss.py
index 73e0641..ff843cf 100644
--- a/src/text_recognizer/networks/losses.py
+++ b/src/text_recognizer/networks/loss.py
@@ -4,6 +4,9 @@ from torch import nn
from torch import Tensor
+__all__ = ["EmbeddingLoss"]
+
+
class EmbeddingLoss:
"""Metric loss for training encoders to produce information-rich latent embeddings."""
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index d66af28..1101912 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
class MLP(nn.Module):
@@ -63,8 +63,8 @@ class MLP(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
return self.layers(x)
@property
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 046600d..6405192 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -7,8 +7,8 @@ import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.misc import activation_function
from text_recognizer.networks.stn import SpatialTransformerNetwork
+from text_recognizer.networks.util import activation_function
class Conv2dAuto(nn.Conv2d):
@@ -225,8 +225,8 @@ class ResidualNetworkEncoder(nn.Module):
in_channels=in_channels,
out_channels=self.block_sizes[0],
kernel_size=3,
- stride=2,
- padding=3,
+ stride=1,
+ padding=1,
bias=False,
),
nn.BatchNorm2d(self.block_sizes[0]),
diff --git a/src/text_recognizer/networks/sparse_mlp.py b/src/text_recognizer/networks/sparse_mlp.py
new file mode 100644
index 0000000..53cf166
--- /dev/null
+++ b/src/text_recognizer/networks/sparse_mlp.py
@@ -0,0 +1,78 @@
+"""Defines the Sparse MLP network."""
+from typing import Callable, Dict, List, Optional, Union
+import warnings
+
+from einops.layers.torch import Rearrange
+from pytorch_block_sparse import BlockSparseLinear
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+
+class SparseMLP(nn.Module):
+ """Sparse multi layered perceptron network."""
+
+ def __init__(
+ self,
+ input_size: int = 784,
+ num_classes: int = 10,
+ hidden_size: Union[int, List] = 128,
+ num_layers: int = 3,
+ density: float = 0.1,
+ activation_fn: str = "relu",
+ ) -> None:
+ """Initialization of the MLP network.
+
+ Args:
+ input_size (int): The input shape of the network. Defaults to 784.
+ num_classes (int): Number of classes in the dataset. Defaults to 10.
+ hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
+ num_layers (int): The number of hidden layers. Defaults to 3.
+ density (float): The density of activation at each layer. Default to 0.1.
+ activation_fn (str): Name of the activation function in the hidden layers. Defaults to
+ relu.
+
+ """
+ super().__init__()
+
+ activation_fn = activation_function(activation_fn)
+
+ if isinstance(hidden_size, int):
+ hidden_size = [hidden_size] * num_layers
+
+ self.layers = [
+ Rearrange("b c h w -> b (c h w)"),
+ nn.Linear(in_features=input_size, out_features=hidden_size[0]),
+ activation_fn,
+ ]
+
+ for i in range(num_layers - 1):
+ self.layers += [
+ BlockSparseLinear(
+ in_features=hidden_size[i],
+ out_features=hidden_size[i + 1],
+ density=density,
+ ),
+ activation_fn,
+ ]
+
+ self.layers.append(
+ nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
+ )
+
+ self.layers = nn.Sequential(*self.layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.layers(x)
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the network."""
+ return "mlp"
diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py
deleted file mode 100644
index c091ba0..0000000
--- a/src/text_recognizer/networks/transformer.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""TBC."""
-from typing import Dict
-
-import torch
-from torch import Tensor
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py
new file mode 100644
index 0000000..020a917
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/__init__.py
@@ -0,0 +1,3 @@
+"""Transformer modules."""
+from .positional_encoding import PositionalEncoding
+from .transformer import Decoder, Encoder, Transformer
diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py
new file mode 100644
index 0000000..cce1ecc
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/attention.py
@@ -0,0 +1,93 @@
+"""Implementes the attention module for the transformer."""
+from typing import Optional, Tuple
+
+from einops import rearrange
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class MultiHeadAttention(nn.Module):
+ """Implementation of multihead attention."""
+
+ def __init__(
+ self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ self.fc_q = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_k = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_v = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
+
+ self._init_weights()
+
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def _init_weights(self) -> None:
+ nn.init.normal_(
+ self.fc_q.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.normal_(
+ self.fc_k.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.normal_(
+ self.fc_v.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.xavier_normal_(self.fc_out.weight)
+
+ def scaled_dot_product_attention(
+ self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ ) -> Tensor:
+ """Calculates the scaled dot product attention."""
+
+ # Compute the energy.
+ energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
+ query.shape[-1]
+ )
+
+ # If we have a mask for padding some inputs.
+ if mask is not None:
+ energy = energy.masked_fill(mask == 0, -np.inf)
+
+ # Compute the attention from the energy.
+ attention = torch.softmax(energy, dim=3)
+
+ out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
+ out = rearrange(out, "b head l v -> b l (head v)")
+ return out, attention
+
+ def forward(
+ self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Tensor]:
+ """Forward pass for computing the multihead attention."""
+ # Get the query, key, and value tensor.
+ query = rearrange(
+ self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
+ )
+ key = rearrange(
+ self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
+ )
+ value = rearrange(
+ self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
+ )
+
+ out, attention = self.scaled_dot_product_attention(query, key, value, mask)
+
+ out = self.fc_out(out)
+ out = self.dropout(out)
+ return out, attention
diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py
new file mode 100644
index 0000000..a47141b
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/positional_encoding.py
@@ -0,0 +1,31 @@
+"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+ """Encodes a sense of distance or time for transformer networks."""
+
+ def __init__(
+ self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
+ ) -> None:
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ pe = torch.zeros(max_len, hidden_dim)
+ position = torch.arange(0, max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
+ )
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Encodes the tensor with a postional embedding."""
+ x = x + self.pe[:, : x.shape[1]]
+ return self.dropout(x)
diff --git a/src/text_recognizer/networks/transformer/sparse_transformer.py b/src/text_recognizer/networks/transformer/sparse_transformer.py
new file mode 100644
index 0000000..8c391c8
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/sparse_transformer.py
@@ -0,0 +1 @@
+"""Encoder and Decoder modules using spares activations."""
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
new file mode 100644
index 0000000..1c9c7dd
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/transformer.py
@@ -0,0 +1,241 @@
+"""Transfomer module."""
+import copy
+from typing import Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer.attention import MultiHeadAttention
+from text_recognizer.networks.util import activation_function
+
+
+def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
+
+
+class _IntraLayerConnection(nn.Module):
+ """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
+
+ def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
+ super().__init__()
+ self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward(self, src: Tensor, residual: Tensor) -> Tensor:
+ return self.norm(self.dropout(src) + residual)
+
+
+class _ConvolutionalLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_dim: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.layer = nn.Sequential(
+ nn.Linear(in_features=hidden_dim, out_features=expansion_dim),
+ activation_function(activation),
+ nn.Dropout(p=dropout_rate),
+ nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.layer(x)
+
+
+class EncoderLayer(nn.Module):
+ """Transfomer encoding layer."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+ self.cnn = _ConvolutionalLayer(
+ hidden_dim, expansion_dim, dropout_rate, activation
+ )
+ self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ """Forward pass through the encoder."""
+ # First block.
+ # Multi head attention.
+ out, _ = self.self_attention(src, src, src, mask)
+
+ # Add & norm.
+ out = self.block1(out, src)
+
+ # Second block.
+ # Apply 1D-convolution.
+ cnn_out = self.cnn(out)
+
+ # Add & norm.
+ out = self.block2(cnn_out, out)
+
+ return out
+
+
+class Encoder(nn.Module):
+ """Transfomer encoder module."""
+
+ def __init__(
+ self,
+ num_layers: int,
+ encoder_layer: Type[nn.Module],
+ norm: Optional[Type[nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.norm = norm
+
+ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
+ """Forward pass through all encoder layers."""
+ for layer in self.layers:
+ src = layer(src, src_mask)
+
+ if self.norm is not None:
+ src = self.norm(src)
+
+ return src
+
+
+class DecoderLayer(nn.Module):
+ """Transfomer decoder layer."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float = 0.0,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+ self.multihead_attention = MultiHeadAttention(
+ hidden_dim, num_heads, dropout_rate
+ )
+ self.cnn = _ConvolutionalLayer(
+ hidden_dim, expansion_dim, dropout_rate, activation
+ )
+ self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+ def forward(
+ self,
+ trg: Tensor,
+ memory: Tensor,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass of the layer."""
+ out, _ = self.self_attention(trg, trg, trg, trg_mask)
+ trg = self.block1(out, trg)
+
+ out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
+ trg = self.block2(out, trg)
+
+ out = self.cnn(trg)
+ out = self.block3(out, trg)
+
+ return out
+
+
+class Decoder(nn.Module):
+ """Transfomer decoder module."""
+
+ def __init__(
+ self,
+ decoder_layer: Type[nn.Module],
+ num_layers: int,
+ norm: Optional[Type[nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ trg: Tensor,
+ memory: Tensor,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass through the decoder."""
+ for layer in self.layers:
+ trg = layer(trg, memory, trg_mask, memory_mask)
+
+ if self.norm is not None:
+ trg = self.norm(trg)
+
+ return trg
+
+
+class Transformer(nn.Module):
+ """Transformer network."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+
+ # Configure encoder.
+ encoder_norm = nn.LayerNorm(hidden_dim)
+ encoder_layer = EncoderLayer(
+ hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+ )
+ self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
+
+ # Configure decoder.
+ decoder_norm = nn.LayerNorm(hidden_dim)
+ decoder_layer = DecoderLayer(
+ hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+ )
+ self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(
+ self,
+ src: Tensor,
+ trg: Tensor,
+ src_mask: Optional[Tensor] = None,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass through the transformer."""
+ if src.shape[0] != trg.shape[0]:
+ raise RuntimeError("The batch size of the src and trg must be the same.")
+ if src.shape[2] != trg.shape[2]:
+ raise RuntimeError(
+ "The number of features for the src and trg must be the same."
+ )
+
+ memory = self.encoder(src, src_mask)
+ output = self.decoder(trg, memory, trg_mask, memory_mask)
+ return output
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/util.py
index 1f853e9..0d08506 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/util.py
@@ -1,7 +1,10 @@
"""Miscellaneous neural network functionality."""
-from typing import Tuple, Type
+import importlib
+from pathlib import Path
+from typing import Dict, Tuple, Type
from einops import rearrange
+from loguru import logger
import torch
from torch import nn
@@ -43,3 +46,38 @@ def activation_function(activation: str) -> Type[nn.Module]:
]
)
return activation_fns[activation.lower()]
+
+
+def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
+ """Loads a backbone network."""
+ network_module = importlib.import_module("text_recognizer.networks")
+ backbone_ = getattr(network_module, backbone)
+
+ if "pretrained" in backbone_args:
+ logger.info("Loading pretrained backbone.")
+ checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
+ "pretrained"
+ )
+
+ # Loading state directory.
+ state_dict = torch.load(checkpoint_file)
+ network_args = state_dict["network_args"]
+ weights = state_dict["model_state"]
+
+ # Initializes the network with trained weights.
+ backbone = backbone_(**network_args)
+ backbone.load_state_dict(weights)
+ if "freeze" in backbone_args and backbone_args["freeze"] is True:
+ for params in backbone.parameters():
+ params.requires_grad = False
+
+ else:
+ backbone_ = getattr(network_module, backbone)
+ backbone = backbone_(**backbone_args)
+
+ if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
+ backbone = nn.Sequential(
+ *list(backbone.children())[0][: -backbone_args["remove_layers"]]
+ )
+
+ return backbone
diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py
new file mode 100644
index 0000000..4d204d3
--- /dev/null
+++ b/src/text_recognizer/networks/vision_transformer.py
@@ -0,0 +1,158 @@
+"""VisionTransformer module.
+
+Splits each image into patches and feeds them to a transformer.
+
+"""
+
+from typing import Dict, Optional, Tuple, Type
+
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange
+from loguru import logger
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import configure_backbone
+
+
+class VisionTransformer(nn.Module):
+ """Linear projection+Transfomer for image to sequence prediction, sort of based on the ideas from ViT."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ max_len: int,
+ expansion_dim: int,
+ mlp_dim: int,
+ dropout_rate: float,
+ trg_pad_index: int,
+ patch_size: Tuple[int, int] = (28, 28),
+ stride: Tuple[int, int] = (1, 14),
+ activation: str = "gelu",
+ backbone: Optional[str] = None,
+ backbone_args: Optional[Dict] = None,
+ ) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.stride = stride
+ self.trg_pad_index = trg_pad_index
+ self.slidning_window = self._configure_sliding_window()
+ self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+ self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
+
+ self.use_backbone = False
+ if backbone is None:
+ self.linear_projection = nn.Linear(
+ self.patch_size[0] * self.patch_size[1], hidden_dim
+ )
+ else:
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.use_backbone = True
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(
+ nn.LayerNorm(hidden_dim),
+ nn.Linear(hidden_dim, mlp_dim),
+ nn.GELU(),
+ nn.Dropout(p=dropout_rate),
+ nn.Linear(mlp_dim, vocab_size),
+ )
+
+ def _configure_sliding_window(self) -> nn.Sequential:
+ return nn.Sequential(
+ nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+ Rearrange(
+ "b (c h w) t -> b t c h w",
+ h=self.patch_size[0],
+ w=self.patch_size[1],
+ c=1,
+ ),
+ )
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def _backbone(self, x: Tensor) -> Tensor:
+ b, t = x.shape[:2]
+ if self.use_backbone:
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+ x = self.backbone(x)
+ x = rearrange(x, "(b t) h -> b t h", b=b, t=t)
+ else:
+ x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t)
+ x = self.linear_projection(x)
+ return x
+
+ def preprocess_input(self, src: Tensor) -> Tensor:
+ """Encodes src with a backbone network and a positional encoding.
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: A input src to the transformer.
+
+ """
+ # If batch dimenstion is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+ src = self.slidning_window(src) # .squeeze(-2)
+ src = self._backbone(src)
+ src = self.position_encoding(src)
+ return src
+
+ def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+ """
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.character_embedding(trg.long())
+ trg = self.position_encoding(trg)
+ return trg, trg_mask
+
+ def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+ """Forward pass with vision transfomer."""
+ src = self.preprocess_input(x)
+ trg, trg_mask = self.preprocess_target(trg)
+ out = self.transformer(src, trg, trg_mask=trg_mask)
+ logits = self.head(out)
+ return logits
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index 618f414..aa79c12 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -8,7 +8,7 @@ import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
@@ -206,8 +206,8 @@ class WideResidualNetwork(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Feedforward pass."""
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
+ if len(x.shape) < 4:
+ x = x[(None,) * int(4 - len(x.shape))]
x = self.encoder(x)
if self.decoder is not None:
x = self.decoder(x)
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
new file mode 100644
index 0000000..6a9a915
--- /dev/null
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
index c001528..7fe1fa3 100644
--- a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
+++ b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
Binary files differ