summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--poetry.lock14
-rw-r--r--pyproject.toml1
-rw-r--r--text_recognizer/__init__.py1
-rw-r--r--text_recognizer/character_predictor.py29
-rw-r--r--text_recognizer/data/iam_paragraphs_dataset.py291
-rw-r--r--text_recognizer/data/util.py209
-rw-r--r--text_recognizer/line_predictor.py28
-rw-r--r--text_recognizer/models/__init__.py18
-rw-r--r--text_recognizer/models/base.py455
-rw-r--r--text_recognizer/models/character_model.py88
-rw-r--r--text_recognizer/models/crnn_model.py119
-rw-r--r--text_recognizer/models/ctc_transformer_model.py120
-rw-r--r--text_recognizer/models/segmentation_model.py75
-rw-r--r--text_recognizer/models/transformer_model.py124
-rw-r--r--text_recognizer/models/vqvae_model.py80
-rw-r--r--text_recognizer/networks/__init__.py43
-rw-r--r--text_recognizer/networks/beam.py83
-rw-r--r--text_recognizer/networks/cnn.py101
-rw-r--r--text_recognizer/networks/crnn.py110
-rw-r--r--text_recognizer/networks/ctc.py58
-rw-r--r--text_recognizer/networks/densenet.py225
-rw-r--r--text_recognizer/networks/lenet.py68
-rw-r--r--text_recognizer/networks/metrics.py123
-rw-r--r--text_recognizer/networks/mlp.py73
-rw-r--r--text_recognizer/networks/stn.py44
-rw-r--r--text_recognizer/networks/unet.py255
-rw-r--r--text_recognizer/networks/vit.py150
-rw-r--r--text_recognizer/paragraph_text_recognizer.py153
28 files changed, 14 insertions, 3124 deletions
diff --git a/poetry.lock b/poetry.lock
index fa374b6..6926f1f 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -853,6 +853,14 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
dev = ["codecov (>=2.0.15)", "colorama (>=0.3.4)", "flake8 (>=3.7.7)", "tox (>=3.9.0)", "tox-travis (>=0.12)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "Sphinx (>=2.2.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "black (>=19.10b0)", "isort (>=5.1.1)"]
[[package]]
+name = "madgrad"
+version = "1.0"
+description = "A general purpose PyTorch Optimizer"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
name = "markdown"
version = "3.3.4"
description = "Python implementation of Markdown."
@@ -2162,7 +2170,7 @@ multidict = ">=4.0"
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
-content-hash = "cffb5a23a46f3be6be0b8ea8289cfb97c8ba9722f869bbfc2af75cb80a737877"
+content-hash = "db4253add1258abaf637f127ba576b1ec5e0b415c8f7f93b18ecdca40bc6f042"
[metadata.files]
absl-py = [
@@ -2641,6 +2649,10 @@ loguru = [
{file = "loguru-0.5.3-py3-none-any.whl", hash = "sha256:f8087ac396b5ee5f67c963b495d615ebbceac2796379599820e324419d53667c"},
{file = "loguru-0.5.3.tar.gz", hash = "sha256:b28e72ac7a98be3d28ad28570299a393dfcd32e5e3f6a353dec94675767b6319"},
]
+madgrad = [
+ {file = "madgrad-1.0-py3-none-any.whl", hash = "sha256:cd5239a1274ee025abec14c99d2af06b11783a379da32cbe2f4b07fc81ef20ea"},
+ {file = "madgrad-1.0.tar.gz", hash = "sha256:5a34e1d295ebb2f85fbf9e09ed3b548e27908471bbe2506dda35de5a471c0cbe"},
+]
markdown = [
{file = "Markdown-3.3.4-py3-none-any.whl", hash = "sha256:96c3ba1261de2f7547b46a00ea8463832c921d3f9d6aba3f255a6f71386db20c"},
{file = "Markdown-3.3.4.tar.gz", hash = "sha256:31b5b491868dcc87d6c24b7e3d19a0d730d59d3e46f4eea6430a321bed387a49"},
diff --git a/pyproject.toml b/pyproject.toml
index e791dd9..32bdb0e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -38,6 +38,7 @@ gtn = "^0.0.0"
sentencepiece = "^0.1.95"
pytorch-lightning = "^1.2.4"
Pillow = "^8.1.2"
+madgrad = "^1.0"
[tool.poetry.dev-dependencies]
pytest = "^5.4.2"
diff --git a/text_recognizer/__init__.py b/text_recognizer/__init__.py
index 3dc1f76..e69de29 100644
--- a/text_recognizer/__init__.py
+++ b/text_recognizer/__init__.py
@@ -1 +0,0 @@
-__version__ = "0.1.0"
diff --git a/text_recognizer/character_predictor.py b/text_recognizer/character_predictor.py
deleted file mode 100644
index ad71289..0000000
--- a/text_recognizer/character_predictor.py
+++ /dev/null
@@ -1,29 +0,0 @@
-"""CharacterPredictor class."""
-from typing import Dict, Tuple, Type, Union
-
-import numpy as np
-from torch import nn
-
-from text_recognizer import datasets, networks
-from text_recognizer.models import CharacterModel
-from text_recognizer.util import read_image
-
-
-class CharacterPredictor:
- """Recognizes the character in handwritten character images."""
-
- def __init__(self, network_fn: str, dataset: str) -> None:
- """Intializes the CharacterModel and load the pretrained weights."""
- network_fn = getattr(networks, network_fn)
- dataset = getattr(datasets, dataset)
- self.model = CharacterModel(network_fn=network_fn, dataset=dataset)
- self.model.eval()
- self.model.use_swa_model()
-
- def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]:
- """Predict on a single images contianing a handwritten character."""
- if isinstance(image_or_filename, str):
- image = read_image(image_or_filename, grayscale=True)
- else:
- image = image_or_filename
- return self.model.predict_on_image(image)
diff --git a/text_recognizer/data/iam_paragraphs_dataset.py b/text_recognizer/data/iam_paragraphs_dataset.py
deleted file mode 100644
index 8ba5142..0000000
--- a/text_recognizer/data/iam_paragraphs_dataset.py
+++ /dev/null
@@ -1,291 +0,0 @@
-"""IamParagraphsDataset class and functions for data processing."""
-import random
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import click
-import cv2
-import h5py
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from torchvision.transforms import ToTensor
-
-from text_recognizer import util
-from text_recognizer.datasets.dataset import Dataset
-from text_recognizer.datasets.iam_dataset import IamDataset
-from text_recognizer.datasets.util import (
- compute_sha256,
- DATA_DIRNAME,
- download_url,
- EmnistMapper,
-)
-
-INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
-DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
-PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs"
-CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops"
-GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt"
-
-PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines.
-TEST_FRACTION = 0.2
-SEED = 4711
-
-
-class IamParagraphsDataset(Dataset):
- """IAM Paragraphs dataset for paragraphs of handwritten text."""
-
- def __init__(
- self,
- train: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- ) -> None:
- super().__init__(
- train=train,
- subsample_fraction=subsample_fraction,
- transform=transform,
- target_transform=target_transform,
- )
- # Load Iam dataset.
- self.iam_dataset = IamDataset()
-
- self.num_classes = 3
- self._input_shape = (256, 256)
- self._output_shape = self._input_shape + (self.num_classes,)
- self._ids = None
-
- def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
- """Fetches data, target pair of the dataset for a given and index or indices.
-
- Args:
- index (Union[int, Tensor]): Either a list or int of indices/index.
-
- Returns:
- Tuple[Tensor, Tensor]: Data target pair.
-
- """
- if torch.is_tensor(index):
- index = index.tolist()
-
- data = self.data[index]
- targets = self.targets[index]
-
- seed = np.random.randint(SEED)
- random.seed(seed) # apply this seed to target tranfsorms
- torch.manual_seed(seed) # needed for torchvision 0.7
- if self.transform:
- data = self.transform(data)
-
- random.seed(seed) # apply this seed to target tranfsorms
- torch.manual_seed(seed) # needed for torchvision 0.7
- if self.target_transform:
- targets = self.target_transform(targets)
-
- return data, targets.long()
-
- @property
- def ids(self) -> Tensor:
- """Ids of the dataset."""
- return self._ids
-
- def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]:
- """Get data target pair from id."""
- ind = self.ids.index(id_)
- return self.data[ind], self.targets[ind]
-
- def load_or_generate_data(self) -> None:
- """Load or generate dataset data."""
- num_actual = len(list(CROPS_DIRNAME.glob("*.jpg")))
- num_targets = len(self.iam_dataset.line_regions_by_id)
-
- if num_actual < num_targets - 2:
- self._process_iam_paragraphs()
-
- self._data, self._targets, self._ids = _load_iam_paragraphs()
- self._get_random_split()
- self._subsample()
-
- def _get_random_split(self) -> None:
- np.random.seed(SEED)
- num_train = int((1 - TEST_FRACTION) * self.data.shape[0])
- indices = np.random.permutation(self.data.shape[0])
- train_indices, test_indices = indices[:num_train], indices[num_train:]
- if self.train:
- self._data = self.data[train_indices]
- self._targets = self.targets[train_indices]
- else:
- self._data = self.data[test_indices]
- self._targets = self.targets[test_indices]
-
- def _process_iam_paragraphs(self) -> None:
- """Crop the part with the text.
-
- For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are
- self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel
- corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line
- """
- crop_dims = self._decide_on_crop_dims()
- CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
- DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
- GT_DIRNAME.mkdir(parents=True, exist_ok=True)
- logger.info(
- f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}"
- )
- for filename in self.iam_dataset.form_filenames:
- id_ = filename.stem
- line_region = self.iam_dataset.line_regions_by_id[id_]
- _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape)
-
- def _decide_on_crop_dims(self) -> Tuple[int, int]:
- """Decide on the dimensions to crop out of the form image.
-
- Since image width is larger than a comfortable crop around the longest paragraph,
- we will make the crop a square form factor.
- And since the found dimensions 610x610 are pretty close to 512x512,
- we might as well resize crops and make it exactly that, which lets us
- do all kinds of power-of-2 pooling and upsampling should we choose to.
-
- Returns:
- Tuple[int, int]: A tuple of crop dimensions.
-
- Raises:
- RuntimeError: When max crop height is larger than max crop width.
-
- """
-
- sample_form_filename = self.iam_dataset.form_filenames[0]
- sample_image = util.read_image(sample_form_filename, grayscale=True)
- max_crop_width = sample_image.shape[1]
- max_crop_height = _get_max_paragraph_crop_height(
- self.iam_dataset.line_regions_by_id
- )
- if not max_crop_height <= max_crop_width:
- raise RuntimeError(
- f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}"
- )
-
- crop_dims = (max_crop_width, max_crop_width)
- logger.info(
- f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}."
- )
- logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
- return crop_dims
-
- def __repr__(self) -> str:
- """Return info about the dataset."""
- return (
- "IAM Paragraph Dataset\n" # pylint: disable=no-member
- f"Num classes: {self.num_classes}\n"
- f"Data: {self.data.shape}\n"
- f"Targets: {self.targets.shape}\n"
- )
-
-
-def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int:
- heights = []
- for regions in line_regions_by_id.values():
- min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER
- max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER
- height = max_y2 - min_y1
- heights.append(height)
- return max(heights)
-
-
-def _crop_paragraph_image(
- filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple
-) -> None:
- image = util.read_image(filename, grayscale=True)
-
- min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER
- max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER
- height = max_y2 - min_y1
- crop_height = crop_dims[0]
- buffer = (crop_height - height) // 2
-
- # Generate image crop.
- image_crop = 255 * np.ones(crop_dims, dtype=np.uint8)
- try:
- image_crop[buffer : buffer + height] = image[min_y1:max_y2]
- except Exception as e: # pylint: disable=broad-except
- logger.error(f"Rescued {filename}: {e}")
- return
-
- # Generate ground truth.
- gt_image = np.zeros_like(image_crop, dtype=np.uint8)
- for index, region in enumerate(line_regions):
- gt_image[
- (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer),
- region["x1"] : region["x2"],
- ] = (index % 2 + 1)
-
- # Generate image for debugging.
- import matplotlib.pyplot as plt
-
- cmap = plt.get_cmap("Set1")
- image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop])
- for index, region in enumerate(line_regions):
- color = [255 * _ for _ in cmap(index)[:-1]]
- cv2.rectangle(
- image_crop_for_debug,
- (region["x1"], region["y1"] - min_y1 + buffer),
- (region["x2"], region["y2"] - min_y1 + buffer),
- color,
- 3,
- )
- image_crop_for_debug = cv2.resize(
- image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA
- )
- util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg")
-
- image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA)
- util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg")
-
- gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST)
- util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png")
-
-
-def _load_iam_paragraphs() -> None:
- logger.info("Loading IAM paragraph crops and ground truth from image files...")
- images = []
- gt_images = []
- ids = []
- for filename in CROPS_DIRNAME.glob("*.jpg"):
- id_ = filename.stem
- image = util.read_image(filename, grayscale=True)
- image = 1.0 - image / 255
-
- gt_filename = GT_DIRNAME / f"{id_}.png"
- gt_image = util.read_image(gt_filename, grayscale=True)
-
- images.append(image)
- gt_images.append(gt_image)
- ids.append(id_)
- images = np.array(images).astype(np.float32)
- gt_images = np.array(gt_images).astype(np.uint8)
- ids = np.array(ids)
- return images, gt_images, ids
-
-
-@click.command()
-@click.option(
- "--subsample_fraction",
- type=float,
- default=None,
- help="The subsampling factor of the dataset.",
-)
-def main(subsample_fraction: float) -> None:
- """Load dataset and print info."""
- logger.info("Creating train set...")
- dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction)
- dataset.load_or_generate_data()
- print(dataset)
- logger.info("Creating test set...")
- dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
- dataset.load_or_generate_data()
- print(dataset)
-
-
-if __name__ == "__main__":
- main()
diff --git a/text_recognizer/data/util.py b/text_recognizer/data/util.py
deleted file mode 100644
index da87756..0000000
--- a/text_recognizer/data/util.py
+++ /dev/null
@@ -1,209 +0,0 @@
-"""Util functions for datasets."""
-import hashlib
-import json
-import os
-from pathlib import Path
-import string
-from typing import Dict, List, Optional, Union
-from urllib.request import urlretrieve
-
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from torchvision.datasets import EMNIST
-from tqdm import tqdm
-
-DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
-ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
-
-
-def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None:
- """Extract and saves EMNIST essentials."""
- labels = emnsit_dataset.classes
- labels.sort()
- mapping = [(i, str(label)) for i, label in enumerate(labels)]
- essentials = {
- "mapping": mapping,
- "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]),
- }
- logger.info("Saving emnist essentials...")
- with open(ESSENTIALS_FILENAME, "w") as f:
- json.dump(essentials, f)
-
-
-def download_emnist() -> None:
- """Download the EMNIST dataset via the PyTorch class."""
- logger.info(f"Data directory is: {DATA_DIRNAME}")
- dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
- save_emnist_essentials(dataset)
-
-
-class EmnistMapper:
- """Mapper between network output to Emnist character."""
-
- def __init__(
- self,
- pad_token: str,
- init_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- """Loads the emnist essentials file with the mapping and input shape."""
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- self.lower = lower
-
- self.essentials = self._load_emnist_essentials()
- # Load dataset information.
- self._mapping = dict(self.essentials["mapping"])
- self._augment_emnist_mapping()
- self._inverse_mapping = {v: k for k, v in self.mapping.items()}
- self._num_classes = len(self.mapping)
- self._input_shape = self.essentials["input_shape"]
-
- def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]:
- """Maps the token to emnist character or character index.
-
- If the token is an integer (index), the method will return the Emnist character corresponding to that index.
- If the token is a str (Emnist character), the method will return the corresponding index for that character.
-
- Args:
- token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer).
-
- Returns:
- Union[str, int]: The mapping result.
-
- Raises:
- KeyError: If the index or string does not exist in the mapping.
-
- """
- if (
- (isinstance(token, np.uint8) or isinstance(token, int))
- or torch.is_tensor(token)
- and int(token) in self.mapping
- ):
- return self.mapping[int(token)]
- elif isinstance(token, str) and token in self._inverse_mapping:
- return self._inverse_mapping[token]
- else:
- raise KeyError(f"Token {token} does not exist in the mappings.")
-
- @property
- def mapping(self) -> Dict:
- """Returns the mapping between index and character."""
- return self._mapping
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the mapping between character and index."""
- return self._inverse_mapping
-
- @property
- def num_classes(self) -> int:
- """Returns the number of classes in the dataset."""
- return self._num_classes
-
- @property
- def input_shape(self) -> List[int]:
- """Returns the input shape of the Emnist characters."""
- return self._input_shape
-
- def _load_emnist_essentials(self) -> Dict:
- """Load the EMNIST mapping."""
- with open(str(ESSENTIALS_FILENAME)) as f:
- essentials = json.load(f)
- return essentials
-
- def _augment_emnist_mapping(self) -> None:
- """Augment the mapping with extra symbols."""
- # Extra symbols in IAM dataset
- if self.lower:
- self._mapping = {
- k: str(v)
- for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase))
- }
-
- extra_symbols = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
-
- # padding symbol, and acts as blank symbol as well.
- extra_symbols.append(self.pad_token)
-
- if self.init_token is not None:
- extra_symbols.append(self.init_token)
-
- if self.eos_token is not None:
- extra_symbols.append(self.eos_token)
-
- max_key = max(self.mapping.keys())
- extra_mapping = {}
- for i, symbol in enumerate(extra_symbols):
- extra_mapping[max_key + 1 + i] = symbol
-
- self._mapping = {**self.mapping, **extra_mapping}
-
-
-def compute_sha256(filename: Union[Path, str]) -> str:
- """Returns the SHA256 checksum of a file."""
- with open(filename, "rb") as f:
- return hashlib.sha256(f.read()).hexdigest()
-
-
-class TqdmUpTo(tqdm):
- """TQDM progress bar when downloading files.
-
- From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
-
- """
-
- def update_to(
- self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
- ) -> None:
- """Updates the progress bar.
-
- Args:
- blocks (int): Number of blocks transferred so far. Defaults to 1.
- block_size (int): Size of each block, in tqdm units. Defaults to 1.
- total_size (Optional[int]): Total size in tqdm units. Defaults to None.
- """
- if total_size is not None:
- self.total = total_size # pylint: disable=attribute-defined-outside-init
- self.update(blocks * block_size - self.n)
-
-
-def download_url(url: str, filename: str) -> None:
- """Downloads a file from url to filename, with a progress bar."""
- with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
- urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
-
-
-def _download_raw_dataset(metadata: Dict) -> None:
- if os.path.exists(metadata["filename"]):
- return
- logger.info(f"Downloading raw dataset from {metadata['url']}...")
- download_url(metadata["url"], metadata["filename"])
- logger.info("Computing SHA-256...")
- sha256 = compute_sha256(metadata["filename"])
- if sha256 != metadata["sha256"]:
- raise ValueError(
- "Downloaded data file SHA-256 does not match that listed in metadata document."
- )
diff --git a/text_recognizer/line_predictor.py b/text_recognizer/line_predictor.py
deleted file mode 100644
index 8e348fe..0000000
--- a/text_recognizer/line_predictor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-"""LinePredictor class."""
-import importlib
-from typing import Tuple, Union
-
-import numpy as np
-from torch import nn
-
-from text_recognizer import datasets, networks
-from text_recognizer.models import TransformerModel
-from text_recognizer.util import read_image
-
-
-class LinePredictor:
- """Given an image of a line of handwritten text, recognizes the text content."""
-
- def __init__(self, dataset: str, network_fn: str) -> None:
- network_fn = getattr(networks, network_fn)
- dataset = getattr(datasets, dataset)
- self.model = TransformerModel(network_fn=network_fn, dataset=dataset)
- self.model.eval()
-
- def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]:
- """Predict on a single images contianing a handwritten character."""
- if isinstance(image_or_filename, str):
- image = read_image(image_or_filename, grayscale=True)
- else:
- image = image_or_filename
- return self.model.predict_on_image(image)
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py
index 7647d7e..e69de29 100644
--- a/text_recognizer/models/__init__.py
+++ b/text_recognizer/models/__init__.py
@@ -1,18 +0,0 @@
-"""Model modules."""
-from .base import Model
-from .character_model import CharacterModel
-from .crnn_model import CRNNModel
-from .ctc_transformer_model import CTCTransformerModel
-from .segmentation_model import SegmentationModel
-from .transformer_model import TransformerModel
-from .vqvae_model import VQVAEModel
-
-__all__ = [
- "CharacterModel",
- "CRNNModel",
- "CTCTransformerModel",
- "Model",
- "SegmentationModel",
- "TransformerModel",
- "VQVAEModel",
-]
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
deleted file mode 100644
index 70f4cdb..0000000
--- a/text_recognizer/models/base.py
+++ /dev/null
@@ -1,455 +0,0 @@
-"""Abstract Model class for PyTorch neural networks."""
-
-from abc import ABC, abstractmethod
-from glob import glob
-import importlib
-from pathlib import Path
-import re
-import shutil
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
-
-from loguru import logger
-import torch
-from torch import nn
-from torch import Tensor
-from torch.optim.swa_utils import AveragedModel, SWALR
-from torch.utils.data import DataLoader, Dataset, random_split
-from torchsummary import summary
-
-from text_recognizer import datasets
-from text_recognizer import networks
-from text_recognizer.datasets import EmnistMapper
-
-WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
-
-
-class Model(ABC):
- """Abstract Model class with composition of different parts defining a PyTorch neural network."""
-
- def __init__(
- self,
- network_fn: str,
- dataset: str,
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- """Base class, to be inherited by model for specific type of data.
-
- Args:
- network_fn (str): The name of network.
- dataset (str): The name dataset class.
- network_args (Optional[Dict]): Arguments for the network. Defaults to None.
- dataset_args (Optional[Dict]): Arguments for the dataset.
- metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
- criterion (Optional[Callable]): The criterion to evaluate the performance of the network.
- Defaults to None.
- criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.
- optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None.
- optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None.
- lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None.
- lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to
- None.
- swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to
- None.
- device (Optional[str]): Name of the device to train on. Defaults to None.
-
- """
- self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}"
- # Has to be set in subclass.
- self._mapper = None
-
- # Placeholder.
- self._input_shape = None
-
- self.dataset_name = dataset
- self.dataset = None
- self.dataset_args = dataset_args
-
- # Placeholders for datasets.
- self.train_dataset = None
- self.val_dataset = None
- self.test_dataset = None
-
- # Stochastic Weight Averaging placeholders.
- self.swa_args = swa_args
- self._swa_scheduler = None
- self._swa_network = None
- self._use_swa_model = False
-
- # Experiment directory.
- self.model_dir = None
-
- # Flag for configured model.
- self.is_configured = False
- self.data_prepared = False
-
- # Flag for stopping training.
- self.stop_training = False
-
- self._metrics = metrics if metrics is not None else None
-
- # Set the device.
- self._device = (
- torch.device("cuda" if torch.cuda.is_available() else "cpu")
- if device is None
- else device
- )
-
- # Configure network.
- self._network = None
- self._network_args = network_args
- self._configure_network(network_fn)
-
- # Place network on device (GPU).
- self.to_device()
-
- # Loss and Optimizer placeholders for before loading.
- self._criterion = criterion
- self.criterion_args = criterion_args
-
- self._optimizer = optimizer
- self.optimizer_args = optimizer_args
-
- self._lr_scheduler = lr_scheduler
- self.lr_scheduler_args = lr_scheduler_args
-
- def configure_model(self) -> None:
- """Configures criterion and optimizers."""
- if not self.is_configured:
- self._configure_criterion()
- self._configure_optimizers()
-
- # Set this flag to true to prevent the model from configuring again.
- self.is_configured = True
-
- def prepare_data(self) -> None:
- """Prepare data for training."""
- # TODO add downloading.
- if not self.data_prepared:
- # Load dataset module.
- self.dataset = getattr(datasets, self.dataset_name)
-
- # Load train dataset.
- train_dataset = self.dataset(train=True, **self.dataset_args["args"])
- train_dataset.load_or_generate_data()
-
- # Set input shape.
- self._input_shape = train_dataset.input_shape
-
- # Split train dataset into a training and validation partition.
- dataset_len = len(train_dataset)
- train_len = int(
- self.dataset_args["train_args"]["train_fraction"] * dataset_len
- )
- val_len = dataset_len - train_len
- self.train_dataset, self.val_dataset = random_split(
- train_dataset, lengths=[train_len, val_len]
- )
-
- # Load test dataset.
- self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
- self.test_dataset.load_or_generate_data()
-
- # Set the flag to true to disable ability to load data again.
- self.data_prepared = True
-
- def train_dataloader(self) -> DataLoader:
- """Returns data loader for training set."""
- return DataLoader(
- self.train_dataset,
- batch_size=self.dataset_args["train_args"]["batch_size"],
- num_workers=self.dataset_args["train_args"]["num_workers"],
- shuffle=True,
- pin_memory=True,
- )
-
- def val_dataloader(self) -> DataLoader:
- """Returns data loader for validation set."""
- return DataLoader(
- self.val_dataset,
- batch_size=self.dataset_args["train_args"]["batch_size"],
- num_workers=self.dataset_args["train_args"]["num_workers"],
- shuffle=True,
- pin_memory=True,
- )
-
- def test_dataloader(self) -> DataLoader:
- """Returns data loader for test set."""
- return DataLoader(
- self.test_dataset,
- batch_size=self.dataset_args["train_args"]["batch_size"],
- num_workers=self.dataset_args["train_args"]["num_workers"],
- shuffle=False,
- pin_memory=True,
- )
-
- def _configure_network(self, network_fn: Type[nn.Module]) -> None:
- """Loads the network."""
- # If no network arguments are given, load pretrained weights if they exist.
- # Load network module.
- network_fn = getattr(networks, network_fn)
- if self._network_args is None:
- self.load_weights(network_fn)
- else:
- self._network = network_fn(**self._network_args)
-
- def _configure_criterion(self) -> None:
- """Loads the criterion."""
- self._criterion = (
- self._criterion(**self.criterion_args)
- if self._criterion is not None
- else None
- )
-
- def _configure_optimizers(self,) -> None:
- """Loads the optimizers."""
- if self._optimizer is not None:
- self._optimizer = self._optimizer(
- self._network.parameters(), **self.optimizer_args
- )
- else:
- self._optimizer = None
-
- if self._optimizer and self._lr_scheduler is not None:
- if "steps_per_epoch" in self.lr_scheduler_args:
- self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
-
- # Assume lr scheduler should update at each epoch if not specified.
- if "interval" not in self.lr_scheduler_args:
- interval = "epoch"
- else:
- interval = self.lr_scheduler_args.pop("interval")
- self._lr_scheduler = {
- "lr_scheduler": self._lr_scheduler(
- self._optimizer, **self.lr_scheduler_args
- ),
- "interval": interval,
- }
-
- if self.swa_args is not None:
- self._swa_scheduler = {
- "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]),
- "swa_start": self.swa_args["start"],
- }
- self._swa_network = AveragedModel(self._network).to(self.device)
-
- @property
- def name(self) -> str:
- """Returns the name of the model."""
- return self._name
-
- @property
- def input_shape(self) -> Tuple[int, ...]:
- """The input shape."""
- return self._input_shape
-
- @property
- def mapper(self) -> EmnistMapper:
- """Returns the mapper that maps between ints and chars."""
- return self._mapper
-
- @property
- def mapping(self) -> Dict:
- """Returns the mapping between network output and Emnist character."""
- return self._mapper.mapping if self._mapper is not None else None
-
- def eval(self) -> None:
- """Sets the network to evaluation mode."""
- self._network.eval()
-
- def train(self) -> None:
- """Sets the network to train mode."""
- self._network.train()
-
- @property
- def device(self) -> str:
- """Device where the weights are stored, i.e. cpu or cuda."""
- return self._device
-
- @property
- def metrics(self) -> Optional[Dict]:
- """Metrics."""
- return self._metrics
-
- @property
- def criterion(self) -> Optional[Callable]:
- """Criterion."""
- return self._criterion
-
- @property
- def optimizer(self) -> Optional[Callable]:
- """Optimizer."""
- return self._optimizer
-
- @property
- def lr_scheduler(self) -> Optional[Dict]:
- """Returns a directory with the learning rate scheduler."""
- return self._lr_scheduler
-
- @property
- def swa_scheduler(self) -> Optional[Dict]:
- """Returns a directory with the stochastic weight averaging scheduler."""
- return self._swa_scheduler
-
- @property
- def swa_network(self) -> Optional[Callable]:
- """Returns the stochastic weight averaging network."""
- return self._swa_network
-
- @property
- def network(self) -> Type[nn.Module]:
- """Neural network."""
- # Returns the SWA network if available.
- return self._network
-
- @property
- def weights_filename(self) -> str:
- """Filepath to the network weights."""
- WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
- return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
-
- def use_swa_model(self) -> None:
- """Set to use predictions from SWA model."""
- if self.swa_network is not None:
- self._use_swa_model = True
-
- def forward(self, x: Tensor) -> Tensor:
- """Feedforward pass with the network."""
- if self._use_swa_model:
- return self.swa_network(x)
- else:
- return self.network(x)
-
- def summary(
- self,
- input_shape: Optional[Union[List, Tuple]] = None,
- depth: int = 3,
- device: Optional[str] = None,
- ) -> None:
- """Prints a summary of the network architecture."""
- device = self.device if device is None else device
-
- if input_shape is not None:
- summary(self.network, input_shape, depth=depth, device=device)
- elif self._input_shape is not None:
- input_shape = tuple(self._input_shape)
- summary(self.network, input_shape, depth=depth, device=device)
- else:
- logger.warning("Could not print summary as input shape is not set.")
-
- def to_device(self) -> None:
- """Places the network on the device (GPU)."""
- self._network.to(self._device)
-
- def _get_state_dict(self) -> Dict:
- """Get the state dict of the model."""
- state = {"model_state": self._network.state_dict()}
-
- if self._optimizer is not None:
- state["optimizer_state"] = self._optimizer.state_dict()
-
- if self._lr_scheduler is not None:
- state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict()
- state["scheduler_interval"] = self._lr_scheduler["interval"]
-
- if self._swa_network is not None:
- state["swa_network"] = self._swa_network.state_dict()
-
- return state
-
- def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
- """Load a previously saved checkpoint.
-
- Args:
- checkpoint_path (Path): Path to the experiment with the checkpoint.
-
- """
- checkpoint_path = Path(checkpoint_path)
- self.prepare_data()
- self.configure_model()
- logger.debug("Loading checkpoint...")
- if not checkpoint_path.exists():
- logger.debug("File does not exist {str(checkpoint_path)}")
-
- checkpoint = torch.load(str(checkpoint_path), map_location=self.device)
- self._network.load_state_dict(checkpoint["model_state"])
-
- if self._optimizer is not None:
- self._optimizer.load_state_dict(checkpoint["optimizer_state"])
-
- if self._lr_scheduler is not None:
- # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs
- # with OneCycleLR.
- if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
- self._lr_scheduler["lr_scheduler"].load_state_dict(
- checkpoint["scheduler_state"]
- )
- self._lr_scheduler["interval"] = checkpoint["scheduler_interval"]
-
- if self._swa_network is not None:
- self._swa_network.load_state_dict(checkpoint["swa_network"])
-
- def save_checkpoint(
- self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str
- ) -> None:
- """Saves a checkpoint of the model.
-
- Args:
- checkpoint_path (Path): Path to the experiment with the checkpoint.
- is_best (bool): If it is the currently best model.
- epoch (int): The epoch of the checkpoint.
- val_metric (str): Validation metric.
-
- """
- state = self._get_state_dict()
- state["is_best"] = is_best
- state["epoch"] = epoch
- state["network_args"] = self._network_args
-
- checkpoint_path.mkdir(parents=True, exist_ok=True)
-
- logger.debug("Saving checkpoint...")
- filepath = str(checkpoint_path / "last.pt")
- torch.save(state, filepath)
-
- if is_best:
- logger.debug(
- f"Found a new best {val_metric}. Saving best checkpoint and weights."
- )
- shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
-
- def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None:
- """Load the network weights."""
- logger.debug("Loading network with pretrained weights.")
- filename = glob(self.weights_filename)[0]
- if not filename:
- raise FileNotFoundError(
- f"Could not find any pretrained weights at {self.weights_filename}"
- )
- # Loading state directory.
- state_dict = torch.load(filename, map_location=torch.device(self._device))
- self._network_args = state_dict["network_args"]
- weights = state_dict["model_state"]
-
- # Initializes the network with trained weights.
- if network_fn is not None:
- self._network = network_fn(**self._network_args)
- self._network.load_state_dict(weights)
-
- if "swa_network" in state_dict:
- self._swa_network = AveragedModel(self._network).to(self.device)
- self._swa_network.load_state_dict(state_dict["swa_network"])
-
- def save_weights(self, path: Path) -> None:
- """Save the network weights."""
- logger.debug("Saving the best network weights.")
- shutil.copyfile(str(path / "best.pt"), self.weights_filename)
diff --git a/text_recognizer/models/character_model.py b/text_recognizer/models/character_model.py
deleted file mode 100644
index f9944f3..0000000
--- a/text_recognizer/models/character_model.py
+++ /dev/null
@@ -1,88 +0,0 @@
-"""Defines the CharacterModel class."""
-from typing import Callable, Dict, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets import EmnistMapper
-from text_recognizer.models.base import Model
-
-
-class CharacterModel(Model):
- """Model for predicting characters from images."""
-
- def __init__(
- self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- """Initializes the CharacterModel."""
-
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- self.pad_token = dataset_args["args"]["pad_token"]
- if self._mapper is None:
- self._mapper = EmnistMapper(pad_token=self.pad_token,)
- self.tensor_transform = ToTensor()
- self.softmax = nn.Softmax(dim=0)
-
- @torch.no_grad()
- def predict_on_image(
- self, image: Union[np.ndarray, torch.Tensor]
- ) -> Tuple[str, float]:
- """Character prediction on an image.
-
- Args:
- image (Union[np.ndarray, torch.Tensor]): An image containing a character.
-
- Returns:
- Tuple[str, float]: The predicted character and the confidence in the prediction.
-
- """
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- logits = self.forward(image)
-
- prediction = self.softmax(logits.squeeze(0))
-
- index = int(torch.argmax(prediction, dim=0))
- confidence_of_prediction = prediction[index]
- predicted_character = self.mapper(index)
-
- return predicted_character, confidence_of_prediction
diff --git a/text_recognizer/models/crnn_model.py b/text_recognizer/models/crnn_model.py
deleted file mode 100644
index 1e01a83..0000000
--- a/text_recognizer/models/crnn_model.py
+++ /dev/null
@@ -1,119 +0,0 @@
-"""Defines the CRNNModel class."""
-from typing import Callable, Dict, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets import EmnistMapper
-from text_recognizer.models.base import Model
-from text_recognizer.networks import greedy_decoder
-
-
-class CRNNModel(Model):
- """Model for predicting a sequence of characters from an image of a text line."""
-
- def __init__(
- self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
-
- self.pad_token = dataset_args["args"]["pad_token"]
- if self._mapper is None:
- self._mapper = EmnistMapper(pad_token=self.pad_token,)
- self.tensor_transform = ToTensor()
-
- def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
- """Computes the CTC loss.
-
- Args:
- output (Tensor): Model predictions.
- targets (Tensor): Correct output sequence.
-
- Returns:
- Tensor: The CTC loss.
-
- """
-
- # Input lengths on the form [T, B]
- input_lengths = torch.full(
- size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
- )
-
- # Configure target tensors for ctc loss.
- targets_ = Tensor([]).to(self.device)
- target_lengths = []
- for t in targets:
- # Remove padding symbol as it acts as the blank symbol.
- t = t[t < 79]
- targets_ = torch.cat([targets_, t])
- target_lengths.append(len(t))
-
- targets = targets_.type(dtype=torch.long)
- target_lengths = (
- torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
- )
-
- return self._criterion(output, targets, input_lengths, target_lengths)
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
- """Predict on a single input."""
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
-
- # Rescale image between 0 and 1.
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- log_probs = self.forward(image)
-
- raw_pred, _ = greedy_decoder(
- predictions=log_probs,
- character_mapper=self.mapper,
- blank_label=79,
- collapse_repeated=True,
- )
-
- log_probs, _ = log_probs.max(dim=2)
-
- predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
-
- return predicted_characters, confidence_of_prediction
diff --git a/text_recognizer/models/ctc_transformer_model.py b/text_recognizer/models/ctc_transformer_model.py
deleted file mode 100644
index 25925f2..0000000
--- a/text_recognizer/models/ctc_transformer_model.py
+++ /dev/null
@@ -1,120 +0,0 @@
-"""Defines the CTC Transformer Model class."""
-from typing import Callable, Dict, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets import EmnistMapper
-from text_recognizer.models.base import Model
-from text_recognizer.networks import greedy_decoder
-
-
-class CTCTransformerModel(Model):
- """Model for predicting a sequence of characters from an image of a text line."""
-
- def __init__(
- self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- self.pad_token = dataset_args["args"]["pad_token"]
- self.lower = dataset_args["args"]["lower"]
-
- if self._mapper is None:
- self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,)
-
- self.tensor_transform = ToTensor()
-
- def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
- """Computes the CTC loss.
-
- Args:
- output (Tensor): Model predictions.
- targets (Tensor): Correct output sequence.
-
- Returns:
- Tensor: The CTC loss.
-
- """
- # Input lengths on the form [T, B]
- input_lengths = torch.full(
- size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
- )
-
- # Configure target tensors for ctc loss.
- targets_ = Tensor([]).to(self.device)
- target_lengths = []
- for t in targets:
- # Remove padding symbol as it acts as the blank symbol.
- t = t[t < 53]
- targets_ = torch.cat([targets_, t])
- target_lengths.append(len(t))
-
- targets = targets_.type(dtype=torch.long)
- target_lengths = (
- torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
- )
-
- return self._criterion(output, targets, input_lengths, target_lengths)
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
- """Predict on a single input."""
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
-
- # Rescale image between 0 and 1.
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- log_probs = self.forward(image)
-
- raw_pred, _ = greedy_decoder(
- predictions=log_probs,
- character_mapper=self.mapper,
- blank_label=53,
- collapse_repeated=True,
- )
-
- log_probs, _ = log_probs.max(dim=2)
-
- predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
-
- return predicted_characters, confidence_of_prediction
diff --git a/text_recognizer/models/segmentation_model.py b/text_recognizer/models/segmentation_model.py
deleted file mode 100644
index 613108a..0000000
--- a/text_recognizer/models/segmentation_model.py
+++ /dev/null
@@ -1,75 +0,0 @@
-"""Segmentation model for detecting and segmenting lines."""
-from typing import Callable, Dict, Optional, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.models.base import Model
-
-
-class SegmentationModel(Model):
- """Model for segmenting lines in an image."""
-
- def __init__(
- self,
- network_fn: str,
- dataset: str,
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- self.tensor_transform = ToTensor()
- self.softmax = nn.Softmax(dim=2)
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor:
- """Predict on a single input."""
- self.eval()
-
- if image.dtype is np.uint8:
- # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
-
- # Rescale image between 0 and 1.
- if image.dtype is torch.uint8 or image.dtype is torch.int64:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- if not torch.is_tensor(image):
- image = Tensor(image)
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
-
- logits = self.forward(image)
-
- segmentation_mask = torch.argmax(logits, dim=1)
-
- return segmentation_mask
diff --git a/text_recognizer/models/transformer_model.py b/text_recognizer/models/transformer_model.py
deleted file mode 100644
index 3f63053..0000000
--- a/text_recognizer/models/transformer_model.py
+++ /dev/null
@@ -1,124 +0,0 @@
-"""Defines the CNN-Transformer class."""
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-from torch.utils.data import Dataset
-
-from text_recognizer.datasets import EmnistMapper
-import text_recognizer.datasets.transforms as transforms
-from text_recognizer.models.base import Model
-from text_recognizer.networks import greedy_decoder
-
-
-class TransformerModel(Model):
- """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer."""
-
- def __init__(
- self,
- network_fn: str,
- dataset: str,
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- self.init_token = dataset_args["args"]["init_token"]
- self.pad_token = dataset_args["args"]["pad_token"]
- self.eos_token = dataset_args["args"]["eos_token"]
- self.lower = dataset_args["args"]["lower"]
- self.max_len = 100
-
- if self._mapper is None:
- self._mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- lower=self.lower,
- )
- self.tensor_transform = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])]
- )
- self.softmax = nn.Softmax(dim=2)
-
- @torch.no_grad()
- def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
- src = self.network.extract_image_features(image)
-
- # Added for vqvae transformer.
- if isinstance(src, Tuple):
- src = src[0]
-
- memory = self.network.encoder(src)
-
- confidence_of_predictions = []
- trg_indices = [self.mapper(self.init_token)]
-
- for _ in range(self.max_len - 1):
- trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
- trg = self.network.target_embedding(trg)
- logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)
-
- # Convert logits to probabilities.
- probs = self.softmax(logits)
-
- pred_token = probs.argmax(2)[:, -1].item()
- confidence = probs.max(2).values[:, -1].item()
-
- trg_indices.append(pred_token)
- confidence_of_predictions.append(confidence)
-
- if pred_token == self.mapper(self.eos_token):
- break
-
- confidence = np.min(confidence_of_predictions)
- predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]])
-
- return predicted_characters, confidence
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
- """Predict on a single input."""
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
-
- # Rescale image between 0 and 1.
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
-
- (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
- image
- )
-
- return predicted_characters, confidence_of_prediction
diff --git a/text_recognizer/models/vqvae_model.py b/text_recognizer/models/vqvae_model.py
deleted file mode 100644
index 70f6f1f..0000000
--- a/text_recognizer/models/vqvae_model.py
+++ /dev/null
@@ -1,80 +0,0 @@
-"""Defines the VQVAEModel class."""
-from typing import Callable, Dict, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets import EmnistMapper
-from text_recognizer.models.base import Model
-
-
-class VQVAEModel(Model):
- """Model for reconstructing images from codebook."""
-
- def __init__(
- self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- """Initializes the CharacterModel."""
-
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- self.pad_token = dataset_args["args"]["pad_token"]
- if self._mapper is None:
- self._mapper = EmnistMapper(pad_token=self.pad_token,)
- self.tensor_transform = ToTensor()
- self.softmax = nn.Softmax(dim=0)
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
- """Reconstruction of image.
-
- Args:
- image (Union[np.ndarray, torch.Tensor]): An image containing a character.
-
- Returns:
- Tuple[str, float]: The predicted character and the confidence in the prediction.
-
- """
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- image_reconstructed, _ = self.forward(image)
-
- return image_reconstructed
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index 1521355..e69de29 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -1,43 +0,0 @@
-"""Network modules."""
-from .cnn import CNN
-from .cnn_transformer import CNNTransformer
-from .crnn import ConvolutionalRecurrentNetwork
-from .ctc import greedy_decoder
-from .densenet import DenseNet
-from .lenet import LeNet
-from .metrics import accuracy, cer, wer
-from .mlp import MLP
-from .residual_network import ResidualNetwork, ResidualNetworkEncoder
-from .transducer import load_transducer_loss, TDS2d
-from .transformer import Transformer
-from .unet import UNet
-from .util import sliding_window
-from .vit import ViT
-from .vq_transformer import VQTransformer
-from .vqvae import VQVAE
-from .wide_resnet import WideResidualNetwork
-
-__all__ = [
- "accuracy",
- "cer",
- "CNN",
- "CNNTransformer",
- "ConvolutionalRecurrentNetwork",
- "DenseNet",
- "FCN",
- "greedy_decoder",
- "MLP",
- "LeNet",
- "load_transducer_loss",
- "ResidualNetwork",
- "ResidualNetworkEncoder",
- "sliding_window",
- "UNet",
- "TDS2d",
- "Transformer",
- "ViT",
- "VQTransformer",
- "VQVAE",
- "wer",
- "WideResidualNetwork",
-]
diff --git a/text_recognizer/networks/beam.py b/text_recognizer/networks/beam.py
deleted file mode 100644
index dccccdb..0000000
--- a/text_recognizer/networks/beam.py
+++ /dev/null
@@ -1,83 +0,0 @@
-"""Implementation of beam search decoder for a sequence to sequence network.
-
-Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py
-
-"""
-# from typing import List
-# from Queue import PriorityQueue
-
-# from loguru import logger
-# import torch
-# from torch import nn
-# from torch import Tensor
-# import torch.nn.functional as F
-
-
-# class Node:
-# def __init__(
-# self, parent: Node, target_index: int, log_prob: Tensor, length: int
-# ) -> None:
-# self.parent = parent
-# self.target_index = target_index
-# self.log_prob = log_prob
-# self.length = length
-# self.reward = 0.0
-
-# def eval(self, alpha: float = 1.0) -> Tensor:
-# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward
-
-
-# @torch.no_grad()
-# def beam_decoder(
-# network, mapper, device, memory: Tensor = None, max_len: int = 97,
-# ) -> Tensor:
-# beam_width = 10
-# topk = 1 # How many sentences to generate.
-
-# trg_indices = [mapper(mapper.init_token)]
-
-# end_nodes = []
-
-# node = Node(None, trg_indices, 0, 1)
-# nodes = PriorityQueue()
-
-# nodes.put((node.eval(), node))
-# q_size = 1
-
-# # Beam search
-# for _ in range(max_len):
-# if q_size > 2000:
-# logger.warning("Could not decoder input")
-# break
-
-# # Fetch the best node.
-# score, n = nodes.get()
-# decoder_input = n.target_index
-
-# if n.target_index == mapper(mapper.eos_token) and n.parent is not None:
-# end_nodes.append((score, n))
-
-# # If we reached the maximum number of sentences required.
-# if len(end_nodes) >= 1:
-# break
-# else:
-# continue
-
-# # Forward pass with transformer.
-# trg = torch.tensor(trg_indices, device=device)[None, :].long()
-# trg = network.target_embedding(trg)
-# logits = network.decoder(trg=trg, memory=memory, trg_mask=None)
-# log_prob = F.log_softmax(logits, dim=2)
-
-# log_prob, indices = torch.topk(log_prob, beam_width)
-
-# for new_k in range(beam_width):
-# # TODO: continue from here
-# token_index = indices[0][new_k].view(1, -1)
-# log_p = log_prob[0][new_k].item()
-
-# node = Node()
-
-# pass
-
-# pass
diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py
deleted file mode 100644
index 1807bb9..0000000
--- a/text_recognizer/networks/cnn.py
+++ /dev/null
@@ -1,101 +0,0 @@
-"""Implementation of a simple backbone cnn network."""
-from typing import Callable, Dict, Optional, Tuple
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-
-from text_recognizer.networks.util import activation_function
-
-
-class CNN(nn.Module):
- """LeNet network for character prediction."""
-
- def __init__(
- self,
- channels: Tuple[int, ...] = (1, 32, 64, 128),
- kernel_sizes: Tuple[int, ...] = (4, 4, 4),
- strides: Tuple[int, ...] = (2, 2, 2),
- max_pool_kernel: int = 2,
- dropout_rate: float = 0.2,
- activation: Optional[str] = "relu",
- ) -> None:
- """Initialization of the LeNet network.
-
- Args:
- channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
- kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
- strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2).
- max_pool_kernel (int): 2D max pooling kernel. Defaults to 2.
- dropout_rate (float): The dropout rate. Defaults to 0.2.
- activation (Optional[str]): The name of non-linear activation function. Defaults to relu.
-
- Raises:
- RuntimeError: if the number of hyperparameters does not match in length.
-
- """
- super().__init__()
-
- if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides):
- raise RuntimeError("The number of the hyperparameters does not match.")
-
- self.cnn = self._build_network(
- channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation,
- )
-
- def _build_network(
- self,
- channels: Tuple[int, ...],
- kernel_sizes: Tuple[int, ...],
- strides: Tuple[int, ...],
- max_pool_kernel: int,
- dropout_rate: float,
- activation: str,
- ) -> nn.Sequential:
- # Load activation function.
- activation_fn = activation_function(activation)
-
- channels = list(channels)
- in_channels = channels.pop(0)
- configuration = zip(channels, kernel_sizes, strides)
-
- modules = nn.ModuleList([])
-
- for i, (out_channels, kernel_size, stride) in enumerate(configuration):
- # Add max pool to reduce output size.
- if i == len(channels) // 2:
- modules.append(nn.MaxPool2d(max_pool_kernel))
- if i == 0:
- modules.append(
- nn.Conv2d(
- in_channels, out_channels, kernel_size, stride=stride, padding=1
- )
- )
- else:
- modules.append(
- nn.Sequential(
- activation_fn,
- nn.BatchNorm2d(in_channels),
- nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- padding=1,
- ),
- )
- )
-
- if dropout_rate:
- modules.append(nn.Dropout2d(p=dropout_rate))
-
- in_channels = out_channels
-
- return nn.Sequential(*modules)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward pass."""
- # If batch dimenstion is missing, it needs to be added.
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- return self.cnn(x)
diff --git a/text_recognizer/networks/crnn.py b/text_recognizer/networks/crnn.py
deleted file mode 100644
index 778e232..0000000
--- a/text_recognizer/networks/crnn.py
+++ /dev/null
@@ -1,110 +0,0 @@
-"""CRNN for handwritten text recognition."""
-from typing import Dict, Tuple
-
-from einops import rearrange, reduce
-from einops.layers.torch import Rearrange
-from loguru import logger
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import configure_backbone
-
-
-class ConvolutionalRecurrentNetwork(nn.Module):
- """Network that takes a image of a text line and predicts tokens that are in the image."""
-
- def __init__(
- self,
- backbone: str,
- backbone_args: Dict = None,
- input_size: int = 128,
- hidden_size: int = 128,
- bidirectional: bool = False,
- num_layers: int = 1,
- num_classes: int = 80,
- patch_size: Tuple[int, int] = (28, 28),
- stride: Tuple[int, int] = (1, 14),
- recurrent_cell: str = "lstm",
- avg_pool: bool = False,
- use_sliding_window: bool = True,
- ) -> None:
- super().__init__()
- self.backbone_args = backbone_args or {}
- self.patch_size = patch_size
- self.stride = stride
- self.sliding_window = (
- self._configure_sliding_window() if use_sliding_window else None
- )
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.backbone = configure_backbone(backbone, backbone_args)
- self.bidirectional = bidirectional
- self.avg_pool = avg_pool
-
- if recurrent_cell.upper() in ["LSTM", "GRU"]:
- recurrent_cell = getattr(nn, recurrent_cell)
- else:
- logger.warning(
- f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
- )
- recurrent_cell = nn.LSTM
-
- self.rnn = recurrent_cell(
- input_size=self.input_size,
- hidden_size=self.hidden_size,
- bidirectional=bidirectional,
- num_layers=num_layers,
- )
-
- decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
-
- self.decoder = nn.Sequential(
- nn.Linear(in_features=decoder_size, out_features=num_classes),
- nn.LogSoftmax(dim=2),
- )
-
- def _configure_sliding_window(self) -> nn.Sequential:
- return nn.Sequential(
- nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
- Rearrange(
- "b (c h w) t -> b t c h w",
- h=self.patch_size[0],
- w=self.patch_size[1],
- c=1,
- ),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
-
- if self.sliding_window is not None:
- # Create image patches with a sliding window kernel.
- x = self.sliding_window(x)
-
- # Rearrange from a sequence of patches for feedforward network.
- b, t = x.shape[:2]
- x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
-
- x = self.backbone(x)
-
- # Average pooling.
- if self.avg_pool:
- x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
- else:
- x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
- else:
- # Encode the entire image with a CNN, and use the channels as temporal dimension.
- x = self.backbone(x)
- x = rearrange(x, "b c h w -> b w c h")
- if self.adaptive_pool is not None:
- x = self.adaptive_pool(x)
- x = x.squeeze(3)
-
- # Sequence predictions.
- x, _ = self.rnn(x)
-
- # Sequence to classification layer.
- x = self.decoder(x)
- return x
diff --git a/text_recognizer/networks/ctc.py b/text_recognizer/networks/ctc.py
deleted file mode 100644
index af9b700..0000000
--- a/text_recognizer/networks/ctc.py
+++ /dev/null
@@ -1,58 +0,0 @@
-"""Decodes the CTC output."""
-from typing import Callable, List, Optional, Tuple
-
-from einops import rearrange
-import torch
-from torch import Tensor
-
-from text_recognizer.datasets.util import EmnistMapper
-
-
-def greedy_decoder(
- predictions: Tensor,
- targets: Optional[Tensor] = None,
- target_lengths: Optional[Tensor] = None,
- character_mapper: Optional[Callable] = None,
- blank_label: int = 79,
- collapse_repeated: bool = True,
-) -> Tuple[List[str], List[str]]:
- """Greedy CTC decoder.
-
- Args:
- predictions (Tensor): Tenor of network predictions, shape [time, batch, classes].
- targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None.
- target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
- character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults
- to None.
- blank_label (int): The blank character to be ignored. Defaults to 80.
- collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
-
- Returns:
- Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets.
-
- """
-
- if character_mapper is None:
- character_mapper = EmnistMapper(pad_token="_") # noqa: S106
-
- predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
- decoded_predictions = []
- decoded_targets = []
- for i, prediction in enumerate(predictions):
- decoded_prediction = []
- decoded_target = []
- if targets is not None and target_lengths is not None:
- for target_index in targets[i][: target_lengths[i]]:
- if target_index == blank_label:
- continue
- decoded_target.append(character_mapper(int(target_index)))
- decoded_targets.append(decoded_target)
- for j, index in enumerate(prediction):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == prediction[j - 1]:
- continue
- decoded_prediction.append(index.item())
- decoded_predictions.append(
- [character_mapper(int(pred_index)) for pred_index in decoded_prediction]
- )
- return decoded_predictions, decoded_targets
diff --git a/text_recognizer/networks/densenet.py b/text_recognizer/networks/densenet.py
deleted file mode 100644
index 7dc58d9..0000000
--- a/text_recognizer/networks/densenet.py
+++ /dev/null
@@ -1,225 +0,0 @@
-"""Defines a Densely Connected Convolutional Networks in PyTorch.
-
-Sources:
-https://arxiv.org/abs/1608.06993
-https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
-
-"""
-from typing import List, Optional, Union
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-
-
-class _DenseLayer(nn.Module):
- """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2."""
-
- def __init__(
- self,
- in_channels: int,
- growth_rate: int,
- bn_size: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- activation_fn = activation_function(activation)
- self.dense_layer = [
- nn.BatchNorm2d(in_channels),
- activation_fn,
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=bn_size * growth_rate,
- kernel_size=1,
- stride=1,
- bias=False,
- ),
- nn.BatchNorm2d(bn_size * growth_rate),
- activation_fn,
- nn.Conv2d(
- in_channels=bn_size * growth_rate,
- out_channels=growth_rate,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- ),
- ]
- if dropout_rate:
- self.dense_layer.append(nn.Dropout(p=dropout_rate))
-
- self.dense_layer = nn.Sequential(*self.dense_layer)
-
- def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor:
- if isinstance(x, list):
- x = torch.cat(x, 1)
- return self.dense_layer(x)
-
-
-class _DenseBlock(nn.Module):
- def __init__(
- self,
- num_layers: int,
- in_channels: int,
- bn_size: int,
- growth_rate: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.dense_block = self._build_dense_blocks(
- num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation,
- )
-
- def _build_dense_blocks(
- self,
- num_layers: int,
- in_channels: int,
- bn_size: int,
- growth_rate: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> nn.ModuleList:
- dense_block = []
- for i in range(num_layers):
- dense_block.append(
- _DenseLayer(
- in_channels=in_channels + i * growth_rate,
- growth_rate=growth_rate,
- bn_size=bn_size,
- dropout_rate=dropout_rate,
- activation=activation,
- )
- )
- return nn.ModuleList(dense_block)
-
- def forward(self, x: Tensor) -> Tensor:
- feature_maps = [x]
- for layer in self.dense_block:
- x = layer(feature_maps)
- feature_maps.append(x)
- return torch.cat(feature_maps, 1)
-
-
-class _Transition(nn.Module):
- def __init__(
- self, in_channels: int, out_channels: int, activation: str = "relu",
- ) -> None:
- super().__init__()
- activation_fn = activation_function(activation)
- self.transition = nn.Sequential(
- nn.BatchNorm2d(in_channels),
- activation_fn,
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=1,
- stride=1,
- bias=False,
- ),
- nn.AvgPool2d(kernel_size=2, stride=2),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.transition(x)
-
-
-class DenseNet(nn.Module):
- """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow."""
-
- def __init__(
- self,
- growth_rate: int = 32,
- block_config: List[int] = (6, 12, 24, 16),
- in_channels: int = 1,
- base_channels: int = 64,
- num_classes: int = 80,
- bn_size: int = 4,
- dropout_rate: float = 0,
- classifier: bool = True,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.densenet = self._configure_densenet(
- in_channels,
- base_channels,
- num_classes,
- growth_rate,
- block_config,
- bn_size,
- dropout_rate,
- classifier,
- activation,
- )
-
- def _configure_densenet(
- self,
- in_channels: int,
- base_channels: int,
- num_classes: int,
- growth_rate: int,
- block_config: List[int],
- bn_size: int,
- dropout_rate: float,
- classifier: bool,
- activation: str,
- ) -> nn.Sequential:
- activation_fn = activation_function(activation)
- densenet = [
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=base_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- ),
- nn.BatchNorm2d(base_channels),
- activation_fn,
- ]
-
- num_features = base_channels
-
- for i, num_layers in enumerate(block_config):
- densenet.append(
- _DenseBlock(
- num_layers=num_layers,
- in_channels=num_features,
- bn_size=bn_size,
- growth_rate=growth_rate,
- dropout_rate=dropout_rate,
- activation=activation,
- )
- )
- num_features = num_features + num_layers * growth_rate
- if i != len(block_config) - 1:
- densenet.append(
- _Transition(
- in_channels=num_features,
- out_channels=num_features // 2,
- activation=activation,
- )
- )
- num_features = num_features // 2
-
- densenet.append(activation_fn)
-
- if classifier:
- densenet.append(nn.AdaptiveAvgPool2d((1, 1)))
- densenet.append(Rearrange("b c h w -> b (c h w)"))
- densenet.append(
- nn.Linear(in_features=num_features, out_features=num_classes)
- )
-
- return nn.Sequential(*densenet)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass of Densenet."""
- # If batch dimenstion is missing, it will be added.
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- return self.densenet(x)
diff --git a/text_recognizer/networks/lenet.py b/text_recognizer/networks/lenet.py
deleted file mode 100644
index 527e1a0..0000000
--- a/text_recognizer/networks/lenet.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""Implementation of the LeNet network."""
-from typing import Callable, Dict, Optional, Tuple
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-
-from text_recognizer.networks.util import activation_function
-
-
-class LeNet(nn.Module):
- """LeNet network for character prediction."""
-
- def __init__(
- self,
- channels: Tuple[int, ...] = (1, 32, 64),
- kernel_sizes: Tuple[int, ...] = (3, 3, 2),
- hidden_size: Tuple[int, ...] = (9216, 128),
- dropout_rate: float = 0.2,
- num_classes: int = 10,
- activation_fn: Optional[str] = "relu",
- ) -> None:
- """Initialization of the LeNet network.
-
- Args:
- channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
- kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
- hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
- Defaults to (9216, 128).
- dropout_rate (float): The dropout rate. Defaults to 0.2.
- num_classes (int): Number of classes. Defaults to 10.
- activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu.
-
- """
- super().__init__()
-
- activation_fn = activation_function(activation_fn)
-
- self.layers = [
- nn.Conv2d(
- in_channels=channels[0],
- out_channels=channels[1],
- kernel_size=kernel_sizes[0],
- ),
- activation_fn,
- nn.Conv2d(
- in_channels=channels[1],
- out_channels=channels[2],
- kernel_size=kernel_sizes[1],
- ),
- activation_fn,
- nn.MaxPool2d(kernel_sizes[2]),
- nn.Dropout(p=dropout_rate),
- Rearrange("b c h w -> b (c h w)"),
- nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
- activation_fn,
- nn.Dropout(p=dropout_rate),
- nn.Linear(in_features=hidden_size[1], out_features=num_classes),
- ]
-
- self.layers = nn.Sequential(*self.layers)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward pass."""
- # If batch dimenstion is missing, it needs to be added.
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- return self.layers(x)
diff --git a/text_recognizer/networks/metrics.py b/text_recognizer/networks/metrics.py
deleted file mode 100644
index 2605731..0000000
--- a/text_recognizer/networks/metrics.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""Utility functions for models."""
-from typing import Optional
-
-from einops import rearrange
-import Levenshtein as Lev
-import torch
-from torch import Tensor
-
-from text_recognizer.networks import greedy_decoder
-
-
-def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
- """Computes the accuracy.
-
- Args:
- outputs (Tensor): The output from the network.
- labels (Tensor): Ground truth labels.
- pad_index (int): Padding index.
-
- Returns:
- float: The accuracy for the batch.
-
- """
-
- _, predicted = torch.max(outputs, dim=-1)
-
- # Mask out the pad tokens
- mask = labels != pad_index
-
- predicted *= mask
- labels *= mask
-
- acc = (predicted == labels).sum().float() / labels.shape[0]
- acc = acc.item()
- return acc
-
-
-def cer(
- outputs: Tensor,
- targets: Tensor,
- batch_size: Optional[int] = None,
- blank_label: Optional[int] = int,
-) -> float:
- """Computes the character error rate.
-
- Args:
- outputs (Tensor): The output from the network.
- targets (Tensor): Ground truth labels.
- batch_size (Optional[int]): Batch size if target and output has been flattend.
- blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
-
- Returns:
- float: The cer for the batch.
-
- """
- if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
- targets = rearrange(targets, "(b t) -> b t", b=batch_size)
- outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
-
- target_lengths = torch.full(
- size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
- )
- decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths, blank_label=blank_label,
- )
-
- lev_dist = 0
-
- for prediction, target in zip(decoded_predictions, decoded_targets):
- prediction = "".join(prediction)
- target = "".join(target)
- prediction, target = (
- prediction.replace(" ", ""),
- target.replace(" ", ""),
- )
- lev_dist += Lev.distance(prediction, target)
- return lev_dist / len(decoded_predictions)
-
-
-def wer(
- outputs: Tensor,
- targets: Tensor,
- batch_size: Optional[int] = None,
- blank_label: Optional[int] = int,
-) -> float:
- """Computes the Word error rate.
-
- Args:
- outputs (Tensor): The output from the network.
- targets (Tensor): Ground truth labels.
- batch_size (optional[int]): Batch size if target and output has been flattend.
- blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
-
- Returns:
- float: The wer for the batch.
-
- """
- if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
- targets = rearrange(targets, "(b t) -> b t", b=batch_size)
- outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
-
- target_lengths = torch.full(
- size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
- )
- decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths, blank_label=blank_label,
- )
-
- lev_dist = 0
-
- for prediction, target in zip(decoded_predictions, decoded_targets):
- prediction = "".join(prediction)
- target = "".join(target)
-
- b = set(prediction.split() + target.split())
- word2char = dict(zip(b, range(len(b))))
-
- w1 = [chr(word2char[w]) for w in prediction.split()]
- w2 = [chr(word2char[w]) for w in target.split()]
-
- lev_dist += Lev.distance("".join(w1), "".join(w2))
-
- return lev_dist / len(decoded_predictions)
diff --git a/text_recognizer/networks/mlp.py b/text_recognizer/networks/mlp.py
deleted file mode 100644
index 1101912..0000000
--- a/text_recognizer/networks/mlp.py
+++ /dev/null
@@ -1,73 +0,0 @@
-"""Defines the MLP network."""
-from typing import Callable, Dict, List, Optional, Union
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-
-from text_recognizer.networks.util import activation_function
-
-
-class MLP(nn.Module):
- """Multi layered perceptron network."""
-
- def __init__(
- self,
- input_size: int = 784,
- num_classes: int = 10,
- hidden_size: Union[int, List] = 128,
- num_layers: int = 3,
- dropout_rate: float = 0.2,
- activation_fn: str = "relu",
- ) -> None:
- """Initialization of the MLP network.
-
- Args:
- input_size (int): The input shape of the network. Defaults to 784.
- num_classes (int): Number of classes in the dataset. Defaults to 10.
- hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
- num_layers (int): The number of hidden layers. Defaults to 3.
- dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
- activation_fn (str): Name of the activation function in the hidden layers. Defaults to
- relu.
-
- """
- super().__init__()
-
- activation_fn = activation_function(activation_fn)
-
- if isinstance(hidden_size, int):
- hidden_size = [hidden_size] * num_layers
-
- self.layers = [
- Rearrange("b c h w -> b (c h w)"),
- nn.Linear(in_features=input_size, out_features=hidden_size[0]),
- activation_fn,
- ]
-
- for i in range(num_layers - 1):
- self.layers += [
- nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]),
- activation_fn,
- ]
-
- if dropout_rate:
- self.layers.append(nn.Dropout(p=dropout_rate))
-
- self.layers.append(
- nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
- )
-
- self.layers = nn.Sequential(*self.layers)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward pass."""
- # If batch dimenstion is missing, it needs to be added.
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- return self.layers(x)
-
- @property
- def __name__(self) -> str:
- """Returns the name of the network."""
- return "mlp"
diff --git a/text_recognizer/networks/stn.py b/text_recognizer/networks/stn.py
deleted file mode 100644
index e9d216f..0000000
--- a/text_recognizer/networks/stn.py
+++ /dev/null
@@ -1,44 +0,0 @@
-"""Spatial Transformer Network."""
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-
-class SpatialTransformerNetwork(nn.Module):
- """A network with differentiable attention.
-
- Network that learns how to perform spatial transformations on the input image in order to enhance the
- geometric invariance of the model.
-
- # TODO: add arguments to make it more general.
-
- """
-
- def __init__(self) -> None:
- super().__init__()
- # Initialize the identity transformation and its weights and biases.
- linear = nn.Linear(32, 3 * 2)
- linear.weight.data.zero_()
- linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
-
- self.theta = nn.Sequential(
- nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7),
- nn.MaxPool2d(kernel_size=2, stride=2),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5),
- nn.MaxPool2d(kernel_size=2, stride=2),
- nn.ReLU(inplace=True),
- Rearrange("b c h w -> b (c h w)", h=3, w=3),
- nn.Linear(in_features=10 * 3 * 3, out_features=32),
- nn.ReLU(inplace=True),
- linear,
- Rearrange("b (row col) -> b row col", row=2, col=3),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """The spatial transformation."""
- grid = F.affine_grid(self.theta(x), x.shape)
- return F.grid_sample(x, grid, align_corners=False)
diff --git a/text_recognizer/networks/unet.py b/text_recognizer/networks/unet.py
deleted file mode 100644
index 510910f..0000000
--- a/text_recognizer/networks/unet.py
+++ /dev/null
@@ -1,255 +0,0 @@
-"""UNet for segmentation."""
-from typing import List, Optional, Tuple, Union
-
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-
-
-class _ConvBlock(nn.Module):
- """Modified UNet convolutional block with dilation."""
-
- def __init__(
- self,
- channels: List[int],
- activation: str,
- num_groups: int,
- dropout_rate: float = 0.1,
- kernel_size: int = 3,
- dilation: int = 1,
- padding: int = 0,
- ) -> None:
- super().__init__()
- self.channels = channels
- self.dropout_rate = dropout_rate
- self.kernel_size = kernel_size
- self.dilation = dilation
- self.padding = padding
- self.num_groups = num_groups
- self.activation = activation_function(activation)
- self.block = self._configure_block()
- self.residual_conv = nn.Sequential(
- nn.Conv2d(
- self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1
- ),
- self.activation,
- )
-
- def _configure_block(self) -> nn.Sequential:
- block = []
- for i in range(len(self.channels) - 1):
- block += [
- nn.Dropout(p=self.dropout_rate),
- nn.GroupNorm(self.num_groups, self.channels[i]),
- self.activation,
- nn.Conv2d(
- self.channels[i],
- self.channels[i + 1],
- kernel_size=self.kernel_size,
- padding=self.padding,
- stride=1,
- dilation=self.dilation,
- ),
- ]
-
- return nn.Sequential(*block)
-
- def forward(self, x: Tensor) -> Tensor:
- """Apply the convolutional block."""
- residual = self.residual_conv(x)
- return self.block(x) + residual
-
-
-class _DownSamplingBlock(nn.Module):
- """Basic down sampling block."""
-
- def __init__(
- self,
- channels: List[int],
- activation: str,
- num_groups: int,
- pooling_kernel: Union[int, bool] = 2,
- dropout_rate: float = 0.1,
- kernel_size: int = 3,
- dilation: int = 1,
- padding: int = 0,
- ) -> None:
- super().__init__()
- self.conv_block = _ConvBlock(
- channels,
- activation,
- num_groups,
- dropout_rate,
- kernel_size,
- dilation,
- padding,
- )
- self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Return the convolutional block output and a down sampled tensor."""
- x = self.conv_block(x)
- x_down = self.down_sampling(x) if self.down_sampling is not None else x
-
- return x_down, x
-
-
-class _UpSamplingBlock(nn.Module):
- """The upsampling block of the UNet."""
-
- def __init__(
- self,
- channels: List[int],
- activation: str,
- num_groups: int,
- scale_factor: int = 2,
- dropout_rate: float = 0.1,
- kernel_size: int = 3,
- dilation: int = 1,
- padding: int = 0,
- ) -> None:
- super().__init__()
- self.conv_block = _ConvBlock(
- channels,
- activation,
- num_groups,
- dropout_rate,
- kernel_size,
- dilation,
- padding,
- )
- self.up_sampling = nn.Upsample(
- scale_factor=scale_factor, mode="bilinear", align_corners=True
- )
-
- def forward(self, x: Tensor, x_skip: Optional[Tensor] = None) -> Tensor:
- """Apply the up sampling and convolutional block."""
- x = self.up_sampling(x)
- if x_skip is not None:
- x = torch.cat((x, x_skip), dim=1)
- return self.conv_block(x)
-
-
-class UNet(nn.Module):
- """UNet architecture."""
-
- def __init__(
- self,
- in_channels: int = 1,
- base_channels: int = 64,
- num_classes: int = 3,
- depth: int = 4,
- activation: str = "relu",
- num_groups: int = 8,
- dropout_rate: float = 0.1,
- pooling_kernel: int = 2,
- scale_factor: int = 2,
- kernel_size: Optional[List[int]] = None,
- dilation: Optional[List[int]] = None,
- padding: Optional[List[int]] = None,
- ) -> None:
- super().__init__()
- self.depth = depth
- self.num_groups = num_groups
-
- if kernel_size is not None and dilation is not None and padding is not None:
- if (
- len(kernel_size) != depth
- and len(dilation) != depth
- and len(padding) != depth
- ):
- raise RuntimeError(
- "Length of convolutional parameters does not match the depth."
- )
- self.kernel_size = kernel_size
- self.padding = padding
- self.dilation = dilation
-
- else:
- self.kernel_size = [3] * depth
- self.padding = [1] * depth
- self.dilation = [1] * depth
-
- self.dropout_rate = dropout_rate
- self.conv = nn.Conv2d(
- in_channels, base_channels, kernel_size=3, stride=1, padding=1
- )
-
- channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)]
- self.encoder_blocks = self._configure_down_sampling_blocks(
- channels, activation, pooling_kernel
- )
- self.decoder_blocks = self._configure_up_sampling_blocks(
- channels, activation, scale_factor
- )
-
- self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1)
-
- def _configure_down_sampling_blocks(
- self, channels: List[int], activation: str, pooling_kernel: int
- ) -> nn.ModuleList:
- blocks = nn.ModuleList([])
- for i in range(len(channels) - 1):
- pooling_kernel = pooling_kernel if i < self.depth - 1 else False
- dropout_rate = self.dropout_rate if i < 0 else 0
- blocks += [
- _DownSamplingBlock(
- [channels[i], channels[i + 1], channels[i + 1]],
- activation,
- self.num_groups,
- pooling_kernel,
- dropout_rate,
- self.kernel_size[i],
- self.dilation[i],
- self.padding[i],
- )
- ]
-
- return blocks
-
- def _configure_up_sampling_blocks(
- self, channels: List[int], activation: str, scale_factor: int,
- ) -> nn.ModuleList:
- channels.reverse()
- self.kernel_size.reverse()
- self.dilation.reverse()
- self.padding.reverse()
- return nn.ModuleList(
- [
- _UpSamplingBlock(
- [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]],
- activation,
- self.num_groups,
- scale_factor,
- self.dropout_rate,
- self.kernel_size[i],
- self.dilation[i],
- self.padding[i],
- )
- for i in range(len(channels) - 2)
- ]
- )
-
- def _encode(self, x: Tensor) -> List[Tensor]:
- x_skips = []
- for block in self.encoder_blocks:
- x, x_skip = block(x)
- x_skips.append(x_skip)
- return x_skips
-
- def _decode(self, x_skips: List[Tensor]) -> Tensor:
- x = x_skips[-1]
- for i, block in enumerate(self.decoder_blocks):
- x = block(x, x_skips[-(i + 2)])
- return x
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass with the UNet model."""
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- x = self.conv(x)
- x_skips = self._encode(x)
- x = self._decode(x_skips)
- return self.head(x)
diff --git a/text_recognizer/networks/vit.py b/text_recognizer/networks/vit.py
deleted file mode 100644
index efb3701..0000000
--- a/text_recognizer/networks/vit.py
+++ /dev/null
@@ -1,150 +0,0 @@
-"""A Vision Transformer.
-
-Inspired by:
-https://openreview.net/pdf?id=YicbFdNTTy
-
-"""
-from typing import Optional, Tuple
-
-from einops import rearrange, repeat
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.transformer import Transformer
-
-
-class ViT(nn.Module):
- """Transfomer for image to sequence prediction."""
-
- def __init__(
- self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- vocab_size: int,
- num_heads: int,
- expansion_dim: int,
- patch_dim: Tuple[int, int],
- image_size: Tuple[int, int],
- dropout_rate: float,
- trg_pad_index: int,
- max_len: int,
- activation: str = "gelu",
- ) -> None:
- super().__init__()
-
- self.trg_pad_index = trg_pad_index
- self.patch_dim = patch_dim
- self.num_patches = image_size[-1] // self.patch_dim[1]
-
- # Encoder
- self.patch_to_embedding = nn.Linear(
- self.patch_dim[0] * self.patch_dim[1], hidden_dim
- )
- self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
- self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
- self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
- self.dropout = nn.Dropout(dropout_rate)
- self._init()
-
- self.transformer = Transformer(
- num_encoder_layers,
- num_decoder_layers,
- hidden_dim,
- num_heads,
- expansion_dim,
- dropout_rate,
- activation,
- )
-
- self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
-
- def _init(self) -> None:
- nn.init.normal_(self.character_embedding.weight, std=0.02)
- # nn.init.normal_(self.pos_embedding.weight, std=0.02)
-
- def _create_trg_mask(self, trg: Tensor) -> Tensor:
- # Move this outside the transformer.
- trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
- trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(
- torch.ones((trg_len, trg_len), device=trg.device)
- ).bool()
- trg_mask = trg_pad_mask & trg_sub_mask
- return trg_mask
-
- def encoder(self, src: Tensor) -> Tensor:
- """Forward pass with the encoder of the transformer."""
- return self.transformer.encoder(src)
-
- def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
- """Forward pass with the decoder of the transformer + classification head."""
- return self.head(
- self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
- )
-
- def extract_image_features(self, src: Tensor) -> Tensor:
- """Extracts image features with a backbone neural network.
-
- It seem like the winning idea was to swap channels and width dimension and collapse
- the height dimension. The transformer is learning like a baby with this implementation!!! :D
- Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
-
- Args:
- src (Tensor): Input tensor.
-
- Returns:
- Tensor: A input src to the transformer.
-
- """
- # If batch dimension is missing, it needs to be added.
- if len(src.shape) < 4:
- src = src[(None,) * (4 - len(src.shape))]
-
- patches = rearrange(
- src,
- "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
- p1=self.patch_dim[0],
- p2=self.patch_dim[1],
- )
-
- # From patches to encoded sequence.
- x = self.patch_to_embedding(patches)
- b, n, _ = x.shape
- cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
- x = torch.cat((cls_tokens, x), dim=1)
- x += self.pos_embedding[:, : (n + 1)]
- x = self.dropout(x)
-
- return x
-
- def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes target tensor with embedding and postion.
-
- Args:
- trg (Tensor): Target tensor.
-
- Returns:
- Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
-
- """
- _, n = trg.shape
- trg = self.character_embedding(trg.long())
- trg += self.pos_embedding[:, :n]
- return trg
-
- def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Takes images features from the backbone and decodes them with the transformer."""
- trg_mask = self._create_trg_mask(trg)
- trg = self.target_embedding(trg)
- out = self.transformer(h, trg, trg_mask=trg_mask)
-
- logits = self.head(out)
- return logits
-
- def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Forward pass with CNN transfomer."""
- h = self.extract_image_features(x)
- logits = self.decode_image_features(h, trg)
- return logits
diff --git a/text_recognizer/paragraph_text_recognizer.py b/text_recognizer/paragraph_text_recognizer.py
deleted file mode 100644
index aa39662..0000000
--- a/text_recognizer/paragraph_text_recognizer.py
+++ /dev/null
@@ -1,153 +0,0 @@
-"""Full model.
-
-Takes an image and returns the text in the image, by first segmenting the image with a LineDetector, then extracting the
-each crop of the image corresponding to line regions, and feeding them to a LinePredictor model that outputs the text
-in each region.
-"""
-from typing import Dict, List, Tuple, Union
-
-import cv2
-import numpy as np
-import torch
-
-from text_recognizer.models import SegmentationModel, TransformerModel
-from text_recognizer.util import read_image
-
-
-class ParagraphTextRecognizor:
- """Given an image of a single handwritten character, recognizes it."""
-
- def __init__(self, line_predictor_args: Dict, line_detector_args: Dict) -> None:
- self._line_predictor = TransformerModel(**line_predictor_args)
- self._line_detector = SegmentationModel(**line_detector_args)
- self._line_detector.eval()
- self._line_predictor.eval()
-
- def predict(self, image_or_filename: Union[str, np.ndarray]) -> Tuple:
- """Takes an image and returns all text within it."""
- image = (
- read_image(image_or_filename)
- if isinstance(image_or_filename, str)
- else image_or_filename
- )
-
- line_region_crops = self._get_line_region_crops(image)
- processed_line_region_crops = [
- self._process_image_for_line_predictor(image=crop)
- for crop in line_region_crops
- ]
- line_region_strings = [
- self.line_predictor_model.predict_on_image(crop)[0]
- for crop in processed_line_region_crops
- ]
-
- return " ".join(line_region_strings), line_region_crops
-
- def _get_line_region_crops(
- self, image: np.ndarray, min_crop_len_factor: float = 0.02
- ) -> List[np.ndarray]:
- """Returns all the crops of text lines in a square image."""
- processed_image, scale_down_factor = self._process_image_for_line_detector(
- image
- )
- line_segmentation = self._line_detector.predict_on_image(processed_image)
- bounding_boxes = _find_line_bounding_boxes(line_segmentation)
-
- bounding_boxes = (bounding_boxes * scale_down_factor).astype(int)
-
- min_crop_len = int(min_crop_len_factor * min(image.shape[0], image.shape[1]))
- line_region_crops = [
- image[y : y + h, x : x + w]
- for x, y, w, h in bounding_boxes
- if w >= min_crop_len and h >= min_crop_len
- ]
- return line_region_crops
-
- def _process_image_for_line_detector(
- self, image: np.ndarray
- ) -> Tuple[np.ndarray, float]:
- """Convert uint8 image to float image with black background with shape self._line_detector.image_shape."""
- resized_image, scale_down_factor = _resize_image_for_line_detector(
- image=image, max_shape=self._line_detector.image_shape
- )
- resized_image = (1.0 - resized_image / 255).astype("float32")
- return resized_image, scale_down_factor
-
- def _process_image_for_line_predictor(self, image: np.ndarray) -> np.ndarray:
- """Preprocessing of image before feeding it to the LinePrediction model.
-
- Convert uint8 image to float image with black background with shape
- self._line_predictor.image_shape while maintaining the image aspect ratio.
-
- Args:
- image (np.ndarray): Crop of text line.
-
- Returns:
- np.ndarray: Processed crop for feeding line predictor.
- """
- expected_shape = self._line_detector.image_shape
- scale_factor = (np.array(expected_shape) / np.array(image.shape)).min()
- scaled_image = cv2.resize(
- image,
- dsize=None,
- fx=scale_factor,
- fy=scale_factor,
- interpolation=cv2.INTER_AREA,
- )
-
- pad_with = (
- (0, expected_shape[0] - scaled_image.shape[0]),
- (0, expected_shape[1] - scaled_image.shape[1]),
- )
-
- padded_image = np.pad(
- scaled_image, pad_with=pad_with, mode="constant", constant_values=255
- )
- return 1 - padded_image / 255
-
-
-def _find_line_bounding_boxes(line_segmentation: np.ndarray) -> np.ndarray:
- """Given a line segmentation, find bounding boxes for connected-component regions corresponding to non-0 labels."""
-
- def _find_line_bounding_boxes_in_channel(
- line_segmentation_channel: np.ndarray,
- ) -> np.ndarray:
- line_segmentation_image = cv2.dilate(
- line_segmentation_channel, kernel=np.ones((3, 3)), iterations=1
- )
- line_activation_image = (line_segmentation_image * 255).astype("uint8")
- line_activation_image = cv2.threshold(
- line_activation_image, 0.5, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU
- )[1]
-
- bounding_cnts, _ = cv2.findContours(
- line_segmentation_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
- )
- return np.array([cv2.boundingRect(cnt) for cnt in bounding_cnts])
-
- bounding_boxes = np.concatenate(
- [
- _find_line_bounding_boxes_in_channel(line_segmentation[:, :, i])
- for i in [1, 2]
- ],
- axis=0,
- )
-
- return bounding_boxes[np.argsort(bounding_boxes[:, 1])]
-
-
-def _resize_image_for_line_detector(
- image: np.ndarray, max_shape: Tuple[int, int]
-) -> Tuple[np.ndarray, float]:
- """Resize the image to less than the max_shape while maintaining the aspect ratio."""
- scale_down_factor = max(np.ndarray(image.shape) / np.ndarray(max_shape))
- if scale_down_factor == 1:
- return image.copy(), scale_down_factor
- resize_image = cv2.resize(
- image,
- dsize=None,
- fx=1 / scale_down_factor,
- fy=1 / scale_down_factor,
- interpolation=cv2.INTER_AREA,
- )
- return resize_image, scale_down_factor