summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/character_predictor.py7
-rw-r--r--src/text_recognizer/datasets/__init__.py4
-rw-r--r--src/text_recognizer/datasets/dataset.py22
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py6
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py51
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py6
-rw-r--r--src/text_recognizer/datasets/iam_paragraphs_dataset.py7
-rw-r--r--src/text_recognizer/datasets/transforms.py40
-rw-r--r--src/text_recognizer/datasets/util.py31
-rw-r--r--src/text_recognizer/line_predictor.py28
-rw-r--r--src/text_recognizer/models/__init__.py16
-rw-r--r--src/text_recognizer/models/base.py66
-rw-r--r--src/text_recognizer/models/character_model.py4
-rw-r--r--src/text_recognizer/models/crnn_model.py (renamed from src/text_recognizer/models/line_ctc_model.py)16
-rw-r--r--src/text_recognizer/models/metrics.py5
-rw-r--r--src/text_recognizer/models/transformer_encoder_model.py111
-rw-r--r--src/text_recognizer/models/vision_transformer_model.py119
-rw-r--r--src/text_recognizer/networks/__init__.py20
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py135
-rw-r--r--src/text_recognizer/networks/cnn_transformer_encoder.py73
-rw-r--r--src/text_recognizer/networks/crnn.py108
-rw-r--r--src/text_recognizer/networks/ctc.py2
-rw-r--r--src/text_recognizer/networks/densenet.py225
-rw-r--r--src/text_recognizer/networks/lenet.py6
-rw-r--r--src/text_recognizer/networks/line_lstm_ctc.py120
-rw-r--r--src/text_recognizer/networks/loss.py69
-rw-r--r--src/text_recognizer/networks/losses.py31
-rw-r--r--src/text_recognizer/networks/mlp.py6
-rw-r--r--src/text_recognizer/networks/residual_network.py6
-rw-r--r--src/text_recognizer/networks/sparse_mlp.py78
-rw-r--r--src/text_recognizer/networks/transformer.py5
-rw-r--r--src/text_recognizer/networks/transformer/__init__.py3
-rw-r--r--src/text_recognizer/networks/transformer/attention.py93
-rw-r--r--src/text_recognizer/networks/transformer/positional_encoding.py32
-rw-r--r--src/text_recognizer/networks/transformer/transformer.py242
-rw-r--r--src/text_recognizer/networks/util.py (renamed from src/text_recognizer/networks/misc.py)42
-rw-r--r--src/text_recognizer/networks/vision_transformer.py159
-rw-r--r--src/text_recognizer/networks/wide_resnet.py6
-rw-r--r--src/text_recognizer/tests/support/create_emnist_support_files.py13
-rw-r--r--src/text_recognizer/tests/test_line_predictor.py35
-rw-r--r--src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.ptbin0 -> 5628749 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.ptbin0 -> 1273881 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.ptbin14485362 -> 0 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.ptbin17938163 -> 0 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.ptbin26090486 -> 0 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.ptbin32765213 -> 0 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.ptbin44089479 -> 0 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt (renamed from src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt)bin14485342 -> 14953410 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.ptbin1704096 -> 0 bytes
-rw-r--r--src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.ptbin20694308 -> 3457858 bytes
50 files changed, 1804 insertions, 244 deletions
diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py
index df37e68..ad71289 100644
--- a/src/text_recognizer/character_predictor.py
+++ b/src/text_recognizer/character_predictor.py
@@ -4,6 +4,7 @@ from typing import Dict, Tuple, Type, Union
import numpy as np
from torch import nn
+from text_recognizer import datasets, networks
from text_recognizer.models import CharacterModel
from text_recognizer.util import read_image
@@ -11,9 +12,11 @@ from text_recognizer.util import read_image
class CharacterPredictor:
"""Recognizes the character in handwritten character images."""
- def __init__(self, network_fn: Type[nn.Module]) -> None:
+ def __init__(self, network_fn: str, dataset: str) -> None:
"""Intializes the CharacterModel and load the pretrained weights."""
- self.model = CharacterModel(network_fn=network_fn)
+ network_fn = getattr(networks, network_fn)
+ dataset = getattr(datasets, dataset)
+ self.model = CharacterModel(network_fn=network_fn, dataset=dataset)
self.model.eval()
self.model.use_swa_model()
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index a3af9b1..d8372e3 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,5 +1,5 @@
"""Dataset modules."""
-from .emnist_dataset import EmnistDataset, Transpose
+from .emnist_dataset import EmnistDataset
from .emnist_lines_dataset import (
construct_image_from_string,
EmnistLinesDataset,
@@ -8,6 +8,7 @@ from .emnist_lines_dataset import (
from .iam_dataset import IamDataset
from .iam_lines_dataset import IamLinesDataset
from .iam_paragraphs_dataset import IamParagraphsDataset
+from .transforms import AddTokens, Transpose
from .util import (
_download_raw_dataset,
compute_sha256,
@@ -19,6 +20,7 @@ from .util import (
__all__ = [
"_download_raw_dataset",
+ "AddTokens",
"compute_sha256",
"construct_image_from_string",
"DATA_DIRNAME",
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index 05520e5..2de7f09 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -18,6 +18,9 @@ class Dataset(data.Dataset):
subsample_fraction: float = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
) -> None:
"""Initialization of Dataset class.
@@ -26,12 +29,14 @@ class Dataset(data.Dataset):
subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
Raises:
ValueError: If subsample_fraction is not None and outside the range (0, 1).
"""
-
self.train = train
self.split = "train" if self.train else "test"
@@ -40,19 +45,18 @@ class Dataset(data.Dataset):
raise ValueError("The subsample fraction must be in (0, 1).")
self.subsample_fraction = subsample_fraction
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(
+ init_token=init_token, eos_token=eos_token, pad_token=pad_token
+ )
self._input_shape = self._mapper.input_shape
self._output_shape = self._mapper._num_classes
self.num_classes = self.mapper.num_classes
# Set transforms.
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
+ self.transform = transform if transform is not None else ToTensor()
+ self.target_transform = (
+ target_transform if target_transform is not None else torch.tensor
+ )
self._data = None
self._targets = None
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index d01dcee..9884fdf 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -22,6 +22,7 @@ class EmnistDataset(Dataset):
def __init__(
self,
+ pad_token: str = None,
train: bool = False,
sample_to_balance: bool = False,
subsample_fraction: float = None,
@@ -32,6 +33,7 @@ class EmnistDataset(Dataset):
"""Loads the dataset and the mappings.
Args:
+ pad_token (str): The pad token symbol. Defaults to _.
train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
@@ -45,6 +47,7 @@ class EmnistDataset(Dataset):
subsample_fraction=subsample_fraction,
transform=transform,
target_transform=target_transform,
+ pad_token=pad_token,
)
self.sample_to_balance = sample_to_balance
@@ -53,8 +56,7 @@ class EmnistDataset(Dataset):
if transform is None:
self.transform = Compose([Transpose(), ToTensor()])
- # The EMNIST dataset is already casted to tensors.
- self.target_transform = target_transform
+ self.target_transform = None
self.seed = seed
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 6268a01..6871492 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -4,6 +4,7 @@ from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
+import click
import h5py
from loguru import logger
import numpy as np
@@ -37,6 +38,9 @@ class EmnistLinesDataset(Dataset):
max_overlap: float = 0.33,
num_samples: int = 10000,
seed: int = 4711,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
) -> None:
"""Set attributes and loads the dataset.
@@ -50,13 +54,21 @@ class EmnistLinesDataset(Dataset):
max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
num_samples (int): Number of samples to generate. Defaults to 10000.
seed (int): Seed number. Defaults to 4711.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
"""
+ self.pad_token = "_" if pad_token is None else pad_token
+
super().__init__(
train=train,
transform=transform,
target_transform=target_transform,
subsample_fraction=subsample_fraction,
+ init_token=init_token,
+ pad_token=self.pad_token,
+ eos_token=eos_token,
)
# Extract dataset information.
@@ -118,11 +130,7 @@ class EmnistLinesDataset(Dataset):
@property
def data_filename(self) -> Path:
"""Path to the h5 file."""
- filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
- if self.train:
- filename = "train_" + filename
- else:
- filename = "test_" + filename
+ filename = "train.pt" if self.train else "test.pt"
return DATA_DIRNAME / filename
def load_or_generate_data(self) -> None:
@@ -138,8 +146,8 @@ class EmnistLinesDataset(Dataset):
"""Loads the dataset from the h5 file."""
logger.debug("EmnistLinesDataset loading data from HDF5...")
with h5py.File(self.data_filename, "r") as f:
- self._data = f["data"][:]
- self._targets = f["targets"][:]
+ self._data = f["data"][()]
+ self._targets = f["targets"][()]
def _generate_data(self) -> str:
"""Generates a dataset with the Brown corpus and Emnist characters."""
@@ -148,7 +156,10 @@ class EmnistLinesDataset(Dataset):
sentence_generator = SentenceGenerator(self.max_length)
# Load emnist dataset.
- emnist = EmnistDataset(train=self.train, sample_to_balance=True)
+ emnist = EmnistDataset(
+ train=self.train, sample_to_balance=True, pad_token=self.pad_token
+ )
+ emnist.load_or_generate_data()
samples_by_character = get_samples_by_character(
emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,
@@ -298,6 +309,18 @@ def convert_strings_to_categorical_labels(
return np.array([[mapping[c] for c in label] for label in labels])
+@click.command()
+@click.option(
+ "--max_length", type=int, default=34, help="Number of characters in a sentence."
+)
+@click.option(
+ "--min_overlap", type=float, default=0.0, help="Min overlap between characters."
+)
+@click.option(
+ "--max_overlap", type=float, default=0.33, help="Max overlap between characters."
+)
+@click.option("--num_train", type=int, default=10_000, help="Number of train examples.")
+@click.option("--num_test", type=int, default=1_000, help="Number of test examples.")
def create_datasets(
max_length: int = 34,
min_overlap: float = 0,
@@ -306,17 +329,17 @@ def create_datasets(
num_test: int = 1000,
) -> None:
"""Creates a training an validation dataset of Emnist lines."""
- emnist_train = EmnistDataset(train=True, sample_to_balance=True)
- emnist_test = EmnistDataset(train=False, sample_to_balance=True)
- datasets = [emnist_train, emnist_test]
num_samples = [num_train, num_test]
- for num, train, dataset in zip(num_samples, [True, False], datasets):
+ for num, train in zip(num_samples, [True, False]):
emnist_lines = EmnistLinesDataset(
train=train,
- emnist=dataset,
max_length=max_length,
min_overlap=min_overlap,
max_overlap=max_overlap,
num_samples=num,
)
- emnist_lines._load_or_generate_data()
+ emnist_lines.load_or_generate_data()
+
+
+if __name__ == "__main__":
+ create_datasets()
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
index 4a74b2b..fdd2fe6 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -32,12 +32,18 @@ class IamLinesDataset(Dataset):
subsample_fraction: float = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
) -> None:
super().__init__(
train=train,
subsample_fraction=subsample_fraction,
transform=transform,
target_transform=target_transform,
+ init_token=init_token,
+ pad_token=pad_token,
+ eos_token=eos_token,
)
@property
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
index 4b34bd1..c1e8fe2 100644
--- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py
+++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -266,11 +266,16 @@ def _load_iam_paragraphs() -> None:
@click.option(
"--subsample_fraction",
type=float,
- default=0.0,
+ default=None,
help="The subsampling factor of the dataset.",
)
def main(subsample_fraction: float) -> None:
"""Load dataset and print info."""
+ logger.info("Creating train set...")
+ dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction)
+ dataset.load_or_generate_data()
+ print(dataset)
+ logger.info("Creating test set...")
dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
dataset.load_or_generate_data()
print(dataset)
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 17231a8..8deac7f 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -3,6 +3,9 @@ import numpy as np
from PIL import Image
import torch
from torch import Tensor
+from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor
+
+from text_recognizer.datasets.util import EmnistMapper
class Transpose:
@@ -11,3 +14,40 @@ class Transpose:
def __call__(self, image: Image) -> np.ndarray:
"""Swaps axis."""
return np.array(image).swapaxes(0, 1)
+
+
+class AddTokens:
+ """Adds start of sequence and end of sequence tokens to target tensor."""
+
+ def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token,
+ )
+ self.pad_value = self.emnist_mapper(self.pad_token)
+ self.eos_value = self.emnist_mapper(self.eos_token)
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Adds a sos token to the begining and a eos token to the end of a target sequence."""
+ dtype, device = target.dtype, target.device
+
+ # Find the where padding starts.
+ pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
+
+ target[pad_index] = self.eos_value
+
+ if self.init_token is not None:
+ self.sos_value = self.emnist_mapper(self.init_token)
+ sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
+ target = torch.cat([sos, target], dim=0)
+
+ return target
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 73968a1..d2df8b5 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -4,6 +4,7 @@ import importlib
import json
import os
from pathlib import Path
+import string
from typing import Callable, Dict, List, Optional, Type, Union
from urllib.request import urlopen, urlretrieve
@@ -26,7 +27,7 @@ def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
mapping = [(i, str(label)) for i, label in enumerate(labels)]
essentials = {
"mapping": mapping,
- "input_shape": tuple(emnsit_dataset[0][0].shape[:]),
+ "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]),
}
logger.info("Saving emnist essentials...")
with open(ESSENTIALS_FILENAME, "w") as f:
@@ -43,11 +44,21 @@ def download_emnist() -> None:
class EmnistMapper:
"""Mapper between network output to Emnist character."""
- def __init__(self) -> None:
+ def __init__(
+ self,
+ pad_token: str,
+ init_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ ) -> None:
"""Loads the emnist essentials file with the mapping and input shape."""
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+
self.essentials = self._load_emnist_essentials()
# Load dataset infromation.
- self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
+ self._mapping = dict(self.essentials["mapping"])
+ self._augment_emnist_mapping()
self._inverse_mapping = {v: k for k, v in self.mapping.items()}
self._num_classes = len(self.mapping)
self._input_shape = self.essentials["input_shape"]
@@ -103,7 +114,7 @@ class EmnistMapper:
essentials = json.load(f)
return essentials
- def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
+ def _augment_emnist_mapping(self) -> None:
"""Augment the mapping with extra symbols."""
# Extra symbols in IAM dataset
extra_symbols = [
@@ -127,14 +138,20 @@ class EmnistMapper:
]
# padding symbol, and acts as blank symbol as well.
- extra_symbols.append("_")
+ extra_symbols.append(self.pad_token)
+
+ if self.init_token is not None:
+ extra_symbols.append(self.init_token)
+
+ if self.eos_token is not None:
+ extra_symbols.append(self.eos_token)
- max_key = max(mapping.keys())
+ max_key = max(self.mapping.keys())
extra_mapping = {}
for i, symbol in enumerate(extra_symbols):
extra_mapping[max_key + 1 + i] = symbol
- return {**mapping, **extra_mapping}
+ self._mapping = {**self.mapping, **extra_mapping}
def compute_sha256(filename: Union[Path, str]) -> str:
diff --git a/src/text_recognizer/line_predictor.py b/src/text_recognizer/line_predictor.py
new file mode 100644
index 0000000..981e2c9
--- /dev/null
+++ b/src/text_recognizer/line_predictor.py
@@ -0,0 +1,28 @@
+"""LinePredictor class."""
+import importlib
+from typing import Tuple, Union
+
+import numpy as np
+from torch import nn
+
+from text_recognizer import datasets, networks
+from text_recognizer.models import VisionTransformerModel
+from text_recognizer.util import read_image
+
+
+class LinePredictor:
+ """Given an image of a line of handwritten text, recognizes the text content."""
+
+ def __init__(self, dataset: str, network_fn: str) -> None:
+ network_fn = getattr(networks, network_fn)
+ dataset = getattr(datasets, dataset)
+ self.model = VisionTransformerModel(network_fn=network_fn, dataset=dataset)
+ self.model.eval()
+
+ def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]:
+ """Predict on a single images contianing a handwritten character."""
+ if isinstance(image_or_filename, str):
+ image = read_image(image_or_filename, grayscale=True)
+ else:
+ image = image_or_filename
+ return self.model.predict_on_image(image)
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index a3cfc15..28aa52e 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -1,7 +1,19 @@
"""Model modules."""
from .base import Model
from .character_model import CharacterModel
-from .line_ctc_model import LineCTCModel
+from .crnn_model import CRNNModel
from .metrics import accuracy, cer, wer
+from .transformer_encoder_model import TransformerEncoderModel
+from .vision_transformer_model import VisionTransformerModel
-__all__ = ["Model", "cer", "CharacterModel", "LineCTCModel", "accuracy", "wer"]
+__all__ = [
+ "Model",
+ "cer",
+ "CharacterModel",
+ "CRNNModel",
+ "CNNTransfromerModel",
+ "accuracy",
+ "TransformerEncoderModel",
+ "VisionTransformerModel",
+ "wer",
+]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index caf8065..cc44c92 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -6,7 +6,7 @@ import importlib
from pathlib import Path
import re
import shutil
-from typing import Callable, Dict, Optional, Tuple, Type
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from loguru import logger
import torch
@@ -15,6 +15,7 @@ from torch import Tensor
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.utils.data import DataLoader, Dataset, random_split
from torchsummary import summary
+from torchvision.transforms import Compose
from text_recognizer.datasets import EmnistMapper
@@ -128,16 +129,42 @@ class Model(ABC):
self._configure_criterion()
self._configure_optimizers()
- # Prints a summary of the network in terminal.
- self.summary()
-
# Set this flag to true to prevent the model from configuring again.
self.is_configured = True
+ def _configure_transforms(self) -> None:
+ # Load transforms.
+ transforms_module = importlib.import_module(
+ "text_recognizer.datasets.transforms"
+ )
+ if (
+ "transform" in self.dataset_args["args"]
+ and self.dataset_args["args"]["transform"] is not None
+ ):
+ transform_ = []
+ for t in self.dataset_args["args"]["transform"]:
+ args = t["args"] or {}
+ transform_.append(getattr(transforms_module, t["type"])(**args))
+ self.dataset_args["args"]["transform"] = Compose(transform_)
+
+ if (
+ "target_transform" in self.dataset_args["args"]
+ and self.dataset_args["args"]["target_transform"] is not None
+ ):
+ target_transform_ = [
+ torch.tensor,
+ ]
+ for t in self.dataset_args["args"]["target_transform"]:
+ args = t["args"] or {}
+ target_transform_.append(getattr(transforms_module, t["type"])(**args))
+ self.dataset_args["args"]["target_transform"] = Compose(target_transform_)
+
def prepare_data(self) -> None:
"""Prepare data for training."""
# TODO add downloading.
if not self.data_prepared:
+ self._configure_transforms()
+
# Load train dataset.
train_dataset = self.dataset(train=True, **self.dataset_args["args"])
train_dataset.load_or_generate_data()
@@ -327,20 +354,20 @@ class Model(ABC):
else:
return self.network(x)
- def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
- """Compute the loss."""
- return self.criterion(output, targets)
-
def summary(
- self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3
+ self,
+ input_shape: Optional[Union[List, Tuple]] = None,
+ depth: int = 4,
+ device: Optional[str] = None,
) -> None:
"""Prints a summary of the network architecture."""
+ device = self.device if device is None else device
if input_shape is not None:
- summary(self.network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=device)
elif self._input_shape is not None:
input_shape = (1,) + tuple(self._input_shape)
- summary(self.network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=device)
else:
logger.warning("Could not print summary as input shape is not set.")
@@ -356,25 +383,29 @@ class Model(ABC):
state["optimizer_state"] = self._optimizer.state_dict()
if self._lr_scheduler is not None:
- state["scheduler_state"] = self._lr_scheduler.state_dict()
+ state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict()
+ state["scheduler_interval"] = self._lr_scheduler["interval"]
if self._swa_network is not None:
state["swa_network"] = self._swa_network.state_dict()
return state
- def load_from_checkpoint(self, checkpoint_path: Path) -> None:
+ def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
"""Load a previously saved checkpoint.
Args:
checkpoint_path (Path): Path to the experiment with the checkpoint.
"""
+ checkpoint_path = Path(checkpoint_path)
+ self.prepare_data()
+ self.configure_model()
logger.debug("Loading checkpoint...")
if not checkpoint_path.exists():
logger.debug("File does not exist {str(checkpoint_path)}")
- checkpoint = torch.load(str(checkpoint_path))
+ checkpoint = torch.load(str(checkpoint_path), map_location=self.device)
self._network.load_state_dict(checkpoint["model_state"])
if self._optimizer is not None:
@@ -383,8 +414,11 @@ class Model(ABC):
if self._lr_scheduler is not None:
# Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
# with OneCycleLR.
- if self._lr_scheduler.__class__.__name__ != "OneCycleLR":
- self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
+ if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
+ self._lr_scheduler["lr_scheduler"].load_state_dict(
+ checkpoint["scheduler_state"]
+ )
+ self._lr_scheduler["interval"] = checkpoint["scheduler_interval"]
if self._swa_network is not None:
self._swa_network.load_state_dict(checkpoint["swa_network"])
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 50e94a2..f9944f3 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -47,8 +47,9 @@ class CharacterModel(Model):
swa_args,
device,
)
+ self.pad_token = dataset_args["args"]["pad_token"]
if self._mapper is None:
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
@@ -65,6 +66,7 @@ class CharacterModel(Model):
Tuple[str, float]: The predicted character and the confidence in the prediction.
"""
+ self.eval()
if image.dtype == np.uint8:
# Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/crnn_model.py
index 16eaed3..1e01a83 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/crnn_model.py
@@ -1,4 +1,4 @@
-"""Defines the LineCTCModel class."""
+"""Defines the CRNNModel class."""
from typing import Callable, Dict, Optional, Tuple, Type, Union
import numpy as np
@@ -13,7 +13,7 @@ from text_recognizer.models.base import Model
from text_recognizer.networks import greedy_decoder
-class LineCTCModel(Model):
+class CRNNModel(Model):
"""Model for predicting a sequence of characters from an image of a text line."""
def __init__(
@@ -47,11 +47,13 @@ class LineCTCModel(Model):
swa_args,
device,
)
+
+ self.pad_token = dataset_args["args"]["pad_token"]
if self._mapper is None:
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
self.tensor_transform = ToTensor()
- def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
"""Computes the CTC loss.
Args:
@@ -82,11 +84,13 @@ class LineCTCModel(Model):
torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
)
- return self.criterion(output, targets, input_lengths, target_lengths)
+ return self._criterion(output, targets, input_lengths, target_lengths)
@torch.no_grad()
def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
"""Predict on a single input."""
+ self.eval()
+
if image.dtype == np.uint8:
# Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
image = self.tensor_transform(image)
@@ -110,6 +114,6 @@ class LineCTCModel(Model):
log_probs, _ = log_probs.max(dim=2)
predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = torch.exp(log_probs.sum()).item()
+ confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index 6a26216..42c3c6e 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -17,7 +17,10 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float:
float: The accuracy for the batch.
"""
- _, predicted = torch.max(outputs.data, dim=1)
+ # eos_index = torch.nonzero(labels == eos, as_tuple=False)
+ # eos_index = eos_index[0].item() if eos_index.nelement() else -1
+
+ _, predicted = torch.max(outputs, dim=-1)
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc
diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py
new file mode 100644
index 0000000..e35e298
--- /dev/null
+++ b/src/text_recognizer/models/transformer_encoder_model.py
@@ -0,0 +1,111 @@
+"""Defines the CNN-Transformer class."""
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+
+
+class TransformerEncoderModel(Model):
+ """A class for only using the encoder part in the sequence modelling."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ # self.init_token = dataset_args["args"]["init_token"]
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.eos_token = dataset_args["args"]["eos_token"]
+ if network_args is not None:
+ self.max_len = network_args["max_len"]
+ else:
+ self.max_len = 128
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(
+ # init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ self.tensor_transform = ToTensor()
+
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
+ logits = self.network(image)
+ # Convert logits to probabilities.
+ probs = self.softmax(logits).squeeze(0)
+
+ confidence, pred_tokens = probs.max(1)
+ pred_tokens = pred_tokens
+
+ eos_index = torch.nonzero(
+ pred_tokens == self._mapper(self.eos_token), as_tuple=False,
+ )
+
+ eos_index = eos_index[0].item() if eos_index.nelement() else -1
+
+ predicted_characters = "".join(
+ [self.mapper(x) for x in pred_tokens[:eos_index].tolist()]
+ )
+
+ confidence = np.min(confidence.tolist())
+
+ return predicted_characters, confidence
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
+ image
+ )
+
+ return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py
new file mode 100644
index 0000000..3d36437
--- /dev/null
+++ b/src/text_recognizer/models/vision_transformer_model.py
@@ -0,0 +1,119 @@
+"""Defines the CNN-Transformer class."""
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class VisionTransformerModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.init_token = dataset_args["args"]["init_token"]
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.eos_token = dataset_args["args"]["eos_token"]
+ if network_args is not None:
+ self.max_len = network_args["max_len"]
+ else:
+ self.max_len = 120
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ self.tensor_transform = ToTensor()
+
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
+ src = self.network.preprocess_input(image)
+ memory = self.network.encoder(src)
+
+ confidence_of_predictions = []
+ trg_indices = [self.mapper(self.init_token)]
+
+ for _ in range(self.max_len - 1):
+ trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
+ trg = self.network.preprocess_target(trg)
+ logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)
+
+ # Convert logits to probabilities.
+ probs = self.softmax(logits)
+
+ pred_token = probs.argmax(2)[:, -1].item()
+ confidence = probs.max(2).values[:, -1].item()
+
+ trg_indices.append(pred_token)
+ confidence_of_predictions.append(confidence)
+
+ if pred_token == self.mapper(self.eos_token):
+ break
+
+ confidence = np.min(confidence_of_predictions)
+ predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]])
+
+ return predicted_characters, confidence
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
+ image
+ )
+
+ return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index a39975f..6d88768 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,21 +1,33 @@
"""Network modules."""
+from .cnn_transformer import CNNTransformer
+from .cnn_transformer_encoder import CNNTransformerEncoder
+from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
+from .densenet import DenseNet
from .lenet import LeNet
-from .line_lstm_ctc import LineRecurrentNetwork
-from .losses import EmbeddingLoss
-from .misc import sliding_window
+from .loss import EmbeddingLoss
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
+from .sparse_mlp import SparseMLP
+from .transformer import Transformer
+from .util import sliding_window
+from .vision_transformer import VisionTransformer
from .wide_resnet import WideResidualNetwork
__all__ = [
+ "CNNTransformer",
+ "CNNTransformerEncoder",
+ "ConvolutionalRecurrentNetwork",
+ "DenseNet",
"EmbeddingLoss",
"greedy_decoder",
"MLP",
"LeNet",
- "LineRecurrentNetwork",
"ResidualNetwork",
"ResidualNetworkEncoder",
"sliding_window",
+ "Transformer",
+ "SparseMLP",
+ "VisionTransformer",
"WideResidualNetwork",
]
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
new file mode 100644
index 0000000..3da2c9f
--- /dev/null
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -0,0 +1,135 @@
+"""A DETR style transfomers but for text recognition."""
+from typing import Dict, Optional, Tuple
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import configure_backbone
+
+
+class CNNTransformer(nn.Module):
+ """CNN+Transfomer for image to sequence prediction, sort of based on the ideas from DETR."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ adaptive_pool_dim: Tuple,
+ expansion_dim: int,
+ dropout_rate: float,
+ trg_pad_index: int,
+ backbone: str,
+ out_channels: int,
+ max_len: int,
+ backbone_args: Optional[Dict] = None,
+ activation: str = "gelu",
+ ) -> None:
+ super().__init__()
+ self.trg_pad_index = trg_pad_index
+
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+
+ # self.conv = nn.Conv2d(out_channels, max_len, kernel_size=1)
+
+ self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ self.row_embed = nn.Parameter(torch.rand(max_len, max_len // 2))
+ self.col_embed = nn.Parameter(torch.rand(max_len, max_len // 2))
+
+ self.adaptive_pool = (
+ nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+ )
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def preprocess_input(self, src: Tensor) -> Tensor:
+ """Encodes src with a backbone network and a positional encoding.
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: A input src to the transformer.
+
+ """
+ # If batch dimenstion is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+ src = self.backbone(src)
+ # src = self.conv(src)
+ if self.adaptive_pool is not None:
+ src = self.adaptive_pool(src)
+ H, W = src.shape[-2:]
+ src = rearrange(src, "b t h w -> b t (h w)")
+
+ # construct positional encodings
+ pos = torch.cat(
+ [
+ self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
+ self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
+ ],
+ dim=-1,
+ ).unsqueeze(0)
+ pos = rearrange(pos, "b h w l -> b l (h w)")
+ src = pos + 0.1 * src
+ return src
+
+ def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+ """
+ trg = self.character_embedding(trg.long())
+ trg = self.position_encoding(trg)
+ return trg
+
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ h = self.preprocess_input(x)
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.preprocess_target(trg)
+ out = self.transformer(h, trg, trg_mask=trg_mask)
+
+ logits = self.head(out)
+ return logits
diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py
new file mode 100644
index 0000000..93626bf
--- /dev/null
+++ b/src/text_recognizer/networks/cnn_transformer_encoder.py
@@ -0,0 +1,73 @@
+"""Network with a CNN backend and a transformer encoder head."""
+from typing import Dict
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding
+from text_recognizer.networks.util import configure_backbone
+
+
+class CNNTransformerEncoder(nn.Module):
+ """A CNN backbone with Transformer Encoder frontend for sequence prediction."""
+
+ def __init__(
+ self,
+ backbone: str,
+ backbone_args: Dict,
+ mlp_dim: int,
+ d_model: int,
+ nhead: int = 8,
+ dropout_rate: float = 0.1,
+ activation: str = "relu",
+ num_layers: int = 6,
+ num_classes: int = 80,
+ num_channels: int = 256,
+ max_len: int = 97,
+ ) -> None:
+ super().__init__()
+ self.d_model = d_model
+ self.nhead = nhead
+ self.dropout_rate = dropout_rate
+ self.activation = activation
+ self.num_layers = num_layers
+
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.position_encoding = PositionalEncoding(d_model, dropout_rate)
+ self.encoder = self._configure_encoder()
+
+ self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1)
+
+ self.mlp = nn.Linear(mlp_dim, d_model)
+
+ self.head = nn.Linear(d_model, num_classes)
+
+ def _configure_encoder(self) -> nn.TransformerEncoder:
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=self.d_model,
+ nhead=self.nhead,
+ dropout=self.dropout_rate,
+ activation=self.activation,
+ )
+ norm = nn.LayerNorm(self.d_model)
+ return nn.TransformerEncoder(
+ encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm
+ )
+
+ def forward(self, x: Tensor, targets: Tensor = None) -> Tensor:
+ """Forward pass through the network."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+
+ x = self.conv(self.backbone(x))
+ x = rearrange(x, "b c h w -> b c (h w)")
+ x = self.mlp(x)
+ x = self.position_encoding(x)
+ x = rearrange(x, "b c h-> c b h")
+ x = self.encoder(x)
+ x = rearrange(x, "c b h-> b c h")
+ logits = self.head(x)
+
+ return logits
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py
new file mode 100644
index 0000000..9747429
--- /dev/null
+++ b/src/text_recognizer/networks/crnn.py
@@ -0,0 +1,108 @@
+"""LSTM with CTC for handwritten text recognition within a line."""
+from typing import Dict, Tuple
+
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange
+from loguru import logger
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import configure_backbone
+
+
+class ConvolutionalRecurrentNetwork(nn.Module):
+ """Network that takes a image of a text line and predicts tokens that are in the image."""
+
+ def __init__(
+ self,
+ backbone: str,
+ backbone_args: Dict = None,
+ input_size: int = 128,
+ hidden_size: int = 128,
+ bidirectional: bool = False,
+ num_layers: int = 1,
+ num_classes: int = 80,
+ patch_size: Tuple[int, int] = (28, 28),
+ stride: Tuple[int, int] = (1, 14),
+ recurrent_cell: str = "lstm",
+ avg_pool: bool = False,
+ use_sliding_window: bool = True,
+ ) -> None:
+ super().__init__()
+ self.backbone_args = backbone_args or {}
+ self.patch_size = patch_size
+ self.stride = stride
+ self.sliding_window = (
+ self._configure_sliding_window() if use_sliding_window else None
+ )
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.bidirectional = bidirectional
+ self.avg_pool = avg_pool
+
+ if recurrent_cell.upper() in ["LSTM", "GRU"]:
+ recurrent_cell = getattr(nn, recurrent_cell)
+ else:
+ logger.warning(
+ f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
+ )
+ recurrent_cell = nn.LSTM
+
+ self.rnn = recurrent_cell(
+ input_size=self.input_size,
+ hidden_size=self.hidden_size,
+ bidirectional=bidirectional,
+ num_layers=num_layers,
+ )
+
+ decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
+
+ self.decoder = nn.Sequential(
+ nn.Linear(in_features=decoder_size, out_features=num_classes),
+ nn.LogSoftmax(dim=2),
+ )
+
+ def _configure_sliding_window(self) -> nn.Sequential:
+ return nn.Sequential(
+ nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+ Rearrange(
+ "b (c h w) t -> b t c h w",
+ h=self.patch_size[0],
+ w=self.patch_size[1],
+ c=1,
+ ),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+
+ if self.sliding_window is not None:
+ # Create image patches with a sliding window kernel.
+ x = self.sliding_window(x)
+
+ # Rearrange from a sequence of patches for feedforward network.
+ b, t = x.shape[:2]
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+
+ x = self.backbone(x)
+
+ # Avgerage pooling.
+ if self.avg_pool:
+ x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+ else:
+ x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
+ else:
+ # Encode the entire image with a CNN, and use the channels as temporal dimension.
+ b = x.shape[0]
+ x = self.backbone(x)
+ x = rearrange(x, "b c h w -> c b (h w)", b=b)
+
+ # Sequence predictions.
+ x, _ = self.rnn(x)
+
+ # Sequence to classifcation layer.
+ x = self.decoder(x)
+ return x
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index 2493d5c..af9b700 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -33,7 +33,7 @@ def greedy_decoder(
"""
if character_mapper is None:
- character_mapper = EmnistMapper()
+ character_mapper = EmnistMapper(pad_token="_") # noqa: S106
predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
decoded_predictions = []
diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py
new file mode 100644
index 0000000..7dc58d9
--- /dev/null
+++ b/src/text_recognizer/networks/densenet.py
@@ -0,0 +1,225 @@
+"""Defines a Densely Connected Convolutional Networks in PyTorch.
+
+Sources:
+https://arxiv.org/abs/1608.06993
+https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
+
+"""
+from typing import List, Optional, Union
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class _DenseLayer(nn.Module):
+ """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ growth_rate: int,
+ bn_size: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ activation_fn = activation_function(activation)
+ self.dense_layer = [
+ nn.BatchNorm2d(in_channels),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=bn_size * growth_rate,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(bn_size * growth_rate),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=bn_size * growth_rate,
+ out_channels=growth_rate,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ ]
+ if dropout_rate:
+ self.dense_layer.append(nn.Dropout(p=dropout_rate))
+
+ self.dense_layer = nn.Sequential(*self.dense_layer)
+
+ def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor:
+ if isinstance(x, list):
+ x = torch.cat(x, 1)
+ return self.dense_layer(x)
+
+
+class _DenseBlock(nn.Module):
+ def __init__(
+ self,
+ num_layers: int,
+ in_channels: int,
+ bn_size: int,
+ growth_rate: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.dense_block = self._build_dense_blocks(
+ num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation,
+ )
+
+ def _build_dense_blocks(
+ self,
+ num_layers: int,
+ in_channels: int,
+ bn_size: int,
+ growth_rate: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> nn.ModuleList:
+ dense_block = []
+ for i in range(num_layers):
+ dense_block.append(
+ _DenseLayer(
+ in_channels=in_channels + i * growth_rate,
+ growth_rate=growth_rate,
+ bn_size=bn_size,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ )
+ )
+ return nn.ModuleList(dense_block)
+
+ def forward(self, x: Tensor) -> Tensor:
+ feature_maps = [x]
+ for layer in self.dense_block:
+ x = layer(feature_maps)
+ feature_maps.append(x)
+ return torch.cat(feature_maps, 1)
+
+
+class _Transition(nn.Module):
+ def __init__(
+ self, in_channels: int, out_channels: int, activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ activation_fn = activation_function(activation)
+ self.transition = nn.Sequential(
+ nn.BatchNorm2d(in_channels),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.AvgPool2d(kernel_size=2, stride=2),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.transition(x)
+
+
+class DenseNet(nn.Module):
+ """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow."""
+
+ def __init__(
+ self,
+ growth_rate: int = 32,
+ block_config: List[int] = (6, 12, 24, 16),
+ in_channels: int = 1,
+ base_channels: int = 64,
+ num_classes: int = 80,
+ bn_size: int = 4,
+ dropout_rate: float = 0,
+ classifier: bool = True,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.densenet = self._configure_densenet(
+ in_channels,
+ base_channels,
+ num_classes,
+ growth_rate,
+ block_config,
+ bn_size,
+ dropout_rate,
+ classifier,
+ activation,
+ )
+
+ def _configure_densenet(
+ self,
+ in_channels: int,
+ base_channels: int,
+ num_classes: int,
+ growth_rate: int,
+ block_config: List[int],
+ bn_size: int,
+ dropout_rate: float,
+ classifier: bool,
+ activation: str,
+ ) -> nn.Sequential:
+ activation_fn = activation_function(activation)
+ densenet = [
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=base_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(base_channels),
+ activation_fn,
+ ]
+
+ num_features = base_channels
+
+ for i, num_layers in enumerate(block_config):
+ densenet.append(
+ _DenseBlock(
+ num_layers=num_layers,
+ in_channels=num_features,
+ bn_size=bn_size,
+ growth_rate=growth_rate,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ )
+ )
+ num_features = num_features + num_layers * growth_rate
+ if i != len(block_config) - 1:
+ densenet.append(
+ _Transition(
+ in_channels=num_features,
+ out_channels=num_features // 2,
+ activation=activation,
+ )
+ )
+ num_features = num_features // 2
+
+ densenet.append(activation_fn)
+
+ if classifier:
+ densenet.append(nn.AdaptiveAvgPool2d((1, 1)))
+ densenet.append(Rearrange("b c h w -> b (c h w)"))
+ densenet.append(
+ nn.Linear(in_features=num_features, out_features=num_classes)
+ )
+
+ return nn.Sequential(*densenet)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass of Densenet."""
+ # If batch dimenstion is missing, it will be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.densenet(x)
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index 53c575e..527e1a0 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
class LeNet(nn.Module):
@@ -63,6 +63,6 @@ class LeNet(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
return self.layers(x)
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
deleted file mode 100644
index 9009f94..0000000
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ /dev/null
@@ -1,120 +0,0 @@
-"""LSTM with CTC for handwritten text recognition within a line."""
-import importlib
-from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
-
-from einops import rearrange, reduce
-from einops.layers.torch import Rearrange, Reduce
-from loguru import logger
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class LineRecurrentNetwork(nn.Module):
- """Network that takes a image of a text line and predicts tokens that are in the image."""
-
- def __init__(
- self,
- backbone: str,
- backbone_args: Dict = None,
- flatten: bool = True,
- input_size: int = 128,
- hidden_size: int = 128,
- bidirectional: bool = False,
- num_layers: int = 1,
- num_classes: int = 80,
- patch_size: Tuple[int, int] = (28, 28),
- stride: Tuple[int, int] = (1, 14),
- ) -> None:
- super().__init__()
- self.backbone_args = backbone_args or {}
- self.patch_size = patch_size
- self.stride = stride
- self.sliding_window = self._configure_sliding_window()
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.backbone = self._configure_backbone(backbone)
- self.bidirectional = bidirectional
- self.flatten = flatten
-
- if self.flatten:
- self.fc = nn.Linear(
- in_features=self.input_size, out_features=self.hidden_size
- )
-
- self.rnn = nn.LSTM(
- input_size=self.hidden_size,
- hidden_size=self.hidden_size,
- bidirectional=bidirectional,
- num_layers=num_layers,
- )
-
- decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
-
- self.decoder = nn.Sequential(
- nn.Linear(in_features=decoder_size, out_features=num_classes),
- nn.LogSoftmax(dim=2),
- )
-
- def _configure_backbone(self, backbone: str) -> Type[nn.Module]:
- network_module = importlib.import_module("text_recognizer.networks")
- backbone_ = getattr(network_module, backbone)
-
- if "pretrained" in self.backbone_args:
- logger.info("Loading pretrained backbone.")
- checkpoint_file = Path(__file__).resolve().parents[
- 2
- ] / self.backbone_args.pop("pretrained")
-
- # Loading state directory.
- state_dict = torch.load(checkpoint_file)
- network_args = state_dict["network_args"]
- weights = state_dict["model_state"]
-
- # Initializes the network with trained weights.
- backbone = backbone_(**network_args)
- backbone.load_state_dict(weights)
- if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True:
- for params in backbone.parameters():
- params.requires_grad = False
-
- return backbone
- else:
- return backbone_(**self.backbone_args)
-
- def _configure_sliding_window(self) -> nn.Sequential:
- return nn.Sequential(
- nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
- Rearrange(
- "b (c h w) t -> b t c h w",
- h=self.patch_size[0],
- w=self.patch_size[1],
- c=1,
- ),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
- x = self.sliding_window(x)
-
- # Rearrange from a sequence of patches for feedforward network.
- b, t = x.shape[:2]
- x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
- x = self.backbone(x)
-
- # Avgerage pooling.
- x = (
- self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t))
- if self.flatten
- else rearrange(x, "(b t) h -> t b h", b=b, t=t)
- )
-
- # Sequence predictions.
- x, _ = self.rnn(x)
-
- # Sequence to classifcation layer.
- x = self.decoder(x)
- return x
diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss.py
new file mode 100644
index 0000000..cf9fa0d
--- /dev/null
+++ b/src/text_recognizer/networks/loss.py
@@ -0,0 +1,69 @@
+"""Implementations of custom loss functions."""
+from pytorch_metric_learning import distances, losses, miners, reducers
+import torch
+from torch import nn
+from torch import Tensor
+from torch.autograd import Variable
+import torch.nn.functional as F
+
+__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"]
+
+
+class EmbeddingLoss:
+ """Metric loss for training encoders to produce information-rich latent embeddings."""
+
+ def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
+ self.distance = distances.CosineSimilarity()
+ self.reducer = reducers.ThresholdReducer(low=0)
+ self.loss_fn = losses.TripletMarginLoss(
+ margin=margin, distance=self.distance, reducer=self.reducer
+ )
+ self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
+
+ def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
+ """Computes the metric loss for the embeddings based on their labels.
+
+ Args:
+ embeddings (Tensor): The laten vectors encoded by the network.
+ labels (Tensor): Labels of the embeddings.
+
+ Returns:
+ Tensor: The metric loss for the embeddings.
+
+ """
+ hard_pairs = self.miner(embeddings, labels)
+ loss = self.loss_fn(embeddings, labels, hard_pairs)
+ return loss
+
+
+class LabelSmoothingCrossEntropy(nn.Module):
+ """Label smoothing loss function."""
+
+ def __init__(
+ self,
+ classes: int,
+ smoothing: float = 0.0,
+ ignore_index: int = None,
+ dim: int = -1,
+ ) -> None:
+ super().__init__()
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.ignore_index = ignore_index
+ self.cls = classes
+ self.dim = dim
+
+ def forward(self, pred: Tensor, target: Tensor) -> Tensor:
+ """Calculates the loss."""
+ pred = pred.log_softmax(dim=self.dim)
+ with torch.no_grad():
+ # true_dist = pred.data.clone()
+ true_dist = torch.zeros_like(pred)
+ true_dist.fill_(self.smoothing / (self.cls - 1))
+ true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
+ if self.ignore_index is not None:
+ true_dist[:, self.ignore_index] = 0
+ mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
+ if mask.dim() > 0:
+ true_dist.index_fill_(0, mask.squeeze(), 0.0)
+ return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/losses.py
deleted file mode 100644
index 73e0641..0000000
--- a/src/text_recognizer/networks/losses.py
+++ /dev/null
@@ -1,31 +0,0 @@
-"""Implementations of custom loss functions."""
-from pytorch_metric_learning import distances, losses, miners, reducers
-from torch import nn
-from torch import Tensor
-
-
-class EmbeddingLoss:
- """Metric loss for training encoders to produce information-rich latent embeddings."""
-
- def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
- self.distance = distances.CosineSimilarity()
- self.reducer = reducers.ThresholdReducer(low=0)
- self.loss_fn = losses.TripletMarginLoss(
- margin=margin, distance=self.distance, reducer=self.reducer
- )
- self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
-
- def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
- """Computes the metric loss for the embeddings based on their labels.
-
- Args:
- embeddings (Tensor): The laten vectors encoded by the network.
- labels (Tensor): Labels of the embeddings.
-
- Returns:
- Tensor: The metric loss for the embeddings.
-
- """
- hard_pairs = self.miner(embeddings, labels)
- loss = self.loss_fn(embeddings, labels, hard_pairs)
- return loss
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index d66af28..1101912 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
class MLP(nn.Module):
@@ -63,8 +63,8 @@ class MLP(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
return self.layers(x)
@property
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 046600d..6405192 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -7,8 +7,8 @@ import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.misc import activation_function
from text_recognizer.networks.stn import SpatialTransformerNetwork
+from text_recognizer.networks.util import activation_function
class Conv2dAuto(nn.Conv2d):
@@ -225,8 +225,8 @@ class ResidualNetworkEncoder(nn.Module):
in_channels=in_channels,
out_channels=self.block_sizes[0],
kernel_size=3,
- stride=2,
- padding=3,
+ stride=1,
+ padding=1,
bias=False,
),
nn.BatchNorm2d(self.block_sizes[0]),
diff --git a/src/text_recognizer/networks/sparse_mlp.py b/src/text_recognizer/networks/sparse_mlp.py
new file mode 100644
index 0000000..53cf166
--- /dev/null
+++ b/src/text_recognizer/networks/sparse_mlp.py
@@ -0,0 +1,78 @@
+"""Defines the Sparse MLP network."""
+from typing import Callable, Dict, List, Optional, Union
+import warnings
+
+from einops.layers.torch import Rearrange
+from pytorch_block_sparse import BlockSparseLinear
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+
+class SparseMLP(nn.Module):
+ """Sparse multi layered perceptron network."""
+
+ def __init__(
+ self,
+ input_size: int = 784,
+ num_classes: int = 10,
+ hidden_size: Union[int, List] = 128,
+ num_layers: int = 3,
+ density: float = 0.1,
+ activation_fn: str = "relu",
+ ) -> None:
+ """Initialization of the MLP network.
+
+ Args:
+ input_size (int): The input shape of the network. Defaults to 784.
+ num_classes (int): Number of classes in the dataset. Defaults to 10.
+ hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
+ num_layers (int): The number of hidden layers. Defaults to 3.
+ density (float): The density of activation at each layer. Default to 0.1.
+ activation_fn (str): Name of the activation function in the hidden layers. Defaults to
+ relu.
+
+ """
+ super().__init__()
+
+ activation_fn = activation_function(activation_fn)
+
+ if isinstance(hidden_size, int):
+ hidden_size = [hidden_size] * num_layers
+
+ self.layers = [
+ Rearrange("b c h w -> b (c h w)"),
+ nn.Linear(in_features=input_size, out_features=hidden_size[0]),
+ activation_fn,
+ ]
+
+ for i in range(num_layers - 1):
+ self.layers += [
+ BlockSparseLinear(
+ in_features=hidden_size[i],
+ out_features=hidden_size[i + 1],
+ density=density,
+ ),
+ activation_fn,
+ ]
+
+ self.layers.append(
+ nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
+ )
+
+ self.layers = nn.Sequential(*self.layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.layers(x)
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the network."""
+ return "mlp"
diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py
deleted file mode 100644
index c091ba0..0000000
--- a/src/text_recognizer/networks/transformer.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""TBC."""
-from typing import Dict
-
-import torch
-from torch import Tensor
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py
new file mode 100644
index 0000000..020a917
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/__init__.py
@@ -0,0 +1,3 @@
+"""Transformer modules."""
+from .positional_encoding import PositionalEncoding
+from .transformer import Decoder, Encoder, Transformer
diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py
new file mode 100644
index 0000000..cce1ecc
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/attention.py
@@ -0,0 +1,93 @@
+"""Implementes the attention module for the transformer."""
+from typing import Optional, Tuple
+
+from einops import rearrange
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class MultiHeadAttention(nn.Module):
+ """Implementation of multihead attention."""
+
+ def __init__(
+ self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ self.fc_q = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_k = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_v = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
+
+ self._init_weights()
+
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def _init_weights(self) -> None:
+ nn.init.normal_(
+ self.fc_q.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.normal_(
+ self.fc_k.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.normal_(
+ self.fc_v.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.xavier_normal_(self.fc_out.weight)
+
+ def scaled_dot_product_attention(
+ self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ ) -> Tensor:
+ """Calculates the scaled dot product attention."""
+
+ # Compute the energy.
+ energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
+ query.shape[-1]
+ )
+
+ # If we have a mask for padding some inputs.
+ if mask is not None:
+ energy = energy.masked_fill(mask == 0, -np.inf)
+
+ # Compute the attention from the energy.
+ attention = torch.softmax(energy, dim=3)
+
+ out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
+ out = rearrange(out, "b head l v -> b l (head v)")
+ return out, attention
+
+ def forward(
+ self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Tensor]:
+ """Forward pass for computing the multihead attention."""
+ # Get the query, key, and value tensor.
+ query = rearrange(
+ self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
+ )
+ key = rearrange(
+ self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
+ )
+ value = rearrange(
+ self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
+ )
+
+ out, attention = self.scaled_dot_product_attention(query, key, value, mask)
+
+ out = self.fc_out(out)
+ out = self.dropout(out)
+ return out, attention
diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py
new file mode 100644
index 0000000..1ba5537
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/positional_encoding.py
@@ -0,0 +1,32 @@
+"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+ """Encodes a sense of distance or time for transformer networks."""
+
+ def __init__(
+ self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
+ ) -> None:
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout_rate)
+ self.max_len = max_len
+
+ pe = torch.zeros(max_len, hidden_dim)
+ position = torch.arange(0, max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
+ )
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Encodes the tensor with a postional embedding."""
+ x = x + self.pe[:, : x.shape[1]]
+ return self.dropout(x)
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
new file mode 100644
index 0000000..c6e943e
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/transformer.py
@@ -0,0 +1,242 @@
+"""Transfomer module."""
+import copy
+from typing import Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer.attention import MultiHeadAttention
+from text_recognizer.networks.util import activation_function
+
+
+def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
+
+
+class _IntraLayerConnection(nn.Module):
+ """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
+
+ def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
+ super().__init__()
+ self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward(self, src: Tensor, residual: Tensor) -> Tensor:
+ return self.norm(self.dropout(src) + residual)
+
+
+class _ConvolutionalLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_dim: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.layer = nn.Sequential(
+ nn.Linear(in_features=hidden_dim, out_features=expansion_dim),
+ activation_function(activation),
+ nn.Dropout(p=dropout_rate),
+ nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.layer(x)
+
+
+class EncoderLayer(nn.Module):
+ """Transfomer encoding layer."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+ self.cnn = _ConvolutionalLayer(
+ hidden_dim, expansion_dim, dropout_rate, activation
+ )
+ self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ """Forward pass through the encoder."""
+ # First block.
+ # Multi head attention.
+ out, _ = self.self_attention(src, src, src, mask)
+
+ # Add & norm.
+ out = self.block1(out, src)
+
+ # Second block.
+ # Apply 1D-convolution.
+ cnn_out = self.cnn(out)
+
+ # Add & norm.
+ out = self.block2(cnn_out, out)
+
+ return out
+
+
+class Encoder(nn.Module):
+ """Transfomer encoder module."""
+
+ def __init__(
+ self,
+ num_layers: int,
+ encoder_layer: Type[nn.Module],
+ norm: Optional[Type[nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.norm = norm
+
+ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
+ """Forward pass through all encoder layers."""
+ for layer in self.layers:
+ src = layer(src, src_mask)
+
+ if self.norm is not None:
+ src = self.norm(src)
+
+ return src
+
+
+class DecoderLayer(nn.Module):
+ """Transfomer decoder layer."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float = 0.0,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+ self.multihead_attention = MultiHeadAttention(
+ hidden_dim, num_heads, dropout_rate
+ )
+ self.cnn = _ConvolutionalLayer(
+ hidden_dim, expansion_dim, dropout_rate, activation
+ )
+ self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+ def forward(
+ self,
+ trg: Tensor,
+ memory: Tensor,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass of the layer."""
+ out, _ = self.self_attention(trg, trg, trg, trg_mask)
+ trg = self.block1(out, trg)
+
+ out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
+ trg = self.block2(out, trg)
+
+ out = self.cnn(trg)
+ out = self.block3(out, trg)
+
+ return out
+
+
+class Decoder(nn.Module):
+ """Transfomer decoder module."""
+
+ def __init__(
+ self,
+ decoder_layer: Type[nn.Module],
+ num_layers: int,
+ norm: Optional[Type[nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ trg: Tensor,
+ memory: Tensor,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass through the decoder."""
+ for layer in self.layers:
+ trg = layer(trg, memory, trg_mask, memory_mask)
+
+ if self.norm is not None:
+ trg = self.norm(trg)
+
+ return trg
+
+
+class Transformer(nn.Module):
+ """Transformer network."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+
+ # Configure encoder.
+ encoder_norm = nn.LayerNorm(hidden_dim)
+ encoder_layer = EncoderLayer(
+ hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+ )
+ self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
+
+ # Configure decoder.
+ decoder_norm = nn.LayerNorm(hidden_dim)
+ decoder_layer = DecoderLayer(
+ hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+ )
+ self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(
+ self,
+ src: Tensor,
+ trg: Tensor,
+ src_mask: Optional[Tensor] = None,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass through the transformer."""
+ if src.shape[0] != trg.shape[0]:
+ print(trg.shape)
+ raise RuntimeError("The batch size of the src and trg must be the same.")
+ if src.shape[2] != trg.shape[2]:
+ raise RuntimeError(
+ "The number of features for the src and trg must be the same."
+ )
+
+ memory = self.encoder(src, src_mask)
+ output = self.decoder(trg, memory, trg_mask, memory_mask)
+ return output
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/util.py
index 1f853e9..b31e640 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/util.py
@@ -1,7 +1,10 @@
"""Miscellaneous neural network functionality."""
-from typing import Tuple, Type
+import importlib
+from pathlib import Path
+from typing import Dict, Tuple, Type
from einops import rearrange
+from loguru import logger
import torch
from torch import nn
@@ -25,7 +28,7 @@ def sliding_window(
c = images.shape[1]
patches = unfold(images)
patches = rearrange(
- patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1]
+ patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
)
return patches
@@ -43,3 +46,38 @@ def activation_function(activation: str) -> Type[nn.Module]:
]
)
return activation_fns[activation.lower()]
+
+
+def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
+ """Loads a backbone network."""
+ network_module = importlib.import_module("text_recognizer.networks")
+ backbone_ = getattr(network_module, backbone)
+
+ if "pretrained" in backbone_args:
+ logger.info("Loading pretrained backbone.")
+ checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
+ "pretrained"
+ )
+
+ # Loading state directory.
+ state_dict = torch.load(checkpoint_file)
+ network_args = state_dict["network_args"]
+ weights = state_dict["model_state"]
+
+ # Initializes the network with trained weights.
+ backbone = backbone_(**network_args)
+ backbone.load_state_dict(weights)
+ if "freeze" in backbone_args and backbone_args["freeze"] is True:
+ for params in backbone.parameters():
+ params.requires_grad = False
+
+ else:
+ backbone_ = getattr(network_module, backbone)
+ backbone = backbone_(**backbone_args)
+
+ if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
+ backbone = nn.Sequential(
+ *list(backbone.children())[:][: -backbone_args["remove_layers"]]
+ )
+
+ return backbone
diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py
new file mode 100644
index 0000000..f227954
--- /dev/null
+++ b/src/text_recognizer/networks/vision_transformer.py
@@ -0,0 +1,159 @@
+"""VisionTransformer module.
+
+Splits each image into patches and feeds them to a transformer.
+
+"""
+
+from typing import Dict, Optional, Tuple, Type
+
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange
+from loguru import logger
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import configure_backbone
+
+
+class VisionTransformer(nn.Module):
+ """Linear projection+Transfomer for image to sequence prediction, sort of based on the ideas from ViT."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ max_len: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ trg_pad_index: int,
+ mlp_dim: Optional[int] = None,
+ patch_size: Tuple[int, int] = (28, 28),
+ stride: Tuple[int, int] = (1, 14),
+ activation: str = "gelu",
+ backbone: Optional[str] = None,
+ backbone_args: Optional[Dict] = None,
+ ) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.stride = stride
+ self.trg_pad_index = trg_pad_index
+ self.slidning_window = self._configure_sliding_window()
+ self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+ self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
+ self.mlp_dim = mlp_dim
+
+ self.use_backbone = False
+ if backbone is None:
+ self.linear_projection = nn.Linear(
+ self.patch_size[0] * self.patch_size[1], hidden_dim
+ )
+ else:
+ self.backbone = configure_backbone(backbone, backbone_args)
+ if mlp_dim:
+ self.mlp = nn.Linear(mlp_dim, hidden_dim)
+ self.use_backbone = True
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+ def _configure_sliding_window(self) -> nn.Sequential:
+ return nn.Sequential(
+ nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+ Rearrange(
+ "b (c h w) t -> b t c h w",
+ h=self.patch_size[0],
+ w=self.patch_size[1],
+ c=1,
+ ),
+ )
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def _backbone(self, x: Tensor) -> Tensor:
+ b, t = x.shape[:2]
+ if self.use_backbone:
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+ x = self.backbone(x)
+ if self.mlp_dim:
+ x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t)
+ x = self.mlp(x)
+ else:
+ x = rearrange(x, "(b t) h -> b t h", b=b, t=t)
+ else:
+ x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t)
+ x = self.linear_projection(x)
+ return x
+
+ def preprocess_input(self, src: Tensor) -> Tensor:
+ """Encodes src with a backbone network and a positional encoding.
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: A input src to the transformer.
+
+ """
+ # If batch dimenstion is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+ src = self.slidning_window(src) # .squeeze(-2)
+ src = self._backbone(src)
+ src = self.position_encoding(src)
+ return src
+
+ def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+ """
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.character_embedding(trg.long())
+ trg = self.position_encoding(trg)
+ return trg, trg_mask
+
+ def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+ """Forward pass with vision transfomer."""
+ src = self.preprocess_input(x)
+ trg, trg_mask = self.preprocess_target(trg)
+ out = self.transformer(src, trg, trg_mask=trg_mask)
+ logits = self.head(out)
+ return logits
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index 618f414..aa79c12 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -8,7 +8,7 @@ import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
@@ -206,8 +206,8 @@ class WideResidualNetwork(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Feedforward pass."""
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
+ if len(x.shape) < 4:
+ x = x[(None,) * int(4 - len(x.shape))]
x = self.encoder(x)
if self.decoder is not None:
x = self.decoder(x)
diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py
index 5dd1a81..c04860d 100644
--- a/src/text_recognizer/tests/support/create_emnist_support_files.py
+++ b/src/text_recognizer/tests/support/create_emnist_support_files.py
@@ -2,10 +2,8 @@
from pathlib import Path
import shutil
-from text_recognizer.datasets.emnist_dataset import (
- fetch_emnist_dataset,
- load_emnist_mapping,
-)
+from text_recognizer.datasets.emnist_dataset import EmnistDataset
+from text_recognizer.datasets.util import EmnistMapper
from text_recognizer.util import write_image
SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist"
@@ -16,15 +14,16 @@ def create_emnist_support_files() -> None:
shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
SUPPORT_DIRNAME.mkdir()
- dataset = fetch_emnist_dataset(split="byclass", train=False)
- mapping = load_emnist_mapping()
+ dataset = EmnistDataset(train=False)
+ dataset.load_or_generate_data()
+ mapping = EmnistMapper()
for index in [5, 7, 9]:
image, label = dataset[index]
if len(image.shape) == 3:
image = image.squeeze(0)
image = image.numpy()
- label = mapping[int(label)]
+ label = mapping(int(label))
print(index, label)
write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
diff --git a/src/text_recognizer/tests/test_line_predictor.py b/src/text_recognizer/tests/test_line_predictor.py
new file mode 100644
index 0000000..eede4d4
--- /dev/null
+++ b/src/text_recognizer/tests/test_line_predictor.py
@@ -0,0 +1,35 @@
+"""Tests for LinePredictor."""
+import os
+from pathlib import Path
+import unittest
+
+
+import editdistance
+import numpy as np
+
+from text_recognizer.datasets import IamLinesDataset
+from text_recognizer.line_predictor import LinePredictor
+import text_recognizer.util as util
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support"
+
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+
+class TestEmnistLinePredictor(unittest.TestCase):
+ """Test LinePredictor class on the EmnistLines dataset."""
+
+ def test_filename(self) -> None:
+ """Test that LinePredictor correctly predicts on single images, for several test images."""
+ predictor = LinePredictor(
+ dataset="EmnistLineDataset", network_fn="CNNTransformer"
+ )
+
+ for filename in (SUPPORT_DIRNAME / "emnist_lines").glob("*.png"):
+ pred, conf = predictor.predict(str(filename))
+ true = str(filename.stem)
+ edit_distance = editdistance.eval(pred, true) / len(pred)
+ print(
+ f'Pred: "{pred}" | Confidence: {conf} | True: {true} | Edit distance: {edit_distance}'
+ )
+ self.assertLess(edit_distance, 0.2)
diff --git a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
new file mode 100644
index 0000000..726c723
--- /dev/null
+++ b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
new file mode 100644
index 0000000..6a9a915
--- /dev/null
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt
deleted file mode 100644
index 676eb44..0000000
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt
+++ /dev/null
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
deleted file mode 100644
index 32c83cc..0000000
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt
+++ /dev/null
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt
deleted file mode 100644
index 9f9deee..0000000
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt
+++ /dev/null
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
deleted file mode 100644
index 0dc7eb5..0000000
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt
+++ /dev/null
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt
deleted file mode 100644
index e720299..0000000
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt
+++ /dev/null
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
index ed73c09..2d5a89b 100644
--- a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt
deleted file mode 100644
index 4ec12c1..0000000
--- a/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt
+++ /dev/null
Binary files differ
diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
index 93d34d7..7fe1fa3 100644
--- a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
+++ b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
Binary files differ