diff options
28 files changed, 14 insertions, 3124 deletions
diff --git a/poetry.lock b/poetry.lock index fa374b6..6926f1f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -853,6 +853,14 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} dev = ["codecov (>=2.0.15)", "colorama (>=0.3.4)", "flake8 (>=3.7.7)", "tox (>=3.9.0)", "tox-travis (>=0.12)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "Sphinx (>=2.2.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "black (>=19.10b0)", "isort (>=5.1.1)"] [[package]] +name = "madgrad" +version = "1.0" +description = "A general purpose PyTorch Optimizer" +category = "main" +optional = false +python-versions = ">=3.6" + +[[package]] name = "markdown" version = "3.3.4" description = "Python implementation of Markdown." @@ -2162,7 +2170,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "cffb5a23a46f3be6be0b8ea8289cfb97c8ba9722f869bbfc2af75cb80a737877" +content-hash = "db4253add1258abaf637f127ba576b1ec5e0b415c8f7f93b18ecdca40bc6f042" [metadata.files] absl-py = [ @@ -2641,6 +2649,10 @@ loguru = [ {file = "loguru-0.5.3-py3-none-any.whl", hash = "sha256:f8087ac396b5ee5f67c963b495d615ebbceac2796379599820e324419d53667c"}, {file = "loguru-0.5.3.tar.gz", hash = "sha256:b28e72ac7a98be3d28ad28570299a393dfcd32e5e3f6a353dec94675767b6319"}, ] +madgrad = [ + {file = "madgrad-1.0-py3-none-any.whl", hash = "sha256:cd5239a1274ee025abec14c99d2af06b11783a379da32cbe2f4b07fc81ef20ea"}, + {file = "madgrad-1.0.tar.gz", hash = "sha256:5a34e1d295ebb2f85fbf9e09ed3b548e27908471bbe2506dda35de5a471c0cbe"}, +] markdown = [ {file = "Markdown-3.3.4-py3-none-any.whl", hash = "sha256:96c3ba1261de2f7547b46a00ea8463832c921d3f9d6aba3f255a6f71386db20c"}, {file = "Markdown-3.3.4.tar.gz", hash = "sha256:31b5b491868dcc87d6c24b7e3d19a0d730d59d3e46f4eea6430a321bed387a49"}, diff --git a/pyproject.toml b/pyproject.toml index e791dd9..32bdb0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ gtn = "^0.0.0" sentencepiece = "^0.1.95" pytorch-lightning = "^1.2.4" Pillow = "^8.1.2" +madgrad = "^1.0" [tool.poetry.dev-dependencies] pytest = "^5.4.2" diff --git a/text_recognizer/__init__.py b/text_recognizer/__init__.py index 3dc1f76..e69de29 100644 --- a/text_recognizer/__init__.py +++ b/text_recognizer/__init__.py @@ -1 +0,0 @@ -__version__ = "0.1.0" diff --git a/text_recognizer/character_predictor.py b/text_recognizer/character_predictor.py deleted file mode 100644 index ad71289..0000000 --- a/text_recognizer/character_predictor.py +++ /dev/null @@ -1,29 +0,0 @@ -"""CharacterPredictor class.""" -from typing import Dict, Tuple, Type, Union - -import numpy as np -from torch import nn - -from text_recognizer import datasets, networks -from text_recognizer.models import CharacterModel -from text_recognizer.util import read_image - - -class CharacterPredictor: - """Recognizes the character in handwritten character images.""" - - def __init__(self, network_fn: str, dataset: str) -> None: - """Intializes the CharacterModel and load the pretrained weights.""" - network_fn = getattr(networks, network_fn) - dataset = getattr(datasets, dataset) - self.model = CharacterModel(network_fn=network_fn, dataset=dataset) - self.model.eval() - self.model.use_swa_model() - - def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: - """Predict on a single images contianing a handwritten character.""" - if isinstance(image_or_filename, str): - image = read_image(image_or_filename, grayscale=True) - else: - image = image_or_filename - return self.model.predict_on_image(image) diff --git a/text_recognizer/data/iam_paragraphs_dataset.py b/text_recognizer/data/iam_paragraphs_dataset.py deleted file mode 100644 index 8ba5142..0000000 --- a/text_recognizer/data/iam_paragraphs_dataset.py +++ /dev/null @@ -1,291 +0,0 @@ -"""IamParagraphsDataset class and functions for data processing.""" -import random -from typing import Callable, Dict, List, Optional, Tuple, Union - -import click -import cv2 -import h5py -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer import util -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.iam_dataset import IamDataset -from text_recognizer.datasets.util import ( - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, -) - -INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" -DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" -CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" -GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" - -PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines. -TEST_FRACTION = 0.2 -SEED = 4711 - - -class IamParagraphsDataset(Dataset): - """IAM Paragraphs dataset for paragraphs of handwritten text.""" - - def __init__( - self, - train: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - ) -> None: - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - ) - # Load Iam dataset. - self.iam_dataset = IamDataset() - - self.num_classes = 3 - self._input_shape = (256, 256) - self._output_shape = self._input_shape + (self.num_classes,) - self._ids = None - - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - seed = np.random.randint(SEED) - random.seed(seed) # apply this seed to target tranfsorms - torch.manual_seed(seed) # needed for torchvision 0.7 - if self.transform: - data = self.transform(data) - - random.seed(seed) # apply this seed to target tranfsorms - torch.manual_seed(seed) # needed for torchvision 0.7 - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets.long() - - @property - def ids(self) -> Tensor: - """Ids of the dataset.""" - return self._ids - - def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: - """Get data target pair from id.""" - ind = self.ids.index(id_) - return self.data[ind], self.targets[ind] - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) - num_targets = len(self.iam_dataset.line_regions_by_id) - - if num_actual < num_targets - 2: - self._process_iam_paragraphs() - - self._data, self._targets, self._ids = _load_iam_paragraphs() - self._get_random_split() - self._subsample() - - def _get_random_split(self) -> None: - np.random.seed(SEED) - num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) - indices = np.random.permutation(self.data.shape[0]) - train_indices, test_indices = indices[:num_train], indices[num_train:] - if self.train: - self._data = self.data[train_indices] - self._targets = self.targets[train_indices] - else: - self._data = self.data[test_indices] - self._targets = self.targets[test_indices] - - def _process_iam_paragraphs(self) -> None: - """Crop the part with the text. - - For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are - self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel - corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line - """ - crop_dims = self._decide_on_crop_dims() - CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) - DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) - GT_DIRNAME.mkdir(parents=True, exist_ok=True) - logger.info( - f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" - ) - for filename in self.iam_dataset.form_filenames: - id_ = filename.stem - line_region = self.iam_dataset.line_regions_by_id[id_] - _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) - - def _decide_on_crop_dims(self) -> Tuple[int, int]: - """Decide on the dimensions to crop out of the form image. - - Since image width is larger than a comfortable crop around the longest paragraph, - we will make the crop a square form factor. - And since the found dimensions 610x610 are pretty close to 512x512, - we might as well resize crops and make it exactly that, which lets us - do all kinds of power-of-2 pooling and upsampling should we choose to. - - Returns: - Tuple[int, int]: A tuple of crop dimensions. - - Raises: - RuntimeError: When max crop height is larger than max crop width. - - """ - - sample_form_filename = self.iam_dataset.form_filenames[0] - sample_image = util.read_image(sample_form_filename, grayscale=True) - max_crop_width = sample_image.shape[1] - max_crop_height = _get_max_paragraph_crop_height( - self.iam_dataset.line_regions_by_id - ) - if not max_crop_height <= max_crop_width: - raise RuntimeError( - f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" - ) - - crop_dims = (max_crop_width, max_crop_width) - logger.info( - f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." - ) - logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") - return crop_dims - - def __repr__(self) -> str: - """Return info about the dataset.""" - return ( - "IAM Paragraph Dataset\n" # pylint: disable=no-member - f"Num classes: {self.num_classes}\n" - f"Data: {self.data.shape}\n" - f"Targets: {self.targets.shape}\n" - ) - - -def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: - heights = [] - for regions in line_regions_by_id.values(): - min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER - max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER - height = max_y2 - min_y1 - heights.append(height) - return max(heights) - - -def _crop_paragraph_image( - filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple -) -> None: - image = util.read_image(filename, grayscale=True) - - min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER - max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER - height = max_y2 - min_y1 - crop_height = crop_dims[0] - buffer = (crop_height - height) // 2 - - # Generate image crop. - image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) - try: - image_crop[buffer : buffer + height] = image[min_y1:max_y2] - except Exception as e: # pylint: disable=broad-except - logger.error(f"Rescued {filename}: {e}") - return - - # Generate ground truth. - gt_image = np.zeros_like(image_crop, dtype=np.uint8) - for index, region in enumerate(line_regions): - gt_image[ - (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), - region["x1"] : region["x2"], - ] = (index % 2 + 1) - - # Generate image for debugging. - import matplotlib.pyplot as plt - - cmap = plt.get_cmap("Set1") - image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) - for index, region in enumerate(line_regions): - color = [255 * _ for _ in cmap(index)[:-1]] - cv2.rectangle( - image_crop_for_debug, - (region["x1"], region["y1"] - min_y1 + buffer), - (region["x2"], region["y2"] - min_y1 + buffer), - color, - 3, - ) - image_crop_for_debug = cv2.resize( - image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA - ) - util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") - - image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) - util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") - - gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) - util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") - - -def _load_iam_paragraphs() -> None: - logger.info("Loading IAM paragraph crops and ground truth from image files...") - images = [] - gt_images = [] - ids = [] - for filename in CROPS_DIRNAME.glob("*.jpg"): - id_ = filename.stem - image = util.read_image(filename, grayscale=True) - image = 1.0 - image / 255 - - gt_filename = GT_DIRNAME / f"{id_}.png" - gt_image = util.read_image(gt_filename, grayscale=True) - - images.append(image) - gt_images.append(gt_image) - ids.append(id_) - images = np.array(images).astype(np.float32) - gt_images = np.array(gt_images).astype(np.uint8) - ids = np.array(ids) - return images, gt_images, ids - - -@click.command() -@click.option( - "--subsample_fraction", - type=float, - default=None, - help="The subsampling factor of the dataset.", -) -def main(subsample_fraction: float) -> None: - """Load dataset and print info.""" - logger.info("Creating train set...") - dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) - dataset.load_or_generate_data() - print(dataset) - logger.info("Creating test set...") - dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) - dataset.load_or_generate_data() - print(dataset) - - -if __name__ == "__main__": - main() diff --git a/text_recognizer/data/util.py b/text_recognizer/data/util.py deleted file mode 100644 index da87756..0000000 --- a/text_recognizer/data/util.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Util functions for datasets.""" -import hashlib -import json -import os -from pathlib import Path -import string -from typing import Dict, List, Optional, Union -from urllib.request import urlretrieve - -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.datasets import EMNIST -from tqdm import tqdm - -DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" -ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" - - -def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: - """Extract and saves EMNIST essentials.""" - labels = emnsit_dataset.classes - labels.sort() - mapping = [(i, str(label)) for i, label in enumerate(labels)] - essentials = { - "mapping": mapping, - "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), - } - logger.info("Saving emnist essentials...") - with open(ESSENTIALS_FILENAME, "w") as f: - json.dump(essentials, f) - - -def download_emnist() -> None: - """Download the EMNIST dataset via the PyTorch class.""" - logger.info(f"Data directory is: {DATA_DIRNAME}") - dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) - save_emnist_essentials(dataset) - - -class EmnistMapper: - """Mapper between network output to Emnist character.""" - - def __init__( - self, - pad_token: str, - init_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Loads the emnist essentials file with the mapping and input shape.""" - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - self.lower = lower - - self.essentials = self._load_emnist_essentials() - # Load dataset information. - self._mapping = dict(self.essentials["mapping"]) - self._augment_emnist_mapping() - self._inverse_mapping = {v: k for k, v in self.mapping.items()} - self._num_classes = len(self.mapping) - self._input_shape = self.essentials["input_shape"] - - def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: - """Maps the token to emnist character or character index. - - If the token is an integer (index), the method will return the Emnist character corresponding to that index. - If the token is a str (Emnist character), the method will return the corresponding index for that character. - - Args: - token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). - - Returns: - Union[str, int]: The mapping result. - - Raises: - KeyError: If the index or string does not exist in the mapping. - - """ - if ( - (isinstance(token, np.uint8) or isinstance(token, int)) - or torch.is_tensor(token) - and int(token) in self.mapping - ): - return self.mapping[int(token)] - elif isinstance(token, str) and token in self._inverse_mapping: - return self._inverse_mapping[token] - else: - raise KeyError(f"Token {token} does not exist in the mappings.") - - @property - def mapping(self) -> Dict: - """Returns the mapping between index and character.""" - return self._mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the mapping between character and index.""" - return self._inverse_mapping - - @property - def num_classes(self) -> int: - """Returns the number of classes in the dataset.""" - return self._num_classes - - @property - def input_shape(self) -> List[int]: - """Returns the input shape of the Emnist characters.""" - return self._input_shape - - def _load_emnist_essentials(self) -> Dict: - """Load the EMNIST mapping.""" - with open(str(ESSENTIALS_FILENAME)) as f: - essentials = json.load(f) - return essentials - - def _augment_emnist_mapping(self) -> None: - """Augment the mapping with extra symbols.""" - # Extra symbols in IAM dataset - if self.lower: - self._mapping = { - k: str(v) - for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) - } - - extra_symbols = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # padding symbol, and acts as blank symbol as well. - extra_symbols.append(self.pad_token) - - if self.init_token is not None: - extra_symbols.append(self.init_token) - - if self.eos_token is not None: - extra_symbols.append(self.eos_token) - - max_key = max(self.mapping.keys()) - extra_mapping = {} - for i, symbol in enumerate(extra_symbols): - extra_mapping[max_key + 1 + i] = symbol - - self._mapping = {**self.mapping, **extra_mapping} - - -def compute_sha256(filename: Union[Path, str]) -> str: - """Returns the SHA256 checksum of a file.""" - with open(filename, "rb") as f: - return hashlib.sha256(f.read()).hexdigest() - - -class TqdmUpTo(tqdm): - """TQDM progress bar when downloading files. - - From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - - """ - - def update_to( - self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None - ) -> None: - """Updates the progress bar. - - Args: - blocks (int): Number of blocks transferred so far. Defaults to 1. - block_size (int): Size of each block, in tqdm units. Defaults to 1. - total_size (Optional[int]): Total size in tqdm units. Defaults to None. - """ - if total_size is not None: - self.total = total_size # pylint: disable=attribute-defined-outside-init - self.update(blocks * block_size - self.n) - - -def download_url(url: str, filename: str) -> None: - """Downloads a file from url to filename, with a progress bar.""" - with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: - urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec - - -def _download_raw_dataset(metadata: Dict) -> None: - if os.path.exists(metadata["filename"]): - return - logger.info(f"Downloading raw dataset from {metadata['url']}...") - download_url(metadata["url"], metadata["filename"]) - logger.info("Computing SHA-256...") - sha256 = compute_sha256(metadata["filename"]) - if sha256 != metadata["sha256"]: - raise ValueError( - "Downloaded data file SHA-256 does not match that listed in metadata document." - ) diff --git a/text_recognizer/line_predictor.py b/text_recognizer/line_predictor.py deleted file mode 100644 index 8e348fe..0000000 --- a/text_recognizer/line_predictor.py +++ /dev/null @@ -1,28 +0,0 @@ -"""LinePredictor class.""" -import importlib -from typing import Tuple, Union - -import numpy as np -from torch import nn - -from text_recognizer import datasets, networks -from text_recognizer.models import TransformerModel -from text_recognizer.util import read_image - - -class LinePredictor: - """Given an image of a line of handwritten text, recognizes the text content.""" - - def __init__(self, dataset: str, network_fn: str) -> None: - network_fn = getattr(networks, network_fn) - dataset = getattr(datasets, dataset) - self.model = TransformerModel(network_fn=network_fn, dataset=dataset) - self.model.eval() - - def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: - """Predict on a single images contianing a handwritten character.""" - if isinstance(image_or_filename, str): - image = read_image(image_or_filename, grayscale=True) - else: - image = image_or_filename - return self.model.predict_on_image(image) diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py index 7647d7e..e69de29 100644 --- a/text_recognizer/models/__init__.py +++ b/text_recognizer/models/__init__.py @@ -1,18 +0,0 @@ -"""Model modules.""" -from .base import Model -from .character_model import CharacterModel -from .crnn_model import CRNNModel -from .ctc_transformer_model import CTCTransformerModel -from .segmentation_model import SegmentationModel -from .transformer_model import TransformerModel -from .vqvae_model import VQVAEModel - -__all__ = [ - "CharacterModel", - "CRNNModel", - "CTCTransformerModel", - "Model", - "SegmentationModel", - "TransformerModel", - "VQVAEModel", -] diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py deleted file mode 100644 index 70f4cdb..0000000 --- a/text_recognizer/models/base.py +++ /dev/null @@ -1,455 +0,0 @@ -"""Abstract Model class for PyTorch neural networks.""" - -from abc import ABC, abstractmethod -from glob import glob -import importlib -from pathlib import Path -import re -import shutil -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -from loguru import logger -import torch -from torch import nn -from torch import Tensor -from torch.optim.swa_utils import AveragedModel, SWALR -from torch.utils.data import DataLoader, Dataset, random_split -from torchsummary import summary - -from text_recognizer import datasets -from text_recognizer import networks -from text_recognizer.datasets import EmnistMapper - -WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" - - -class Model(ABC): - """Abstract Model class with composition of different parts defining a PyTorch neural network.""" - - def __init__( - self, - network_fn: str, - dataset: str, - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - """Base class, to be inherited by model for specific type of data. - - Args: - network_fn (str): The name of network. - dataset (str): The name dataset class. - network_args (Optional[Dict]): Arguments for the network. Defaults to None. - dataset_args (Optional[Dict]): Arguments for the dataset. - metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. - criterion (Optional[Callable]): The criterion to evaluate the performance of the network. - Defaults to None. - criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None. - optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. - optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None. - lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None. - lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to - None. - swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to - None. - device (Optional[str]): Name of the device to train on. Defaults to None. - - """ - self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}" - # Has to be set in subclass. - self._mapper = None - - # Placeholder. - self._input_shape = None - - self.dataset_name = dataset - self.dataset = None - self.dataset_args = dataset_args - - # Placeholders for datasets. - self.train_dataset = None - self.val_dataset = None - self.test_dataset = None - - # Stochastic Weight Averaging placeholders. - self.swa_args = swa_args - self._swa_scheduler = None - self._swa_network = None - self._use_swa_model = False - - # Experiment directory. - self.model_dir = None - - # Flag for configured model. - self.is_configured = False - self.data_prepared = False - - # Flag for stopping training. - self.stop_training = False - - self._metrics = metrics if metrics is not None else None - - # Set the device. - self._device = ( - torch.device("cuda" if torch.cuda.is_available() else "cpu") - if device is None - else device - ) - - # Configure network. - self._network = None - self._network_args = network_args - self._configure_network(network_fn) - - # Place network on device (GPU). - self.to_device() - - # Loss and Optimizer placeholders for before loading. - self._criterion = criterion - self.criterion_args = criterion_args - - self._optimizer = optimizer - self.optimizer_args = optimizer_args - - self._lr_scheduler = lr_scheduler - self.lr_scheduler_args = lr_scheduler_args - - def configure_model(self) -> None: - """Configures criterion and optimizers.""" - if not self.is_configured: - self._configure_criterion() - self._configure_optimizers() - - # Set this flag to true to prevent the model from configuring again. - self.is_configured = True - - def prepare_data(self) -> None: - """Prepare data for training.""" - # TODO add downloading. - if not self.data_prepared: - # Load dataset module. - self.dataset = getattr(datasets, self.dataset_name) - - # Load train dataset. - train_dataset = self.dataset(train=True, **self.dataset_args["args"]) - train_dataset.load_or_generate_data() - - # Set input shape. - self._input_shape = train_dataset.input_shape - - # Split train dataset into a training and validation partition. - dataset_len = len(train_dataset) - train_len = int( - self.dataset_args["train_args"]["train_fraction"] * dataset_len - ) - val_len = dataset_len - train_len - self.train_dataset, self.val_dataset = random_split( - train_dataset, lengths=[train_len, val_len] - ) - - # Load test dataset. - self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) - self.test_dataset.load_or_generate_data() - - # Set the flag to true to disable ability to load data again. - self.data_prepared = True - - def train_dataloader(self) -> DataLoader: - """Returns data loader for training set.""" - return DataLoader( - self.train_dataset, - batch_size=self.dataset_args["train_args"]["batch_size"], - num_workers=self.dataset_args["train_args"]["num_workers"], - shuffle=True, - pin_memory=True, - ) - - def val_dataloader(self) -> DataLoader: - """Returns data loader for validation set.""" - return DataLoader( - self.val_dataset, - batch_size=self.dataset_args["train_args"]["batch_size"], - num_workers=self.dataset_args["train_args"]["num_workers"], - shuffle=True, - pin_memory=True, - ) - - def test_dataloader(self) -> DataLoader: - """Returns data loader for test set.""" - return DataLoader( - self.test_dataset, - batch_size=self.dataset_args["train_args"]["batch_size"], - num_workers=self.dataset_args["train_args"]["num_workers"], - shuffle=False, - pin_memory=True, - ) - - def _configure_network(self, network_fn: Type[nn.Module]) -> None: - """Loads the network.""" - # If no network arguments are given, load pretrained weights if they exist. - # Load network module. - network_fn = getattr(networks, network_fn) - if self._network_args is None: - self.load_weights(network_fn) - else: - self._network = network_fn(**self._network_args) - - def _configure_criterion(self) -> None: - """Loads the criterion.""" - self._criterion = ( - self._criterion(**self.criterion_args) - if self._criterion is not None - else None - ) - - def _configure_optimizers(self,) -> None: - """Loads the optimizers.""" - if self._optimizer is not None: - self._optimizer = self._optimizer( - self._network.parameters(), **self.optimizer_args - ) - else: - self._optimizer = None - - if self._optimizer and self._lr_scheduler is not None: - if "steps_per_epoch" in self.lr_scheduler_args: - self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) - - # Assume lr scheduler should update at each epoch if not specified. - if "interval" not in self.lr_scheduler_args: - interval = "epoch" - else: - interval = self.lr_scheduler_args.pop("interval") - self._lr_scheduler = { - "lr_scheduler": self._lr_scheduler( - self._optimizer, **self.lr_scheduler_args - ), - "interval": interval, - } - - if self.swa_args is not None: - self._swa_scheduler = { - "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]), - "swa_start": self.swa_args["start"], - } - self._swa_network = AveragedModel(self._network).to(self.device) - - @property - def name(self) -> str: - """Returns the name of the model.""" - return self._name - - @property - def input_shape(self) -> Tuple[int, ...]: - """The input shape.""" - return self._input_shape - - @property - def mapper(self) -> EmnistMapper: - """Returns the mapper that maps between ints and chars.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Returns the mapping between network output and Emnist character.""" - return self._mapper.mapping if self._mapper is not None else None - - def eval(self) -> None: - """Sets the network to evaluation mode.""" - self._network.eval() - - def train(self) -> None: - """Sets the network to train mode.""" - self._network.train() - - @property - def device(self) -> str: - """Device where the weights are stored, i.e. cpu or cuda.""" - return self._device - - @property - def metrics(self) -> Optional[Dict]: - """Metrics.""" - return self._metrics - - @property - def criterion(self) -> Optional[Callable]: - """Criterion.""" - return self._criterion - - @property - def optimizer(self) -> Optional[Callable]: - """Optimizer.""" - return self._optimizer - - @property - def lr_scheduler(self) -> Optional[Dict]: - """Returns a directory with the learning rate scheduler.""" - return self._lr_scheduler - - @property - def swa_scheduler(self) -> Optional[Dict]: - """Returns a directory with the stochastic weight averaging scheduler.""" - return self._swa_scheduler - - @property - def swa_network(self) -> Optional[Callable]: - """Returns the stochastic weight averaging network.""" - return self._swa_network - - @property - def network(self) -> Type[nn.Module]: - """Neural network.""" - # Returns the SWA network if available. - return self._network - - @property - def weights_filename(self) -> str: - """Filepath to the network weights.""" - WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) - return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") - - def use_swa_model(self) -> None: - """Set to use predictions from SWA model.""" - if self.swa_network is not None: - self._use_swa_model = True - - def forward(self, x: Tensor) -> Tensor: - """Feedforward pass with the network.""" - if self._use_swa_model: - return self.swa_network(x) - else: - return self.network(x) - - def summary( - self, - input_shape: Optional[Union[List, Tuple]] = None, - depth: int = 3, - device: Optional[str] = None, - ) -> None: - """Prints a summary of the network architecture.""" - device = self.device if device is None else device - - if input_shape is not None: - summary(self.network, input_shape, depth=depth, device=device) - elif self._input_shape is not None: - input_shape = tuple(self._input_shape) - summary(self.network, input_shape, depth=depth, device=device) - else: - logger.warning("Could not print summary as input shape is not set.") - - def to_device(self) -> None: - """Places the network on the device (GPU).""" - self._network.to(self._device) - - def _get_state_dict(self) -> Dict: - """Get the state dict of the model.""" - state = {"model_state": self._network.state_dict()} - - if self._optimizer is not None: - state["optimizer_state"] = self._optimizer.state_dict() - - if self._lr_scheduler is not None: - state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() - state["scheduler_interval"] = self._lr_scheduler["interval"] - - if self._swa_network is not None: - state["swa_network"] = self._swa_network.state_dict() - - return state - - def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None: - """Load a previously saved checkpoint. - - Args: - checkpoint_path (Path): Path to the experiment with the checkpoint. - - """ - checkpoint_path = Path(checkpoint_path) - self.prepare_data() - self.configure_model() - logger.debug("Loading checkpoint...") - if not checkpoint_path.exists(): - logger.debug("File does not exist {str(checkpoint_path)}") - - checkpoint = torch.load(str(checkpoint_path), map_location=self.device) - self._network.load_state_dict(checkpoint["model_state"]) - - if self._optimizer is not None: - self._optimizer.load_state_dict(checkpoint["optimizer_state"]) - - if self._lr_scheduler is not None: - # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs - # with OneCycleLR. - if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": - self._lr_scheduler["lr_scheduler"].load_state_dict( - checkpoint["scheduler_state"] - ) - self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] - - if self._swa_network is not None: - self._swa_network.load_state_dict(checkpoint["swa_network"]) - - def save_checkpoint( - self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str - ) -> None: - """Saves a checkpoint of the model. - - Args: - checkpoint_path (Path): Path to the experiment with the checkpoint. - is_best (bool): If it is the currently best model. - epoch (int): The epoch of the checkpoint. - val_metric (str): Validation metric. - - """ - state = self._get_state_dict() - state["is_best"] = is_best - state["epoch"] = epoch - state["network_args"] = self._network_args - - checkpoint_path.mkdir(parents=True, exist_ok=True) - - logger.debug("Saving checkpoint...") - filepath = str(checkpoint_path / "last.pt") - torch.save(state, filepath) - - if is_best: - logger.debug( - f"Found a new best {val_metric}. Saving best checkpoint and weights." - ) - shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) - - def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None: - """Load the network weights.""" - logger.debug("Loading network with pretrained weights.") - filename = glob(self.weights_filename)[0] - if not filename: - raise FileNotFoundError( - f"Could not find any pretrained weights at {self.weights_filename}" - ) - # Loading state directory. - state_dict = torch.load(filename, map_location=torch.device(self._device)) - self._network_args = state_dict["network_args"] - weights = state_dict["model_state"] - - # Initializes the network with trained weights. - if network_fn is not None: - self._network = network_fn(**self._network_args) - self._network.load_state_dict(weights) - - if "swa_network" in state_dict: - self._swa_network = AveragedModel(self._network).to(self.device) - self._swa_network.load_state_dict(state_dict["swa_network"]) - - def save_weights(self, path: Path) -> None: - """Save the network weights.""" - logger.debug("Saving the best network weights.") - shutil.copyfile(str(path / "best.pt"), self.weights_filename) diff --git a/text_recognizer/models/character_model.py b/text_recognizer/models/character_model.py deleted file mode 100644 index f9944f3..0000000 --- a/text_recognizer/models/character_model.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Defines the CharacterModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model - - -class CharacterModel(Model): - """Model for predicting characters from images.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - """Initializes the CharacterModel.""" - - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.pad_token = dataset_args["args"]["pad_token"] - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token,) - self.tensor_transform = ToTensor() - self.softmax = nn.Softmax(dim=0) - - @torch.no_grad() - def predict_on_image( - self, image: Union[np.ndarray, torch.Tensor] - ) -> Tuple[str, float]: - """Character prediction on an image. - - Args: - image (Union[np.ndarray, torch.Tensor]): An image containing a character. - - Returns: - Tuple[str, float]: The predicted character and the confidence in the prediction. - - """ - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - logits = self.forward(image) - - prediction = self.softmax(logits.squeeze(0)) - - index = int(torch.argmax(prediction, dim=0)) - confidence_of_prediction = prediction[index] - predicted_character = self.mapper(index) - - return predicted_character, confidence_of_prediction diff --git a/text_recognizer/models/crnn_model.py b/text_recognizer/models/crnn_model.py deleted file mode 100644 index 1e01a83..0000000 --- a/text_recognizer/models/crnn_model.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Defines the CRNNModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class CRNNModel(Model): - """Model for predicting a sequence of characters from an image of a text line.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - - self.pad_token = dataset_args["args"]["pad_token"] - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token,) - self.tensor_transform = ToTensor() - - def criterion(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the CTC loss. - - Args: - output (Tensor): Model predictions. - targets (Tensor): Correct output sequence. - - Returns: - Tensor: The CTC loss. - - """ - - # Input lengths on the form [T, B] - input_lengths = torch.full( - size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, - ) - - # Configure target tensors for ctc loss. - targets_ = Tensor([]).to(self.device) - target_lengths = [] - for t in targets: - # Remove padding symbol as it acts as the blank symbol. - t = t[t < 79] - targets_ = torch.cat([targets_, t]) - target_lengths.append(len(t)) - - targets = targets_.type(dtype=torch.long) - target_lengths = ( - torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) - ) - - return self._criterion(output, targets, input_lengths, target_lengths) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - log_probs = self.forward(image) - - raw_pred, _ = greedy_decoder( - predictions=log_probs, - character_mapper=self.mapper, - blank_label=79, - collapse_repeated=True, - ) - - log_probs, _ = log_probs.max(dim=2) - - predicted_characters = "".join(raw_pred[0]) - confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() - - return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/ctc_transformer_model.py b/text_recognizer/models/ctc_transformer_model.py deleted file mode 100644 index 25925f2..0000000 --- a/text_recognizer/models/ctc_transformer_model.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Defines the CTC Transformer Model class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class CTCTransformerModel(Model): - """Model for predicting a sequence of characters from an image of a text line.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.pad_token = dataset_args["args"]["pad_token"] - self.lower = dataset_args["args"]["lower"] - - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,) - - self.tensor_transform = ToTensor() - - def criterion(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the CTC loss. - - Args: - output (Tensor): Model predictions. - targets (Tensor): Correct output sequence. - - Returns: - Tensor: The CTC loss. - - """ - # Input lengths on the form [T, B] - input_lengths = torch.full( - size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, - ) - - # Configure target tensors for ctc loss. - targets_ = Tensor([]).to(self.device) - target_lengths = [] - for t in targets: - # Remove padding symbol as it acts as the blank symbol. - t = t[t < 53] - targets_ = torch.cat([targets_, t]) - target_lengths.append(len(t)) - - targets = targets_.type(dtype=torch.long) - target_lengths = ( - torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) - ) - - return self._criterion(output, targets, input_lengths, target_lengths) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - log_probs = self.forward(image) - - raw_pred, _ = greedy_decoder( - predictions=log_probs, - character_mapper=self.mapper, - blank_label=53, - collapse_repeated=True, - ) - - log_probs, _ = log_probs.max(dim=2) - - predicted_characters = "".join(raw_pred[0]) - confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() - - return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/segmentation_model.py b/text_recognizer/models/segmentation_model.py deleted file mode 100644 index 613108a..0000000 --- a/text_recognizer/models/segmentation_model.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Segmentation model for detecting and segmenting lines.""" -from typing import Callable, Dict, Optional, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.models.base import Model - - -class SegmentationModel(Model): - """Model for segmenting lines in an image.""" - - def __init__( - self, - network_fn: str, - dataset: str, - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.tensor_transform = ToTensor() - self.softmax = nn.Softmax(dim=2) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor: - """Predict on a single input.""" - self.eval() - - if image.dtype is np.uint8: - # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype is torch.uint8 or image.dtype is torch.int64: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - if not torch.is_tensor(image): - image = Tensor(image) - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - - logits = self.forward(image) - - segmentation_mask = torch.argmax(logits, dim=1) - - return segmentation_mask diff --git a/text_recognizer/models/transformer_model.py b/text_recognizer/models/transformer_model.py deleted file mode 100644 index 3f63053..0000000 --- a/text_recognizer/models/transformer_model.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Defines the CNN-Transformer class.""" -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset - -from text_recognizer.datasets import EmnistMapper -import text_recognizer.datasets.transforms as transforms -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class TransformerModel(Model): - """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer.""" - - def __init__( - self, - network_fn: str, - dataset: str, - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.init_token = dataset_args["args"]["init_token"] - self.pad_token = dataset_args["args"]["pad_token"] - self.eos_token = dataset_args["args"]["eos_token"] - self.lower = dataset_args["args"]["lower"] - self.max_len = 100 - - if self._mapper is None: - self._mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - lower=self.lower, - ) - self.tensor_transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])] - ) - self.softmax = nn.Softmax(dim=2) - - @torch.no_grad() - def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: - src = self.network.extract_image_features(image) - - # Added for vqvae transformer. - if isinstance(src, Tuple): - src = src[0] - - memory = self.network.encoder(src) - - confidence_of_predictions = [] - trg_indices = [self.mapper(self.init_token)] - - for _ in range(self.max_len - 1): - trg = torch.tensor(trg_indices, device=self.device)[None, :].long() - trg = self.network.target_embedding(trg) - logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) - - # Convert logits to probabilities. - probs = self.softmax(logits) - - pred_token = probs.argmax(2)[:, -1].item() - confidence = probs.max(2).values[:, -1].item() - - trg_indices.append(pred_token) - confidence_of_predictions.append(confidence) - - if pred_token == self.mapper(self.eos_token): - break - - confidence = np.min(confidence_of_predictions) - predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]]) - - return predicted_characters, confidence - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - - (predicted_characters, confidence_of_prediction,) = self._generate_sentence( - image - ) - - return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/vqvae_model.py b/text_recognizer/models/vqvae_model.py deleted file mode 100644 index 70f6f1f..0000000 --- a/text_recognizer/models/vqvae_model.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Defines the VQVAEModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model - - -class VQVAEModel(Model): - """Model for reconstructing images from codebook.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - """Initializes the CharacterModel.""" - - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.pad_token = dataset_args["args"]["pad_token"] - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token,) - self.tensor_transform = ToTensor() - self.softmax = nn.Softmax(dim=0) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: - """Reconstruction of image. - - Args: - image (Union[np.ndarray, torch.Tensor]): An image containing a character. - - Returns: - Tuple[str, float]: The predicted character and the confidence in the prediction. - - """ - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - image_reconstructed, _ = self.forward(image) - - return image_reconstructed diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 1521355..e69de29 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,43 +0,0 @@ -"""Network modules.""" -from .cnn import CNN -from .cnn_transformer import CNNTransformer -from .crnn import ConvolutionalRecurrentNetwork -from .ctc import greedy_decoder -from .densenet import DenseNet -from .lenet import LeNet -from .metrics import accuracy, cer, wer -from .mlp import MLP -from .residual_network import ResidualNetwork, ResidualNetworkEncoder -from .transducer import load_transducer_loss, TDS2d -from .transformer import Transformer -from .unet import UNet -from .util import sliding_window -from .vit import ViT -from .vq_transformer import VQTransformer -from .vqvae import VQVAE -from .wide_resnet import WideResidualNetwork - -__all__ = [ - "accuracy", - "cer", - "CNN", - "CNNTransformer", - "ConvolutionalRecurrentNetwork", - "DenseNet", - "FCN", - "greedy_decoder", - "MLP", - "LeNet", - "load_transducer_loss", - "ResidualNetwork", - "ResidualNetworkEncoder", - "sliding_window", - "UNet", - "TDS2d", - "Transformer", - "ViT", - "VQTransformer", - "VQVAE", - "wer", - "WideResidualNetwork", -] diff --git a/text_recognizer/networks/beam.py b/text_recognizer/networks/beam.py deleted file mode 100644 index dccccdb..0000000 --- a/text_recognizer/networks/beam.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Implementation of beam search decoder for a sequence to sequence network. - -Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py - -""" -# from typing import List -# from Queue import PriorityQueue - -# from loguru import logger -# import torch -# from torch import nn -# from torch import Tensor -# import torch.nn.functional as F - - -# class Node: -# def __init__( -# self, parent: Node, target_index: int, log_prob: Tensor, length: int -# ) -> None: -# self.parent = parent -# self.target_index = target_index -# self.log_prob = log_prob -# self.length = length -# self.reward = 0.0 - -# def eval(self, alpha: float = 1.0) -> Tensor: -# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward - - -# @torch.no_grad() -# def beam_decoder( -# network, mapper, device, memory: Tensor = None, max_len: int = 97, -# ) -> Tensor: -# beam_width = 10 -# topk = 1 # How many sentences to generate. - -# trg_indices = [mapper(mapper.init_token)] - -# end_nodes = [] - -# node = Node(None, trg_indices, 0, 1) -# nodes = PriorityQueue() - -# nodes.put((node.eval(), node)) -# q_size = 1 - -# # Beam search -# for _ in range(max_len): -# if q_size > 2000: -# logger.warning("Could not decoder input") -# break - -# # Fetch the best node. -# score, n = nodes.get() -# decoder_input = n.target_index - -# if n.target_index == mapper(mapper.eos_token) and n.parent is not None: -# end_nodes.append((score, n)) - -# # If we reached the maximum number of sentences required. -# if len(end_nodes) >= 1: -# break -# else: -# continue - -# # Forward pass with transformer. -# trg = torch.tensor(trg_indices, device=device)[None, :].long() -# trg = network.target_embedding(trg) -# logits = network.decoder(trg=trg, memory=memory, trg_mask=None) -# log_prob = F.log_softmax(logits, dim=2) - -# log_prob, indices = torch.topk(log_prob, beam_width) - -# for new_k in range(beam_width): -# # TODO: continue from here -# token_index = indices[0][new_k].view(1, -1) -# log_p = log_prob[0][new_k].item() - -# node = Node() - -# pass - -# pass diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py deleted file mode 100644 index 1807bb9..0000000 --- a/text_recognizer/networks/cnn.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Implementation of a simple backbone cnn network.""" -from typing import Callable, Dict, Optional, Tuple - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class CNN(nn.Module): - """LeNet network for character prediction.""" - - def __init__( - self, - channels: Tuple[int, ...] = (1, 32, 64, 128), - kernel_sizes: Tuple[int, ...] = (4, 4, 4), - strides: Tuple[int, ...] = (2, 2, 2), - max_pool_kernel: int = 2, - dropout_rate: float = 0.2, - activation: Optional[str] = "relu", - ) -> None: - """Initialization of the LeNet network. - - Args: - channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). - kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). - strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2). - max_pool_kernel (int): 2D max pooling kernel. Defaults to 2. - dropout_rate (float): The dropout rate. Defaults to 0.2. - activation (Optional[str]): The name of non-linear activation function. Defaults to relu. - - Raises: - RuntimeError: if the number of hyperparameters does not match in length. - - """ - super().__init__() - - if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides): - raise RuntimeError("The number of the hyperparameters does not match.") - - self.cnn = self._build_network( - channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation, - ) - - def _build_network( - self, - channels: Tuple[int, ...], - kernel_sizes: Tuple[int, ...], - strides: Tuple[int, ...], - max_pool_kernel: int, - dropout_rate: float, - activation: str, - ) -> nn.Sequential: - # Load activation function. - activation_fn = activation_function(activation) - - channels = list(channels) - in_channels = channels.pop(0) - configuration = zip(channels, kernel_sizes, strides) - - modules = nn.ModuleList([]) - - for i, (out_channels, kernel_size, stride) in enumerate(configuration): - # Add max pool to reduce output size. - if i == len(channels) // 2: - modules.append(nn.MaxPool2d(max_pool_kernel)) - if i == 0: - modules.append( - nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=1 - ) - ) - else: - modules.append( - nn.Sequential( - activation_fn, - nn.BatchNorm2d(in_channels), - nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=1, - ), - ) - ) - - if dropout_rate: - modules.append(nn.Dropout2d(p=dropout_rate)) - - in_channels = out_channels - - return nn.Sequential(*modules) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.cnn(x) diff --git a/text_recognizer/networks/crnn.py b/text_recognizer/networks/crnn.py deleted file mode 100644 index 778e232..0000000 --- a/text_recognizer/networks/crnn.py +++ /dev/null @@ -1,110 +0,0 @@ -"""CRNN for handwritten text recognition.""" -from typing import Dict, Tuple - -from einops import rearrange, reduce -from einops.layers.torch import Rearrange -from loguru import logger -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import configure_backbone - - -class ConvolutionalRecurrentNetwork(nn.Module): - """Network that takes a image of a text line and predicts tokens that are in the image.""" - - def __init__( - self, - backbone: str, - backbone_args: Dict = None, - input_size: int = 128, - hidden_size: int = 128, - bidirectional: bool = False, - num_layers: int = 1, - num_classes: int = 80, - patch_size: Tuple[int, int] = (28, 28), - stride: Tuple[int, int] = (1, 14), - recurrent_cell: str = "lstm", - avg_pool: bool = False, - use_sliding_window: bool = True, - ) -> None: - super().__init__() - self.backbone_args = backbone_args or {} - self.patch_size = patch_size - self.stride = stride - self.sliding_window = ( - self._configure_sliding_window() if use_sliding_window else None - ) - self.input_size = input_size - self.hidden_size = hidden_size - self.backbone = configure_backbone(backbone, backbone_args) - self.bidirectional = bidirectional - self.avg_pool = avg_pool - - if recurrent_cell.upper() in ["LSTM", "GRU"]: - recurrent_cell = getattr(nn, recurrent_cell) - else: - logger.warning( - f"Option {recurrent_cell} not valid, defaulting to LSTM cell." - ) - recurrent_cell = nn.LSTM - - self.rnn = recurrent_cell( - input_size=self.input_size, - hidden_size=self.hidden_size, - bidirectional=bidirectional, - num_layers=num_layers, - ) - - decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size - - self.decoder = nn.Sequential( - nn.Linear(in_features=decoder_size, out_features=num_classes), - nn.LogSoftmax(dim=2), - ) - - def _configure_sliding_window(self) -> nn.Sequential: - return nn.Sequential( - nn.Unfold(kernel_size=self.patch_size, stride=self.stride), - Rearrange( - "b (c h w) t -> b t c h w", - h=self.patch_size[0], - w=self.patch_size[1], - c=1, - ), - ) - - def forward(self, x: Tensor) -> Tensor: - """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - - if self.sliding_window is not None: - # Create image patches with a sliding window kernel. - x = self.sliding_window(x) - - # Rearrange from a sequence of patches for feedforward network. - b, t = x.shape[:2] - x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - - x = self.backbone(x) - - # Average pooling. - if self.avg_pool: - x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) - else: - x = rearrange(x, "(b t) h -> t b h", b=b, t=t) - else: - # Encode the entire image with a CNN, and use the channels as temporal dimension. - x = self.backbone(x) - x = rearrange(x, "b c h w -> b w c h") - if self.adaptive_pool is not None: - x = self.adaptive_pool(x) - x = x.squeeze(3) - - # Sequence predictions. - x, _ = self.rnn(x) - - # Sequence to classification layer. - x = self.decoder(x) - return x diff --git a/text_recognizer/networks/ctc.py b/text_recognizer/networks/ctc.py deleted file mode 100644 index af9b700..0000000 --- a/text_recognizer/networks/ctc.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Decodes the CTC output.""" -from typing import Callable, List, Optional, Tuple - -from einops import rearrange -import torch -from torch import Tensor - -from text_recognizer.datasets.util import EmnistMapper - - -def greedy_decoder( - predictions: Tensor, - targets: Optional[Tensor] = None, - target_lengths: Optional[Tensor] = None, - character_mapper: Optional[Callable] = None, - blank_label: int = 79, - collapse_repeated: bool = True, -) -> Tuple[List[str], List[str]]: - """Greedy CTC decoder. - - Args: - predictions (Tensor): Tenor of network predictions, shape [time, batch, classes]. - targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None. - target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None. - character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults - to None. - blank_label (int): The blank character to be ignored. Defaults to 80. - collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True. - - Returns: - Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets. - - """ - - if character_mapper is None: - character_mapper = EmnistMapper(pad_token="_") # noqa: S106 - - predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") - decoded_predictions = [] - decoded_targets = [] - for i, prediction in enumerate(predictions): - decoded_prediction = [] - decoded_target = [] - if targets is not None and target_lengths is not None: - for target_index in targets[i][: target_lengths[i]]: - if target_index == blank_label: - continue - decoded_target.append(character_mapper(int(target_index))) - decoded_targets.append(decoded_target) - for j, index in enumerate(prediction): - if index != blank_label: - if collapse_repeated and j != 0 and index == prediction[j - 1]: - continue - decoded_prediction.append(index.item()) - decoded_predictions.append( - [character_mapper(int(pred_index)) for pred_index in decoded_prediction] - ) - return decoded_predictions, decoded_targets diff --git a/text_recognizer/networks/densenet.py b/text_recognizer/networks/densenet.py deleted file mode 100644 index 7dc58d9..0000000 --- a/text_recognizer/networks/densenet.py +++ /dev/null @@ -1,225 +0,0 @@ -"""Defines a Densely Connected Convolutional Networks in PyTorch. - -Sources: -https://arxiv.org/abs/1608.06993 -https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py - -""" -from typing import List, Optional, Union - -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class _DenseLayer(nn.Module): - """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2.""" - - def __init__( - self, - in_channels: int, - growth_rate: int, - bn_size: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - activation_fn = activation_function(activation) - self.dense_layer = [ - nn.BatchNorm2d(in_channels), - activation_fn, - nn.Conv2d( - in_channels=in_channels, - out_channels=bn_size * growth_rate, - kernel_size=1, - stride=1, - bias=False, - ), - nn.BatchNorm2d(bn_size * growth_rate), - activation_fn, - nn.Conv2d( - in_channels=bn_size * growth_rate, - out_channels=growth_rate, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - ] - if dropout_rate: - self.dense_layer.append(nn.Dropout(p=dropout_rate)) - - self.dense_layer = nn.Sequential(*self.dense_layer) - - def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor: - if isinstance(x, list): - x = torch.cat(x, 1) - return self.dense_layer(x) - - -class _DenseBlock(nn.Module): - def __init__( - self, - num_layers: int, - in_channels: int, - bn_size: int, - growth_rate: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - self.dense_block = self._build_dense_blocks( - num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation, - ) - - def _build_dense_blocks( - self, - num_layers: int, - in_channels: int, - bn_size: int, - growth_rate: int, - dropout_rate: float, - activation: str = "relu", - ) -> nn.ModuleList: - dense_block = [] - for i in range(num_layers): - dense_block.append( - _DenseLayer( - in_channels=in_channels + i * growth_rate, - growth_rate=growth_rate, - bn_size=bn_size, - dropout_rate=dropout_rate, - activation=activation, - ) - ) - return nn.ModuleList(dense_block) - - def forward(self, x: Tensor) -> Tensor: - feature_maps = [x] - for layer in self.dense_block: - x = layer(feature_maps) - feature_maps.append(x) - return torch.cat(feature_maps, 1) - - -class _Transition(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, activation: str = "relu", - ) -> None: - super().__init__() - activation_fn = activation_function(activation) - self.transition = nn.Sequential( - nn.BatchNorm2d(in_channels), - activation_fn, - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=1, - bias=False, - ), - nn.AvgPool2d(kernel_size=2, stride=2), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.transition(x) - - -class DenseNet(nn.Module): - """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow.""" - - def __init__( - self, - growth_rate: int = 32, - block_config: List[int] = (6, 12, 24, 16), - in_channels: int = 1, - base_channels: int = 64, - num_classes: int = 80, - bn_size: int = 4, - dropout_rate: float = 0, - classifier: bool = True, - activation: str = "relu", - ) -> None: - super().__init__() - self.densenet = self._configure_densenet( - in_channels, - base_channels, - num_classes, - growth_rate, - block_config, - bn_size, - dropout_rate, - classifier, - activation, - ) - - def _configure_densenet( - self, - in_channels: int, - base_channels: int, - num_classes: int, - growth_rate: int, - block_config: List[int], - bn_size: int, - dropout_rate: float, - classifier: bool, - activation: str, - ) -> nn.Sequential: - activation_fn = activation_function(activation) - densenet = [ - nn.Conv2d( - in_channels=in_channels, - out_channels=base_channels, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - nn.BatchNorm2d(base_channels), - activation_fn, - ] - - num_features = base_channels - - for i, num_layers in enumerate(block_config): - densenet.append( - _DenseBlock( - num_layers=num_layers, - in_channels=num_features, - bn_size=bn_size, - growth_rate=growth_rate, - dropout_rate=dropout_rate, - activation=activation, - ) - ) - num_features = num_features + num_layers * growth_rate - if i != len(block_config) - 1: - densenet.append( - _Transition( - in_channels=num_features, - out_channels=num_features // 2, - activation=activation, - ) - ) - num_features = num_features // 2 - - densenet.append(activation_fn) - - if classifier: - densenet.append(nn.AdaptiveAvgPool2d((1, 1))) - densenet.append(Rearrange("b c h w -> b (c h w)")) - densenet.append( - nn.Linear(in_features=num_features, out_features=num_classes) - ) - - return nn.Sequential(*densenet) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass of Densenet.""" - # If batch dimenstion is missing, it will be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.densenet(x) diff --git a/text_recognizer/networks/lenet.py b/text_recognizer/networks/lenet.py deleted file mode 100644 index 527e1a0..0000000 --- a/text_recognizer/networks/lenet.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Implementation of the LeNet network.""" -from typing import Callable, Dict, Optional, Tuple - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class LeNet(nn.Module): - """LeNet network for character prediction.""" - - def __init__( - self, - channels: Tuple[int, ...] = (1, 32, 64), - kernel_sizes: Tuple[int, ...] = (3, 3, 2), - hidden_size: Tuple[int, ...] = (9216, 128), - dropout_rate: float = 0.2, - num_classes: int = 10, - activation_fn: Optional[str] = "relu", - ) -> None: - """Initialization of the LeNet network. - - Args: - channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). - kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). - hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. - Defaults to (9216, 128). - dropout_rate (float): The dropout rate. Defaults to 0.2. - num_classes (int): Number of classes. Defaults to 10. - activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. - - """ - super().__init__() - - activation_fn = activation_function(activation_fn) - - self.layers = [ - nn.Conv2d( - in_channels=channels[0], - out_channels=channels[1], - kernel_size=kernel_sizes[0], - ), - activation_fn, - nn.Conv2d( - in_channels=channels[1], - out_channels=channels[2], - kernel_size=kernel_sizes[1], - ), - activation_fn, - nn.MaxPool2d(kernel_sizes[2]), - nn.Dropout(p=dropout_rate), - Rearrange("b c h w -> b (c h w)"), - nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), - activation_fn, - nn.Dropout(p=dropout_rate), - nn.Linear(in_features=hidden_size[1], out_features=num_classes), - ] - - self.layers = nn.Sequential(*self.layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.layers(x) diff --git a/text_recognizer/networks/metrics.py b/text_recognizer/networks/metrics.py deleted file mode 100644 index 2605731..0000000 --- a/text_recognizer/networks/metrics.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Utility functions for models.""" -from typing import Optional - -from einops import rearrange -import Levenshtein as Lev -import torch -from torch import Tensor - -from text_recognizer.networks import greedy_decoder - - -def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: - """Computes the accuracy. - - Args: - outputs (Tensor): The output from the network. - labels (Tensor): Ground truth labels. - pad_index (int): Padding index. - - Returns: - float: The accuracy for the batch. - - """ - - _, predicted = torch.max(outputs, dim=-1) - - # Mask out the pad tokens - mask = labels != pad_index - - predicted *= mask - labels *= mask - - acc = (predicted == labels).sum().float() / labels.shape[0] - acc = acc.item() - return acc - - -def cer( - outputs: Tensor, - targets: Tensor, - batch_size: Optional[int] = None, - blank_label: Optional[int] = int, -) -> float: - """Computes the character error rate. - - Args: - outputs (Tensor): The output from the network. - targets (Tensor): Ground truth labels. - batch_size (Optional[int]): Batch size if target and output has been flattend. - blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. - - Returns: - float: The cer for the batch. - - """ - if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: - targets = rearrange(targets, "(b t) -> b t", b=batch_size) - outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) - - target_lengths = torch.full( - size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, - ) - decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths, blank_label=blank_label, - ) - - lev_dist = 0 - - for prediction, target in zip(decoded_predictions, decoded_targets): - prediction = "".join(prediction) - target = "".join(target) - prediction, target = ( - prediction.replace(" ", ""), - target.replace(" ", ""), - ) - lev_dist += Lev.distance(prediction, target) - return lev_dist / len(decoded_predictions) - - -def wer( - outputs: Tensor, - targets: Tensor, - batch_size: Optional[int] = None, - blank_label: Optional[int] = int, -) -> float: - """Computes the Word error rate. - - Args: - outputs (Tensor): The output from the network. - targets (Tensor): Ground truth labels. - batch_size (optional[int]): Batch size if target and output has been flattend. - blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. - - Returns: - float: The wer for the batch. - - """ - if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: - targets = rearrange(targets, "(b t) -> b t", b=batch_size) - outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) - - target_lengths = torch.full( - size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, - ) - decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths, blank_label=blank_label, - ) - - lev_dist = 0 - - for prediction, target in zip(decoded_predictions, decoded_targets): - prediction = "".join(prediction) - target = "".join(target) - - b = set(prediction.split() + target.split()) - word2char = dict(zip(b, range(len(b)))) - - w1 = [chr(word2char[w]) for w in prediction.split()] - w2 = [chr(word2char[w]) for w in target.split()] - - lev_dist += Lev.distance("".join(w1), "".join(w2)) - - return lev_dist / len(decoded_predictions) diff --git a/text_recognizer/networks/mlp.py b/text_recognizer/networks/mlp.py deleted file mode 100644 index 1101912..0000000 --- a/text_recognizer/networks/mlp.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Defines the MLP network.""" -from typing import Callable, Dict, List, Optional, Union - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class MLP(nn.Module): - """Multi layered perceptron network.""" - - def __init__( - self, - input_size: int = 784, - num_classes: int = 10, - hidden_size: Union[int, List] = 128, - num_layers: int = 3, - dropout_rate: float = 0.2, - activation_fn: str = "relu", - ) -> None: - """Initialization of the MLP network. - - Args: - input_size (int): The input shape of the network. Defaults to 784. - num_classes (int): Number of classes in the dataset. Defaults to 10. - hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. - num_layers (int): The number of hidden layers. Defaults to 3. - dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. - activation_fn (str): Name of the activation function in the hidden layers. Defaults to - relu. - - """ - super().__init__() - - activation_fn = activation_function(activation_fn) - - if isinstance(hidden_size, int): - hidden_size = [hidden_size] * num_layers - - self.layers = [ - Rearrange("b c h w -> b (c h w)"), - nn.Linear(in_features=input_size, out_features=hidden_size[0]), - activation_fn, - ] - - for i in range(num_layers - 1): - self.layers += [ - nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]), - activation_fn, - ] - - if dropout_rate: - self.layers.append(nn.Dropout(p=dropout_rate)) - - self.layers.append( - nn.Linear(in_features=hidden_size[-1], out_features=num_classes) - ) - - self.layers = nn.Sequential(*self.layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.layers(x) - - @property - def __name__(self) -> str: - """Returns the name of the network.""" - return "mlp" diff --git a/text_recognizer/networks/stn.py b/text_recognizer/networks/stn.py deleted file mode 100644 index e9d216f..0000000 --- a/text_recognizer/networks/stn.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Spatial Transformer Network.""" - -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -class SpatialTransformerNetwork(nn.Module): - """A network with differentiable attention. - - Network that learns how to perform spatial transformations on the input image in order to enhance the - geometric invariance of the model. - - # TODO: add arguments to make it more general. - - """ - - def __init__(self) -> None: - super().__init__() - # Initialize the identity transformation and its weights and biases. - linear = nn.Linear(32, 3 * 2) - linear.weight.data.zero_() - linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) - - self.theta = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.ReLU(inplace=True), - Rearrange("b c h w -> b (c h w)", h=3, w=3), - nn.Linear(in_features=10 * 3 * 3, out_features=32), - nn.ReLU(inplace=True), - linear, - Rearrange("b (row col) -> b row col", row=2, col=3), - ) - - def forward(self, x: Tensor) -> Tensor: - """The spatial transformation.""" - grid = F.affine_grid(self.theta(x), x.shape) - return F.grid_sample(x, grid, align_corners=False) diff --git a/text_recognizer/networks/unet.py b/text_recognizer/networks/unet.py deleted file mode 100644 index 510910f..0000000 --- a/text_recognizer/networks/unet.py +++ /dev/null @@ -1,255 +0,0 @@ -"""UNet for segmentation.""" -from typing import List, Optional, Tuple, Union - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class _ConvBlock(nn.Module): - """Modified UNet convolutional block with dilation.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.channels = channels - self.dropout_rate = dropout_rate - self.kernel_size = kernel_size - self.dilation = dilation - self.padding = padding - self.num_groups = num_groups - self.activation = activation_function(activation) - self.block = self._configure_block() - self.residual_conv = nn.Sequential( - nn.Conv2d( - self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1 - ), - self.activation, - ) - - def _configure_block(self) -> nn.Sequential: - block = [] - for i in range(len(self.channels) - 1): - block += [ - nn.Dropout(p=self.dropout_rate), - nn.GroupNorm(self.num_groups, self.channels[i]), - self.activation, - nn.Conv2d( - self.channels[i], - self.channels[i + 1], - kernel_size=self.kernel_size, - padding=self.padding, - stride=1, - dilation=self.dilation, - ), - ] - - return nn.Sequential(*block) - - def forward(self, x: Tensor) -> Tensor: - """Apply the convolutional block.""" - residual = self.residual_conv(x) - return self.block(x) + residual - - -class _DownSamplingBlock(nn.Module): - """Basic down sampling block.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - pooling_kernel: Union[int, bool] = 2, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.conv_block = _ConvBlock( - channels, - activation, - num_groups, - dropout_rate, - kernel_size, - dilation, - padding, - ) - self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Return the convolutional block output and a down sampled tensor.""" - x = self.conv_block(x) - x_down = self.down_sampling(x) if self.down_sampling is not None else x - - return x_down, x - - -class _UpSamplingBlock(nn.Module): - """The upsampling block of the UNet.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - scale_factor: int = 2, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.conv_block = _ConvBlock( - channels, - activation, - num_groups, - dropout_rate, - kernel_size, - dilation, - padding, - ) - self.up_sampling = nn.Upsample( - scale_factor=scale_factor, mode="bilinear", align_corners=True - ) - - def forward(self, x: Tensor, x_skip: Optional[Tensor] = None) -> Tensor: - """Apply the up sampling and convolutional block.""" - x = self.up_sampling(x) - if x_skip is not None: - x = torch.cat((x, x_skip), dim=1) - return self.conv_block(x) - - -class UNet(nn.Module): - """UNet architecture.""" - - def __init__( - self, - in_channels: int = 1, - base_channels: int = 64, - num_classes: int = 3, - depth: int = 4, - activation: str = "relu", - num_groups: int = 8, - dropout_rate: float = 0.1, - pooling_kernel: int = 2, - scale_factor: int = 2, - kernel_size: Optional[List[int]] = None, - dilation: Optional[List[int]] = None, - padding: Optional[List[int]] = None, - ) -> None: - super().__init__() - self.depth = depth - self.num_groups = num_groups - - if kernel_size is not None and dilation is not None and padding is not None: - if ( - len(kernel_size) != depth - and len(dilation) != depth - and len(padding) != depth - ): - raise RuntimeError( - "Length of convolutional parameters does not match the depth." - ) - self.kernel_size = kernel_size - self.padding = padding - self.dilation = dilation - - else: - self.kernel_size = [3] * depth - self.padding = [1] * depth - self.dilation = [1] * depth - - self.dropout_rate = dropout_rate - self.conv = nn.Conv2d( - in_channels, base_channels, kernel_size=3, stride=1, padding=1 - ) - - channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)] - self.encoder_blocks = self._configure_down_sampling_blocks( - channels, activation, pooling_kernel - ) - self.decoder_blocks = self._configure_up_sampling_blocks( - channels, activation, scale_factor - ) - - self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1) - - def _configure_down_sampling_blocks( - self, channels: List[int], activation: str, pooling_kernel: int - ) -> nn.ModuleList: - blocks = nn.ModuleList([]) - for i in range(len(channels) - 1): - pooling_kernel = pooling_kernel if i < self.depth - 1 else False - dropout_rate = self.dropout_rate if i < 0 else 0 - blocks += [ - _DownSamplingBlock( - [channels[i], channels[i + 1], channels[i + 1]], - activation, - self.num_groups, - pooling_kernel, - dropout_rate, - self.kernel_size[i], - self.dilation[i], - self.padding[i], - ) - ] - - return blocks - - def _configure_up_sampling_blocks( - self, channels: List[int], activation: str, scale_factor: int, - ) -> nn.ModuleList: - channels.reverse() - self.kernel_size.reverse() - self.dilation.reverse() - self.padding.reverse() - return nn.ModuleList( - [ - _UpSamplingBlock( - [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]], - activation, - self.num_groups, - scale_factor, - self.dropout_rate, - self.kernel_size[i], - self.dilation[i], - self.padding[i], - ) - for i in range(len(channels) - 2) - ] - ) - - def _encode(self, x: Tensor) -> List[Tensor]: - x_skips = [] - for block in self.encoder_blocks: - x, x_skip = block(x) - x_skips.append(x_skip) - return x_skips - - def _decode(self, x_skips: List[Tensor]) -> Tensor: - x = x_skips[-1] - for i, block in enumerate(self.decoder_blocks): - x = block(x, x_skips[-(i + 2)]) - return x - - def forward(self, x: Tensor) -> Tensor: - """Forward pass with the UNet model.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - x = self.conv(x) - x_skips = self._encode(x) - x = self._decode(x_skips) - return self.head(x) diff --git a/text_recognizer/networks/vit.py b/text_recognizer/networks/vit.py deleted file mode 100644 index efb3701..0000000 --- a/text_recognizer/networks/vit.py +++ /dev/null @@ -1,150 +0,0 @@ -"""A Vision Transformer. - -Inspired by: -https://openreview.net/pdf?id=YicbFdNTTy - -""" -from typing import Optional, Tuple - -from einops import rearrange, repeat -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import Transformer - - -class ViT(nn.Module): - """Transfomer for image to sequence prediction.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - expansion_dim: int, - patch_dim: Tuple[int, int], - image_size: Tuple[int, int], - dropout_rate: float, - trg_pad_index: int, - max_len: int, - activation: str = "gelu", - ) -> None: - super().__init__() - - self.trg_pad_index = trg_pad_index - self.patch_dim = patch_dim - self.num_patches = image_size[-1] // self.patch_dim[1] - - # Encoder - self.patch_to_embedding = nn.Linear( - self.patch_dim[0] * self.patch_dim[1], hidden_dim - ) - self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) - self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.dropout = nn.Dropout(dropout_rate) - self._init() - - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, - ) - - self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) - - def _init(self) -> None: - nn.init.normal_(self.character_embedding.weight, std=0.02) - # nn.init.normal_(self.pos_embedding.weight, std=0.02) - - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) - ) - - def extract_image_features(self, src: Tensor) -> Tensor: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D - - Args: - src (Tensor): Input tensor. - - Returns: - Tensor: A input src to the transformer. - - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - - patches = rearrange( - src, - "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", - p1=self.patch_dim[0], - p2=self.patch_dim[1], - ) - - # From patches to encoded sequence. - x = self.patch_to_embedding(patches) - b, n, _ = x.shape - cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) - x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding[:, : (n + 1)] - x = self.dropout(x) - - return x - - def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes target tensor with embedding and postion. - - Args: - trg (Tensor): Target tensor. - - Returns: - Tuple[Tensor, Tensor]: Encoded target tensor and target mask. - - """ - _, n = trg.shape - trg = self.character_embedding(trg.long()) - trg += self.pos_embedding[:, :n] - return trg - - def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(h, trg, trg_mask=trg_mask) - - logits = self.head(out) - return logits - - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - h = self.extract_image_features(x) - logits = self.decode_image_features(h, trg) - return logits diff --git a/text_recognizer/paragraph_text_recognizer.py b/text_recognizer/paragraph_text_recognizer.py deleted file mode 100644 index aa39662..0000000 --- a/text_recognizer/paragraph_text_recognizer.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Full model. - -Takes an image and returns the text in the image, by first segmenting the image with a LineDetector, then extracting the -each crop of the image corresponding to line regions, and feeding them to a LinePredictor model that outputs the text -in each region. -""" -from typing import Dict, List, Tuple, Union - -import cv2 -import numpy as np -import torch - -from text_recognizer.models import SegmentationModel, TransformerModel -from text_recognizer.util import read_image - - -class ParagraphTextRecognizor: - """Given an image of a single handwritten character, recognizes it.""" - - def __init__(self, line_predictor_args: Dict, line_detector_args: Dict) -> None: - self._line_predictor = TransformerModel(**line_predictor_args) - self._line_detector = SegmentationModel(**line_detector_args) - self._line_detector.eval() - self._line_predictor.eval() - - def predict(self, image_or_filename: Union[str, np.ndarray]) -> Tuple: - """Takes an image and returns all text within it.""" - image = ( - read_image(image_or_filename) - if isinstance(image_or_filename, str) - else image_or_filename - ) - - line_region_crops = self._get_line_region_crops(image) - processed_line_region_crops = [ - self._process_image_for_line_predictor(image=crop) - for crop in line_region_crops - ] - line_region_strings = [ - self.line_predictor_model.predict_on_image(crop)[0] - for crop in processed_line_region_crops - ] - - return " ".join(line_region_strings), line_region_crops - - def _get_line_region_crops( - self, image: np.ndarray, min_crop_len_factor: float = 0.02 - ) -> List[np.ndarray]: - """Returns all the crops of text lines in a square image.""" - processed_image, scale_down_factor = self._process_image_for_line_detector( - image - ) - line_segmentation = self._line_detector.predict_on_image(processed_image) - bounding_boxes = _find_line_bounding_boxes(line_segmentation) - - bounding_boxes = (bounding_boxes * scale_down_factor).astype(int) - - min_crop_len = int(min_crop_len_factor * min(image.shape[0], image.shape[1])) - line_region_crops = [ - image[y : y + h, x : x + w] - for x, y, w, h in bounding_boxes - if w >= min_crop_len and h >= min_crop_len - ] - return line_region_crops - - def _process_image_for_line_detector( - self, image: np.ndarray - ) -> Tuple[np.ndarray, float]: - """Convert uint8 image to float image with black background with shape self._line_detector.image_shape.""" - resized_image, scale_down_factor = _resize_image_for_line_detector( - image=image, max_shape=self._line_detector.image_shape - ) - resized_image = (1.0 - resized_image / 255).astype("float32") - return resized_image, scale_down_factor - - def _process_image_for_line_predictor(self, image: np.ndarray) -> np.ndarray: - """Preprocessing of image before feeding it to the LinePrediction model. - - Convert uint8 image to float image with black background with shape - self._line_predictor.image_shape while maintaining the image aspect ratio. - - Args: - image (np.ndarray): Crop of text line. - - Returns: - np.ndarray: Processed crop for feeding line predictor. - """ - expected_shape = self._line_detector.image_shape - scale_factor = (np.array(expected_shape) / np.array(image.shape)).min() - scaled_image = cv2.resize( - image, - dsize=None, - fx=scale_factor, - fy=scale_factor, - interpolation=cv2.INTER_AREA, - ) - - pad_with = ( - (0, expected_shape[0] - scaled_image.shape[0]), - (0, expected_shape[1] - scaled_image.shape[1]), - ) - - padded_image = np.pad( - scaled_image, pad_with=pad_with, mode="constant", constant_values=255 - ) - return 1 - padded_image / 255 - - -def _find_line_bounding_boxes(line_segmentation: np.ndarray) -> np.ndarray: - """Given a line segmentation, find bounding boxes for connected-component regions corresponding to non-0 labels.""" - - def _find_line_bounding_boxes_in_channel( - line_segmentation_channel: np.ndarray, - ) -> np.ndarray: - line_segmentation_image = cv2.dilate( - line_segmentation_channel, kernel=np.ones((3, 3)), iterations=1 - ) - line_activation_image = (line_segmentation_image * 255).astype("uint8") - line_activation_image = cv2.threshold( - line_activation_image, 0.5, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU - )[1] - - bounding_cnts, _ = cv2.findContours( - line_segmentation_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE - ) - return np.array([cv2.boundingRect(cnt) for cnt in bounding_cnts]) - - bounding_boxes = np.concatenate( - [ - _find_line_bounding_boxes_in_channel(line_segmentation[:, :, i]) - for i in [1, 2] - ], - axis=0, - ) - - return bounding_boxes[np.argsort(bounding_boxes[:, 1])] - - -def _resize_image_for_line_detector( - image: np.ndarray, max_shape: Tuple[int, int] -) -> Tuple[np.ndarray, float]: - """Resize the image to less than the max_shape while maintaining the aspect ratio.""" - scale_down_factor = max(np.ndarray(image.shape) / np.ndarray(max_shape)) - if scale_down_factor == 1: - return image.copy(), scale_down_factor - resize_image = cv2.resize( - image, - dsize=None, - fx=1 / scale_down_factor, - fy=1 / scale_down_factor, - interpolation=cv2.INTER_AREA, - ) - return resize_image, scale_down_factor |