diff options
28 files changed, 14 insertions, 3124 deletions
diff --git a/poetry.lock b/poetry.lock index fa374b6..6926f1f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -853,6 +853,14 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}  dev = ["codecov (>=2.0.15)", "colorama (>=0.3.4)", "flake8 (>=3.7.7)", "tox (>=3.9.0)", "tox-travis (>=0.12)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "Sphinx (>=2.2.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "black (>=19.10b0)", "isort (>=5.1.1)"]  [[package]] +name = "madgrad" +version = "1.0" +description = "A general purpose PyTorch Optimizer" +category = "main" +optional = false +python-versions = ">=3.6" + +[[package]]  name = "markdown"  version = "3.3.4"  description = "Python implementation of Markdown." @@ -2162,7 +2170,7 @@ multidict = ">=4.0"  [metadata]  lock-version = "1.1"  python-versions = "^3.8" -content-hash = "cffb5a23a46f3be6be0b8ea8289cfb97c8ba9722f869bbfc2af75cb80a737877" +content-hash = "db4253add1258abaf637f127ba576b1ec5e0b415c8f7f93b18ecdca40bc6f042"  [metadata.files]  absl-py = [ @@ -2641,6 +2649,10 @@ loguru = [      {file = "loguru-0.5.3-py3-none-any.whl", hash = "sha256:f8087ac396b5ee5f67c963b495d615ebbceac2796379599820e324419d53667c"},      {file = "loguru-0.5.3.tar.gz", hash = "sha256:b28e72ac7a98be3d28ad28570299a393dfcd32e5e3f6a353dec94675767b6319"},  ] +madgrad = [ +    {file = "madgrad-1.0-py3-none-any.whl", hash = "sha256:cd5239a1274ee025abec14c99d2af06b11783a379da32cbe2f4b07fc81ef20ea"}, +    {file = "madgrad-1.0.tar.gz", hash = "sha256:5a34e1d295ebb2f85fbf9e09ed3b548e27908471bbe2506dda35de5a471c0cbe"}, +]  markdown = [      {file = "Markdown-3.3.4-py3-none-any.whl", hash = "sha256:96c3ba1261de2f7547b46a00ea8463832c921d3f9d6aba3f255a6f71386db20c"},      {file = "Markdown-3.3.4.tar.gz", hash = "sha256:31b5b491868dcc87d6c24b7e3d19a0d730d59d3e46f4eea6430a321bed387a49"}, diff --git a/pyproject.toml b/pyproject.toml index e791dd9..32bdb0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ gtn = "^0.0.0"  sentencepiece = "^0.1.95"  pytorch-lightning = "^1.2.4"  Pillow = "^8.1.2" +madgrad = "^1.0"  [tool.poetry.dev-dependencies]  pytest = "^5.4.2" diff --git a/text_recognizer/__init__.py b/text_recognizer/__init__.py index 3dc1f76..e69de29 100644 --- a/text_recognizer/__init__.py +++ b/text_recognizer/__init__.py @@ -1 +0,0 @@ -__version__ = "0.1.0" diff --git a/text_recognizer/character_predictor.py b/text_recognizer/character_predictor.py deleted file mode 100644 index ad71289..0000000 --- a/text_recognizer/character_predictor.py +++ /dev/null @@ -1,29 +0,0 @@ -"""CharacterPredictor class.""" -from typing import Dict, Tuple, Type, Union - -import numpy as np -from torch import nn - -from text_recognizer import datasets, networks -from text_recognizer.models import CharacterModel -from text_recognizer.util import read_image - - -class CharacterPredictor: -    """Recognizes the character in handwritten character images.""" - -    def __init__(self, network_fn: str, dataset: str) -> None: -        """Intializes the CharacterModel and load the pretrained weights.""" -        network_fn = getattr(networks, network_fn) -        dataset = getattr(datasets, dataset) -        self.model = CharacterModel(network_fn=network_fn, dataset=dataset) -        self.model.eval() -        self.model.use_swa_model() - -    def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: -        """Predict on a single images contianing a handwritten character.""" -        if isinstance(image_or_filename, str): -            image = read_image(image_or_filename, grayscale=True) -        else: -            image = image_or_filename -        return self.model.predict_on_image(image) diff --git a/text_recognizer/data/iam_paragraphs_dataset.py b/text_recognizer/data/iam_paragraphs_dataset.py deleted file mode 100644 index 8ba5142..0000000 --- a/text_recognizer/data/iam_paragraphs_dataset.py +++ /dev/null @@ -1,291 +0,0 @@ -"""IamParagraphsDataset class and functions for data processing.""" -import random -from typing import Callable, Dict, List, Optional, Tuple, Union - -import click -import cv2 -import h5py -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer import util -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.iam_dataset import IamDataset -from text_recognizer.datasets.util import ( -    compute_sha256, -    DATA_DIRNAME, -    download_url, -    EmnistMapper, -) - -INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" -DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" -CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" -GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" - -PARAGRAPH_BUFFER = 50  # Pixels in the IAM form images to leave around the lines. -TEST_FRACTION = 0.2 -SEED = 4711 - - -class IamParagraphsDataset(Dataset): -    """IAM Paragraphs dataset for paragraphs of handwritten text.""" - -    def __init__( -        self, -        train: bool = False, -        subsample_fraction: float = None, -        transform: Optional[Callable] = None, -        target_transform: Optional[Callable] = None, -    ) -> None: -        super().__init__( -            train=train, -            subsample_fraction=subsample_fraction, -            transform=transform, -            target_transform=target_transform, -        ) -        # Load Iam dataset. -        self.iam_dataset = IamDataset() - -        self.num_classes = 3 -        self._input_shape = (256, 256) -        self._output_shape = self._input_shape + (self.num_classes,) -        self._ids = None - -    def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: -        """Fetches data, target pair of the dataset for a given and index or indices. - -        Args: -            index (Union[int, Tensor]): Either a list or int of indices/index. - -        Returns: -            Tuple[Tensor, Tensor]: Data target pair. - -        """ -        if torch.is_tensor(index): -            index = index.tolist() - -        data = self.data[index] -        targets = self.targets[index] - -        seed = np.random.randint(SEED) -        random.seed(seed)  # apply this seed to target tranfsorms -        torch.manual_seed(seed)  # needed for torchvision 0.7 -        if self.transform: -            data = self.transform(data) - -        random.seed(seed)  # apply this seed to target tranfsorms -        torch.manual_seed(seed)  # needed for torchvision 0.7 -        if self.target_transform: -            targets = self.target_transform(targets) - -        return data, targets.long() - -    @property -    def ids(self) -> Tensor: -        """Ids of the dataset.""" -        return self._ids - -    def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: -        """Get data target pair from id.""" -        ind = self.ids.index(id_) -        return self.data[ind], self.targets[ind] - -    def load_or_generate_data(self) -> None: -        """Load or generate dataset data.""" -        num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) -        num_targets = len(self.iam_dataset.line_regions_by_id) - -        if num_actual < num_targets - 2: -            self._process_iam_paragraphs() - -        self._data, self._targets, self._ids = _load_iam_paragraphs() -        self._get_random_split() -        self._subsample() - -    def _get_random_split(self) -> None: -        np.random.seed(SEED) -        num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) -        indices = np.random.permutation(self.data.shape[0]) -        train_indices, test_indices = indices[:num_train], indices[num_train:] -        if self.train: -            self._data = self.data[train_indices] -            self._targets = self.targets[train_indices] -        else: -            self._data = self.data[test_indices] -            self._targets = self.targets[test_indices] - -    def _process_iam_paragraphs(self) -> None: -        """Crop the part with the text. - -        For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are -        self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel -        corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line -        """ -        crop_dims = self._decide_on_crop_dims() -        CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) -        DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) -        GT_DIRNAME.mkdir(parents=True, exist_ok=True) -        logger.info( -            f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" -        ) -        for filename in self.iam_dataset.form_filenames: -            id_ = filename.stem -            line_region = self.iam_dataset.line_regions_by_id[id_] -            _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) - -    def _decide_on_crop_dims(self) -> Tuple[int, int]: -        """Decide on the dimensions to crop out of the form image. - -        Since image width is larger than a comfortable crop around the longest paragraph, -        we will make the crop a square form factor. -        And since the found dimensions 610x610 are pretty close to 512x512, -        we might as well resize crops and make it exactly that, which lets us -        do all kinds of power-of-2 pooling and upsampling should we choose to. - -        Returns: -            Tuple[int, int]: A tuple of crop dimensions. - -        Raises: -            RuntimeError: When max crop height is larger than max crop width. - -        """ - -        sample_form_filename = self.iam_dataset.form_filenames[0] -        sample_image = util.read_image(sample_form_filename, grayscale=True) -        max_crop_width = sample_image.shape[1] -        max_crop_height = _get_max_paragraph_crop_height( -            self.iam_dataset.line_regions_by_id -        ) -        if not max_crop_height <= max_crop_width: -            raise RuntimeError( -                f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" -            ) - -        crop_dims = (max_crop_width, max_crop_width) -        logger.info( -            f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." -        ) -        logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") -        return crop_dims - -    def __repr__(self) -> str: -        """Return info about the dataset.""" -        return ( -            "IAM Paragraph Dataset\n"  # pylint: disable=no-member -            f"Num classes: {self.num_classes}\n" -            f"Data: {self.data.shape}\n" -            f"Targets: {self.targets.shape}\n" -        ) - - -def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: -    heights = [] -    for regions in line_regions_by_id.values(): -        min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER -        max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER -        height = max_y2 - min_y1 -        heights.append(height) -    return max(heights) - - -def _crop_paragraph_image( -    filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple -) -> None: -    image = util.read_image(filename, grayscale=True) - -    min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER -    max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER -    height = max_y2 - min_y1 -    crop_height = crop_dims[0] -    buffer = (crop_height - height) // 2 - -    # Generate image crop. -    image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) -    try: -        image_crop[buffer : buffer + height] = image[min_y1:max_y2] -    except Exception as e:  # pylint: disable=broad-except -        logger.error(f"Rescued {filename}: {e}") -        return - -    # Generate ground truth. -    gt_image = np.zeros_like(image_crop, dtype=np.uint8) -    for index, region in enumerate(line_regions): -        gt_image[ -            (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), -            region["x1"] : region["x2"], -        ] = (index % 2 + 1) - -    # Generate image for debugging. -    import matplotlib.pyplot as plt - -    cmap = plt.get_cmap("Set1") -    image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) -    for index, region in enumerate(line_regions): -        color = [255 * _ for _ in cmap(index)[:-1]] -        cv2.rectangle( -            image_crop_for_debug, -            (region["x1"], region["y1"] - min_y1 + buffer), -            (region["x2"], region["y2"] - min_y1 + buffer), -            color, -            3, -        ) -    image_crop_for_debug = cv2.resize( -        image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA -    ) -    util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") - -    image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) -    util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") - -    gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) -    util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") - - -def _load_iam_paragraphs() -> None: -    logger.info("Loading IAM paragraph crops and ground truth from image files...") -    images = [] -    gt_images = [] -    ids = [] -    for filename in CROPS_DIRNAME.glob("*.jpg"): -        id_ = filename.stem -        image = util.read_image(filename, grayscale=True) -        image = 1.0 - image / 255 - -        gt_filename = GT_DIRNAME / f"{id_}.png" -        gt_image = util.read_image(gt_filename, grayscale=True) - -        images.append(image) -        gt_images.append(gt_image) -        ids.append(id_) -    images = np.array(images).astype(np.float32) -    gt_images = np.array(gt_images).astype(np.uint8) -    ids = np.array(ids) -    return images, gt_images, ids - - -@click.command() -@click.option( -    "--subsample_fraction", -    type=float, -    default=None, -    help="The subsampling factor of the dataset.", -) -def main(subsample_fraction: float) -> None: -    """Load dataset and print info.""" -    logger.info("Creating train set...") -    dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) -    dataset.load_or_generate_data() -    print(dataset) -    logger.info("Creating test set...") -    dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) -    dataset.load_or_generate_data() -    print(dataset) - - -if __name__ == "__main__": -    main() diff --git a/text_recognizer/data/util.py b/text_recognizer/data/util.py deleted file mode 100644 index da87756..0000000 --- a/text_recognizer/data/util.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Util functions for datasets.""" -import hashlib -import json -import os -from pathlib import Path -import string -from typing import Dict, List, Optional, Union -from urllib.request import urlretrieve - -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.datasets import EMNIST -from tqdm import tqdm - -DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" -ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" - - -def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: -    """Extract and saves EMNIST essentials.""" -    labels = emnsit_dataset.classes -    labels.sort() -    mapping = [(i, str(label)) for i, label in enumerate(labels)] -    essentials = { -        "mapping": mapping, -        "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), -    } -    logger.info("Saving emnist essentials...") -    with open(ESSENTIALS_FILENAME, "w") as f: -        json.dump(essentials, f) - - -def download_emnist() -> None: -    """Download the EMNIST dataset via the PyTorch class.""" -    logger.info(f"Data directory is: {DATA_DIRNAME}") -    dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) -    save_emnist_essentials(dataset) - - -class EmnistMapper: -    """Mapper between network output to Emnist character.""" - -    def __init__( -        self, -        pad_token: str, -        init_token: Optional[str] = None, -        eos_token: Optional[str] = None, -        lower: bool = False, -    ) -> None: -        """Loads the emnist essentials file with the mapping and input shape.""" -        self.init_token = init_token -        self.pad_token = pad_token -        self.eos_token = eos_token -        self.lower = lower - -        self.essentials = self._load_emnist_essentials() -        # Load dataset information. -        self._mapping = dict(self.essentials["mapping"]) -        self._augment_emnist_mapping() -        self._inverse_mapping = {v: k for k, v in self.mapping.items()} -        self._num_classes = len(self.mapping) -        self._input_shape = self.essentials["input_shape"] - -    def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: -        """Maps the token to emnist character or character index. - -        If the token is an integer (index), the method will return the Emnist character corresponding to that index. -        If the token is a str (Emnist character), the method will return the corresponding index for that character. - -        Args: -            token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). - -        Returns: -            Union[str, int]: The mapping result. - -        Raises: -            KeyError: If the index or string does not exist in the mapping. - -        """ -        if ( -            (isinstance(token, np.uint8) or isinstance(token, int)) -            or torch.is_tensor(token) -            and int(token) in self.mapping -        ): -            return self.mapping[int(token)] -        elif isinstance(token, str) and token in self._inverse_mapping: -            return self._inverse_mapping[token] -        else: -            raise KeyError(f"Token {token} does not exist in the mappings.") - -    @property -    def mapping(self) -> Dict: -        """Returns the mapping between index and character.""" -        return self._mapping - -    @property -    def inverse_mapping(self) -> Dict: -        """Returns the mapping between character and index.""" -        return self._inverse_mapping - -    @property -    def num_classes(self) -> int: -        """Returns the number of classes in the dataset.""" -        return self._num_classes - -    @property -    def input_shape(self) -> List[int]: -        """Returns the input shape of the Emnist characters.""" -        return self._input_shape - -    def _load_emnist_essentials(self) -> Dict: -        """Load the EMNIST mapping.""" -        with open(str(ESSENTIALS_FILENAME)) as f: -            essentials = json.load(f) -        return essentials - -    def _augment_emnist_mapping(self) -> None: -        """Augment the mapping with extra symbols.""" -        # Extra symbols in IAM dataset -        if self.lower: -            self._mapping = { -                k: str(v) -                for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) -            } - -        extra_symbols = [ -            " ", -            "!", -            '"', -            "#", -            "&", -            "'", -            "(", -            ")", -            "*", -            "+", -            ",", -            "-", -            ".", -            "/", -            ":", -            ";", -            "?", -        ] - -        # padding symbol, and acts as blank symbol as well. -        extra_symbols.append(self.pad_token) - -        if self.init_token is not None: -            extra_symbols.append(self.init_token) - -        if self.eos_token is not None: -            extra_symbols.append(self.eos_token) - -        max_key = max(self.mapping.keys()) -        extra_mapping = {} -        for i, symbol in enumerate(extra_symbols): -            extra_mapping[max_key + 1 + i] = symbol - -        self._mapping = {**self.mapping, **extra_mapping} - - -def compute_sha256(filename: Union[Path, str]) -> str: -    """Returns the SHA256 checksum of a file.""" -    with open(filename, "rb") as f: -        return hashlib.sha256(f.read()).hexdigest() - - -class TqdmUpTo(tqdm): -    """TQDM progress bar when downloading files. - -    From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - -    """ - -    def update_to( -        self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None -    ) -> None: -        """Updates the progress bar. - -        Args: -            blocks (int): Number of blocks transferred so far. Defaults to 1. -            block_size (int): Size of each block, in tqdm units. Defaults to 1. -            total_size (Optional[int]): Total size in tqdm units. Defaults to None. -        """ -        if total_size is not None: -            self.total = total_size  # pylint: disable=attribute-defined-outside-init -        self.update(blocks * block_size - self.n) - - -def download_url(url: str, filename: str) -> None: -    """Downloads a file from url to filename, with a progress bar.""" -    with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: -        urlretrieve(url, filename, reporthook=t.update_to, data=None)  # nosec - - -def _download_raw_dataset(metadata: Dict) -> None: -    if os.path.exists(metadata["filename"]): -        return -    logger.info(f"Downloading raw dataset from {metadata['url']}...") -    download_url(metadata["url"], metadata["filename"]) -    logger.info("Computing SHA-256...") -    sha256 = compute_sha256(metadata["filename"]) -    if sha256 != metadata["sha256"]: -        raise ValueError( -            "Downloaded data file SHA-256 does not match that listed in metadata document." -        ) diff --git a/text_recognizer/line_predictor.py b/text_recognizer/line_predictor.py deleted file mode 100644 index 8e348fe..0000000 --- a/text_recognizer/line_predictor.py +++ /dev/null @@ -1,28 +0,0 @@ -"""LinePredictor class.""" -import importlib -from typing import Tuple, Union - -import numpy as np -from torch import nn - -from text_recognizer import datasets, networks -from text_recognizer.models import TransformerModel -from text_recognizer.util import read_image - - -class LinePredictor: -    """Given an image of a line of handwritten text, recognizes the text content.""" - -    def __init__(self, dataset: str, network_fn: str) -> None: -        network_fn = getattr(networks, network_fn) -        dataset = getattr(datasets, dataset) -        self.model = TransformerModel(network_fn=network_fn, dataset=dataset) -        self.model.eval() - -    def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: -        """Predict on a single images contianing a handwritten character.""" -        if isinstance(image_or_filename, str): -            image = read_image(image_or_filename, grayscale=True) -        else: -            image = image_or_filename -        return self.model.predict_on_image(image) diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py index 7647d7e..e69de29 100644 --- a/text_recognizer/models/__init__.py +++ b/text_recognizer/models/__init__.py @@ -1,18 +0,0 @@ -"""Model modules.""" -from .base import Model -from .character_model import CharacterModel -from .crnn_model import CRNNModel -from .ctc_transformer_model import CTCTransformerModel -from .segmentation_model import SegmentationModel -from .transformer_model import TransformerModel -from .vqvae_model import VQVAEModel - -__all__ = [ -    "CharacterModel", -    "CRNNModel", -    "CTCTransformerModel", -    "Model", -    "SegmentationModel", -    "TransformerModel", -    "VQVAEModel", -] diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py deleted file mode 100644 index 70f4cdb..0000000 --- a/text_recognizer/models/base.py +++ /dev/null @@ -1,455 +0,0 @@ -"""Abstract Model class for PyTorch neural networks.""" - -from abc import ABC, abstractmethod -from glob import glob -import importlib -from pathlib import Path -import re -import shutil -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -from loguru import logger -import torch -from torch import nn -from torch import Tensor -from torch.optim.swa_utils import AveragedModel, SWALR -from torch.utils.data import DataLoader, Dataset, random_split -from torchsummary import summary - -from text_recognizer import datasets -from text_recognizer import networks -from text_recognizer.datasets import EmnistMapper - -WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" - - -class Model(ABC): -    """Abstract Model class with composition of different parts defining a PyTorch neural network.""" - -    def __init__( -        self, -        network_fn: str, -        dataset: str, -        network_args: Optional[Dict] = None, -        dataset_args: Optional[Dict] = None, -        metrics: Optional[Dict] = None, -        criterion: Optional[Callable] = None, -        criterion_args: Optional[Dict] = None, -        optimizer: Optional[Callable] = None, -        optimizer_args: Optional[Dict] = None, -        lr_scheduler: Optional[Callable] = None, -        lr_scheduler_args: Optional[Dict] = None, -        swa_args: Optional[Dict] = None, -        device: Optional[str] = None, -    ) -> None: -        """Base class, to be inherited by model for specific type of data. - -        Args: -            network_fn (str): The name of network. -            dataset (str): The name dataset class. -            network_args (Optional[Dict]): Arguments for the network. Defaults to None. -            dataset_args (Optional[Dict]):  Arguments for the dataset. -            metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. -            criterion (Optional[Callable]): The criterion to evaluate the performance of the network. -                Defaults to None. -            criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None. -            optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. -            optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None. -            lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None. -            lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to -                None. -            swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to -                None. -            device (Optional[str]): Name of the device to train on. Defaults to None. - -        """ -        self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}" -        # Has to be set in subclass. -        self._mapper = None - -        # Placeholder. -        self._input_shape = None - -        self.dataset_name = dataset -        self.dataset = None -        self.dataset_args = dataset_args - -        # Placeholders for datasets. -        self.train_dataset = None -        self.val_dataset = None -        self.test_dataset = None - -        # Stochastic Weight Averaging placeholders. -        self.swa_args = swa_args -        self._swa_scheduler = None -        self._swa_network = None -        self._use_swa_model = False - -        # Experiment directory. -        self.model_dir = None - -        # Flag for configured model. -        self.is_configured = False -        self.data_prepared = False - -        # Flag for stopping training. -        self.stop_training = False - -        self._metrics = metrics if metrics is not None else None - -        # Set the device. -        self._device = ( -            torch.device("cuda" if torch.cuda.is_available() else "cpu") -            if device is None -            else device -        ) - -        # Configure network. -        self._network = None -        self._network_args = network_args -        self._configure_network(network_fn) - -        # Place network on device (GPU). -        self.to_device() - -        # Loss and Optimizer placeholders for before loading. -        self._criterion = criterion -        self.criterion_args = criterion_args - -        self._optimizer = optimizer -        self.optimizer_args = optimizer_args - -        self._lr_scheduler = lr_scheduler -        self.lr_scheduler_args = lr_scheduler_args - -    def configure_model(self) -> None: -        """Configures criterion and optimizers.""" -        if not self.is_configured: -            self._configure_criterion() -            self._configure_optimizers() - -            # Set this flag to true to prevent the model from configuring again. -            self.is_configured = True - -    def prepare_data(self) -> None: -        """Prepare data for training.""" -        # TODO add downloading. -        if not self.data_prepared: -            # Load dataset module. -            self.dataset = getattr(datasets, self.dataset_name) - -            # Load train dataset. -            train_dataset = self.dataset(train=True, **self.dataset_args["args"]) -            train_dataset.load_or_generate_data() - -            # Set input shape. -            self._input_shape = train_dataset.input_shape - -            # Split train dataset into a training and validation partition. -            dataset_len = len(train_dataset) -            train_len = int( -                self.dataset_args["train_args"]["train_fraction"] * dataset_len -            ) -            val_len = dataset_len - train_len -            self.train_dataset, self.val_dataset = random_split( -                train_dataset, lengths=[train_len, val_len] -            ) - -            # Load test dataset. -            self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) -            self.test_dataset.load_or_generate_data() - -            # Set the flag to true to disable ability to load data again. -            self.data_prepared = True - -    def train_dataloader(self) -> DataLoader: -        """Returns data loader for training set.""" -        return DataLoader( -            self.train_dataset, -            batch_size=self.dataset_args["train_args"]["batch_size"], -            num_workers=self.dataset_args["train_args"]["num_workers"], -            shuffle=True, -            pin_memory=True, -        ) - -    def val_dataloader(self) -> DataLoader: -        """Returns data loader for validation set.""" -        return DataLoader( -            self.val_dataset, -            batch_size=self.dataset_args["train_args"]["batch_size"], -            num_workers=self.dataset_args["train_args"]["num_workers"], -            shuffle=True, -            pin_memory=True, -        ) - -    def test_dataloader(self) -> DataLoader: -        """Returns data loader for test set.""" -        return DataLoader( -            self.test_dataset, -            batch_size=self.dataset_args["train_args"]["batch_size"], -            num_workers=self.dataset_args["train_args"]["num_workers"], -            shuffle=False, -            pin_memory=True, -        ) - -    def _configure_network(self, network_fn: Type[nn.Module]) -> None: -        """Loads the network.""" -        # If no network arguments are given, load pretrained weights if they exist. -        # Load network module. -        network_fn = getattr(networks, network_fn) -        if self._network_args is None: -            self.load_weights(network_fn) -        else: -            self._network = network_fn(**self._network_args) - -    def _configure_criterion(self) -> None: -        """Loads the criterion.""" -        self._criterion = ( -            self._criterion(**self.criterion_args) -            if self._criterion is not None -            else None -        ) - -    def _configure_optimizers(self,) -> None: -        """Loads the optimizers.""" -        if self._optimizer is not None: -            self._optimizer = self._optimizer( -                self._network.parameters(), **self.optimizer_args -            ) -        else: -            self._optimizer = None - -        if self._optimizer and self._lr_scheduler is not None: -            if "steps_per_epoch" in self.lr_scheduler_args: -                self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) - -            # Assume lr scheduler should update at each epoch if not specified. -            if "interval" not in self.lr_scheduler_args: -                interval = "epoch" -            else: -                interval = self.lr_scheduler_args.pop("interval") -            self._lr_scheduler = { -                "lr_scheduler": self._lr_scheduler( -                    self._optimizer, **self.lr_scheduler_args -                ), -                "interval": interval, -            } - -        if self.swa_args is not None: -            self._swa_scheduler = { -                "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]), -                "swa_start": self.swa_args["start"], -            } -            self._swa_network = AveragedModel(self._network).to(self.device) - -    @property -    def name(self) -> str: -        """Returns the name of the model.""" -        return self._name - -    @property -    def input_shape(self) -> Tuple[int, ...]: -        """The input shape.""" -        return self._input_shape - -    @property -    def mapper(self) -> EmnistMapper: -        """Returns the mapper that maps between ints and chars.""" -        return self._mapper - -    @property -    def mapping(self) -> Dict: -        """Returns the mapping between network output and Emnist character.""" -        return self._mapper.mapping if self._mapper is not None else None - -    def eval(self) -> None: -        """Sets the network to evaluation mode.""" -        self._network.eval() - -    def train(self) -> None: -        """Sets the network to train mode.""" -        self._network.train() - -    @property -    def device(self) -> str: -        """Device where the weights are stored, i.e. cpu or cuda.""" -        return self._device - -    @property -    def metrics(self) -> Optional[Dict]: -        """Metrics.""" -        return self._metrics - -    @property -    def criterion(self) -> Optional[Callable]: -        """Criterion.""" -        return self._criterion - -    @property -    def optimizer(self) -> Optional[Callable]: -        """Optimizer.""" -        return self._optimizer - -    @property -    def lr_scheduler(self) -> Optional[Dict]: -        """Returns a directory with the learning rate scheduler.""" -        return self._lr_scheduler - -    @property -    def swa_scheduler(self) -> Optional[Dict]: -        """Returns a directory with the stochastic weight averaging scheduler.""" -        return self._swa_scheduler - -    @property -    def swa_network(self) -> Optional[Callable]: -        """Returns the stochastic weight averaging network.""" -        return self._swa_network - -    @property -    def network(self) -> Type[nn.Module]: -        """Neural network.""" -        # Returns the SWA network if available. -        return self._network - -    @property -    def weights_filename(self) -> str: -        """Filepath to the network weights.""" -        WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) -        return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") - -    def use_swa_model(self) -> None: -        """Set to use predictions from SWA model.""" -        if self.swa_network is not None: -            self._use_swa_model = True - -    def forward(self, x: Tensor) -> Tensor: -        """Feedforward pass with the network.""" -        if self._use_swa_model: -            return self.swa_network(x) -        else: -            return self.network(x) - -    def summary( -        self, -        input_shape: Optional[Union[List, Tuple]] = None, -        depth: int = 3, -        device: Optional[str] = None, -    ) -> None: -        """Prints a summary of the network architecture.""" -        device = self.device if device is None else device - -        if input_shape is not None: -            summary(self.network, input_shape, depth=depth, device=device) -        elif self._input_shape is not None: -            input_shape = tuple(self._input_shape) -            summary(self.network, input_shape, depth=depth, device=device) -        else: -            logger.warning("Could not print summary as input shape is not set.") - -    def to_device(self) -> None: -        """Places the network on the device (GPU).""" -        self._network.to(self._device) - -    def _get_state_dict(self) -> Dict: -        """Get the state dict of the model.""" -        state = {"model_state": self._network.state_dict()} - -        if self._optimizer is not None: -            state["optimizer_state"] = self._optimizer.state_dict() - -        if self._lr_scheduler is not None: -            state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() -            state["scheduler_interval"] = self._lr_scheduler["interval"] - -        if self._swa_network is not None: -            state["swa_network"] = self._swa_network.state_dict() - -        return state - -    def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None: -        """Load a previously saved checkpoint. - -        Args: -            checkpoint_path (Path): Path to the experiment with the checkpoint. - -        """ -        checkpoint_path = Path(checkpoint_path) -        self.prepare_data() -        self.configure_model() -        logger.debug("Loading checkpoint...") -        if not checkpoint_path.exists(): -            logger.debug("File does not exist {str(checkpoint_path)}") - -        checkpoint = torch.load(str(checkpoint_path), map_location=self.device) -        self._network.load_state_dict(checkpoint["model_state"]) - -        if self._optimizer is not None: -            self._optimizer.load_state_dict(checkpoint["optimizer_state"]) - -        if self._lr_scheduler is not None: -            # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs -            # with OneCycleLR. -            if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": -                self._lr_scheduler["lr_scheduler"].load_state_dict( -                    checkpoint["scheduler_state"] -                ) -                self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] - -        if self._swa_network is not None: -            self._swa_network.load_state_dict(checkpoint["swa_network"]) - -    def save_checkpoint( -        self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str -    ) -> None: -        """Saves a checkpoint of the model. - -        Args: -            checkpoint_path (Path): Path to the experiment with the checkpoint. -            is_best (bool): If it is the currently best model. -            epoch (int): The epoch of the checkpoint. -            val_metric (str): Validation metric. - -        """ -        state = self._get_state_dict() -        state["is_best"] = is_best -        state["epoch"] = epoch -        state["network_args"] = self._network_args - -        checkpoint_path.mkdir(parents=True, exist_ok=True) - -        logger.debug("Saving checkpoint...") -        filepath = str(checkpoint_path / "last.pt") -        torch.save(state, filepath) - -        if is_best: -            logger.debug( -                f"Found a new best {val_metric}. Saving best checkpoint and weights." -            ) -            shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) - -    def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None: -        """Load the network weights.""" -        logger.debug("Loading network with pretrained weights.") -        filename = glob(self.weights_filename)[0] -        if not filename: -            raise FileNotFoundError( -                f"Could not find any pretrained weights at {self.weights_filename}" -            ) -        # Loading state directory. -        state_dict = torch.load(filename, map_location=torch.device(self._device)) -        self._network_args = state_dict["network_args"] -        weights = state_dict["model_state"] - -        # Initializes the network with trained weights. -        if network_fn is not None: -            self._network = network_fn(**self._network_args) -        self._network.load_state_dict(weights) - -        if "swa_network" in state_dict: -            self._swa_network = AveragedModel(self._network).to(self.device) -            self._swa_network.load_state_dict(state_dict["swa_network"]) - -    def save_weights(self, path: Path) -> None: -        """Save the network weights.""" -        logger.debug("Saving the best network weights.") -        shutil.copyfile(str(path / "best.pt"), self.weights_filename) diff --git a/text_recognizer/models/character_model.py b/text_recognizer/models/character_model.py deleted file mode 100644 index f9944f3..0000000 --- a/text_recognizer/models/character_model.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Defines the CharacterModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model - - -class CharacterModel(Model): -    """Model for predicting characters from images.""" - -    def __init__( -        self, -        network_fn: Type[nn.Module], -        dataset: Type[Dataset], -        network_args: Optional[Dict] = None, -        dataset_args: Optional[Dict] = None, -        metrics: Optional[Dict] = None, -        criterion: Optional[Callable] = None, -        criterion_args: Optional[Dict] = None, -        optimizer: Optional[Callable] = None, -        optimizer_args: Optional[Dict] = None, -        lr_scheduler: Optional[Callable] = None, -        lr_scheduler_args: Optional[Dict] = None, -        swa_args: Optional[Dict] = None, -        device: Optional[str] = None, -    ) -> None: -        """Initializes the CharacterModel.""" - -        super().__init__( -            network_fn, -            dataset, -            network_args, -            dataset_args, -            metrics, -            criterion, -            criterion_args, -            optimizer, -            optimizer_args, -            lr_scheduler, -            lr_scheduler_args, -            swa_args, -            device, -        ) -        self.pad_token = dataset_args["args"]["pad_token"] -        if self._mapper is None: -            self._mapper = EmnistMapper(pad_token=self.pad_token,) -        self.tensor_transform = ToTensor() -        self.softmax = nn.Softmax(dim=0) - -    @torch.no_grad() -    def predict_on_image( -        self, image: Union[np.ndarray, torch.Tensor] -    ) -> Tuple[str, float]: -        """Character prediction on an image. - -        Args: -            image (Union[np.ndarray, torch.Tensor]): An image containing a character. - -        Returns: -            Tuple[str, float]: The predicted character and the confidence in the prediction. - -        """ -        self.eval() - -        if image.dtype == np.uint8: -            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. -            image = self.tensor_transform(image) -        if image.dtype == torch.uint8: -            # If the image is an unscaled tensor. -            image = image.type("torch.FloatTensor") / 255 - -        # Put the image tensor on the device the model weights are on. -        image = image.to(self.device) -        logits = self.forward(image) - -        prediction = self.softmax(logits.squeeze(0)) - -        index = int(torch.argmax(prediction, dim=0)) -        confidence_of_prediction = prediction[index] -        predicted_character = self.mapper(index) - -        return predicted_character, confidence_of_prediction diff --git a/text_recognizer/models/crnn_model.py b/text_recognizer/models/crnn_model.py deleted file mode 100644 index 1e01a83..0000000 --- a/text_recognizer/models/crnn_model.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Defines the CRNNModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class CRNNModel(Model): -    """Model for predicting a sequence of characters from an image of a text line.""" - -    def __init__( -        self, -        network_fn: Type[nn.Module], -        dataset: Type[Dataset], -        network_args: Optional[Dict] = None, -        dataset_args: Optional[Dict] = None, -        metrics: Optional[Dict] = None, -        criterion: Optional[Callable] = None, -        criterion_args: Optional[Dict] = None, -        optimizer: Optional[Callable] = None, -        optimizer_args: Optional[Dict] = None, -        lr_scheduler: Optional[Callable] = None, -        lr_scheduler_args: Optional[Dict] = None, -        swa_args: Optional[Dict] = None, -        device: Optional[str] = None, -    ) -> None: -        super().__init__( -            network_fn, -            dataset, -            network_args, -            dataset_args, -            metrics, -            criterion, -            criterion_args, -            optimizer, -            optimizer_args, -            lr_scheduler, -            lr_scheduler_args, -            swa_args, -            device, -        ) - -        self.pad_token = dataset_args["args"]["pad_token"] -        if self._mapper is None: -            self._mapper = EmnistMapper(pad_token=self.pad_token,) -        self.tensor_transform = ToTensor() - -    def criterion(self, output: Tensor, targets: Tensor) -> Tensor: -        """Computes the CTC loss. - -        Args: -            output (Tensor): Model predictions. -            targets (Tensor): Correct output sequence. - -        Returns: -            Tensor: The CTC loss. - -        """ - -        # Input lengths on the form [T, B] -        input_lengths = torch.full( -            size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, -        ) - -        # Configure target tensors for ctc loss. -        targets_ = Tensor([]).to(self.device) -        target_lengths = [] -        for t in targets: -            # Remove padding symbol as it acts as the blank symbol. -            t = t[t < 79] -            targets_ = torch.cat([targets_, t]) -            target_lengths.append(len(t)) - -        targets = targets_.type(dtype=torch.long) -        target_lengths = ( -            torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) -        ) - -        return self._criterion(output, targets, input_lengths, target_lengths) - -    @torch.no_grad() -    def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: -        """Predict on a single input.""" -        self.eval() - -        if image.dtype == np.uint8: -            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. -            image = self.tensor_transform(image) - -        # Rescale image between 0 and 1. -        if image.dtype == torch.uint8: -            # If the image is an unscaled tensor. -            image = image.type("torch.FloatTensor") / 255 - -        # Put the image tensor on the device the model weights are on. -        image = image.to(self.device) -        log_probs = self.forward(image) - -        raw_pred, _ = greedy_decoder( -            predictions=log_probs, -            character_mapper=self.mapper, -            blank_label=79, -            collapse_repeated=True, -        ) - -        log_probs, _ = log_probs.max(dim=2) - -        predicted_characters = "".join(raw_pred[0]) -        confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() - -        return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/ctc_transformer_model.py b/text_recognizer/models/ctc_transformer_model.py deleted file mode 100644 index 25925f2..0000000 --- a/text_recognizer/models/ctc_transformer_model.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Defines the CTC Transformer Model class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class CTCTransformerModel(Model): -    """Model for predicting a sequence of characters from an image of a text line.""" - -    def __init__( -        self, -        network_fn: Type[nn.Module], -        dataset: Type[Dataset], -        network_args: Optional[Dict] = None, -        dataset_args: Optional[Dict] = None, -        metrics: Optional[Dict] = None, -        criterion: Optional[Callable] = None, -        criterion_args: Optional[Dict] = None, -        optimizer: Optional[Callable] = None, -        optimizer_args: Optional[Dict] = None, -        lr_scheduler: Optional[Callable] = None, -        lr_scheduler_args: Optional[Dict] = None, -        swa_args: Optional[Dict] = None, -        device: Optional[str] = None, -    ) -> None: -        super().__init__( -            network_fn, -            dataset, -            network_args, -            dataset_args, -            metrics, -            criterion, -            criterion_args, -            optimizer, -            optimizer_args, -            lr_scheduler, -            lr_scheduler_args, -            swa_args, -            device, -        ) -        self.pad_token = dataset_args["args"]["pad_token"] -        self.lower = dataset_args["args"]["lower"] - -        if self._mapper is None: -            self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,) - -        self.tensor_transform = ToTensor() - -    def criterion(self, output: Tensor, targets: Tensor) -> Tensor: -        """Computes the CTC loss. - -        Args: -            output (Tensor): Model predictions. -            targets (Tensor): Correct output sequence. - -        Returns: -            Tensor: The CTC loss. - -        """ -        # Input lengths on the form [T, B] -        input_lengths = torch.full( -            size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, -        ) - -        # Configure target tensors for ctc loss. -        targets_ = Tensor([]).to(self.device) -        target_lengths = [] -        for t in targets: -            # Remove padding symbol as it acts as the blank symbol. -            t = t[t < 53] -            targets_ = torch.cat([targets_, t]) -            target_lengths.append(len(t)) - -        targets = targets_.type(dtype=torch.long) -        target_lengths = ( -            torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) -        ) - -        return self._criterion(output, targets, input_lengths, target_lengths) - -    @torch.no_grad() -    def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: -        """Predict on a single input.""" -        self.eval() - -        if image.dtype == np.uint8: -            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. -            image = self.tensor_transform(image) - -        # Rescale image between 0 and 1. -        if image.dtype == torch.uint8: -            # If the image is an unscaled tensor. -            image = image.type("torch.FloatTensor") / 255 - -        # Put the image tensor on the device the model weights are on. -        image = image.to(self.device) -        log_probs = self.forward(image) - -        raw_pred, _ = greedy_decoder( -            predictions=log_probs, -            character_mapper=self.mapper, -            blank_label=53, -            collapse_repeated=True, -        ) - -        log_probs, _ = log_probs.max(dim=2) - -        predicted_characters = "".join(raw_pred[0]) -        confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() - -        return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/segmentation_model.py b/text_recognizer/models/segmentation_model.py deleted file mode 100644 index 613108a..0000000 --- a/text_recognizer/models/segmentation_model.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Segmentation model for detecting and segmenting lines.""" -from typing import Callable, Dict, Optional, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.models.base import Model - - -class SegmentationModel(Model): -    """Model for segmenting lines in an image.""" - -    def __init__( -        self, -        network_fn: str, -        dataset: str, -        network_args: Optional[Dict] = None, -        dataset_args: Optional[Dict] = None, -        metrics: Optional[Dict] = None, -        criterion: Optional[Callable] = None, -        criterion_args: Optional[Dict] = None, -        optimizer: Optional[Callable] = None, -        optimizer_args: Optional[Dict] = None, -        lr_scheduler: Optional[Callable] = None, -        lr_scheduler_args: Optional[Dict] = None, -        swa_args: Optional[Dict] = None, -        device: Optional[str] = None, -    ) -> None: -        super().__init__( -            network_fn, -            dataset, -            network_args, -            dataset_args, -            metrics, -            criterion, -            criterion_args, -            optimizer, -            optimizer_args, -            lr_scheduler, -            lr_scheduler_args, -            swa_args, -            device, -        ) -        self.tensor_transform = ToTensor() -        self.softmax = nn.Softmax(dim=2) - -    @torch.no_grad() -    def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor: -        """Predict on a single input.""" -        self.eval() - -        if image.dtype is np.uint8: -            # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. -            image = self.tensor_transform(image) - -        # Rescale image between 0 and 1. -        if image.dtype is torch.uint8 or image.dtype is torch.int64: -            # If the image is an unscaled tensor. -            image = image.type("torch.FloatTensor") / 255 - -        if not torch.is_tensor(image): -            image = Tensor(image) - -        # Put the image tensor on the device the model weights are on. -        image = image.to(self.device) - -        logits = self.forward(image) - -        segmentation_mask = torch.argmax(logits, dim=1) - -        return segmentation_mask diff --git a/text_recognizer/models/transformer_model.py b/text_recognizer/models/transformer_model.py deleted file mode 100644 index 3f63053..0000000 --- a/text_recognizer/models/transformer_model.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Defines the CNN-Transformer class.""" -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset - -from text_recognizer.datasets import EmnistMapper -import text_recognizer.datasets.transforms as transforms -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class TransformerModel(Model): -    """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer.""" - -    def __init__( -        self, -        network_fn: str, -        dataset: str, -        network_args: Optional[Dict] = None, -        dataset_args: Optional[Dict] = None, -        metrics: Optional[Dict] = None, -        criterion: Optional[Callable] = None, -        criterion_args: Optional[Dict] = None, -        optimizer: Optional[Callable] = None, -        optimizer_args: Optional[Dict] = None, -        lr_scheduler: Optional[Callable] = None, -        lr_scheduler_args: Optional[Dict] = None, -        swa_args: Optional[Dict] = None, -        device: Optional[str] = None, -    ) -> None: -        super().__init__( -            network_fn, -            dataset, -            network_args, -            dataset_args, -            metrics, -            criterion, -            criterion_args, -            optimizer, -            optimizer_args, -            lr_scheduler, -            lr_scheduler_args, -            swa_args, -            device, -        ) -        self.init_token = dataset_args["args"]["init_token"] -        self.pad_token = dataset_args["args"]["pad_token"] -        self.eos_token = dataset_args["args"]["eos_token"] -        self.lower = dataset_args["args"]["lower"] -        self.max_len = 100 - -        if self._mapper is None: -            self._mapper = EmnistMapper( -                init_token=self.init_token, -                pad_token=self.pad_token, -                eos_token=self.eos_token, -                lower=self.lower, -            ) -        self.tensor_transform = transforms.Compose( -            [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])] -        ) -        self.softmax = nn.Softmax(dim=2) - -    @torch.no_grad() -    def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: -        src = self.network.extract_image_features(image) - -        # Added for vqvae transformer. -        if isinstance(src, Tuple): -            src = src[0] - -        memory = self.network.encoder(src) - -        confidence_of_predictions = [] -        trg_indices = [self.mapper(self.init_token)] - -        for _ in range(self.max_len - 1): -            trg = torch.tensor(trg_indices, device=self.device)[None, :].long() -            trg = self.network.target_embedding(trg) -            logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) - -            # Convert logits to probabilities. -            probs = self.softmax(logits) - -            pred_token = probs.argmax(2)[:, -1].item() -            confidence = probs.max(2).values[:, -1].item() - -            trg_indices.append(pred_token) -            confidence_of_predictions.append(confidence) - -            if pred_token == self.mapper(self.eos_token): -                break - -        confidence = np.min(confidence_of_predictions) -        predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]]) - -        return predicted_characters, confidence - -    @torch.no_grad() -    def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: -        """Predict on a single input.""" -        self.eval() - -        if image.dtype == np.uint8: -            # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. -            image = self.tensor_transform(image) - -        # Rescale image between 0 and 1. -        if image.dtype == torch.uint8: -            # If the image is an unscaled tensor. -            image = image.type("torch.FloatTensor") / 255 - -        # Put the image tensor on the device the model weights are on. -        image = image.to(self.device) - -        (predicted_characters, confidence_of_prediction,) = self._generate_sentence( -            image -        ) - -        return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/vqvae_model.py b/text_recognizer/models/vqvae_model.py deleted file mode 100644 index 70f6f1f..0000000 --- a/text_recognizer/models/vqvae_model.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Defines the VQVAEModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model - - -class VQVAEModel(Model): -    """Model for reconstructing images from codebook.""" - -    def __init__( -        self, -        network_fn: Type[nn.Module], -        dataset: Type[Dataset], -        network_args: Optional[Dict] = None, -        dataset_args: Optional[Dict] = None, -        metrics: Optional[Dict] = None, -        criterion: Optional[Callable] = None, -        criterion_args: Optional[Dict] = None, -        optimizer: Optional[Callable] = None, -        optimizer_args: Optional[Dict] = None, -        lr_scheduler: Optional[Callable] = None, -        lr_scheduler_args: Optional[Dict] = None, -        swa_args: Optional[Dict] = None, -        device: Optional[str] = None, -    ) -> None: -        """Initializes the CharacterModel.""" - -        super().__init__( -            network_fn, -            dataset, -            network_args, -            dataset_args, -            metrics, -            criterion, -            criterion_args, -            optimizer, -            optimizer_args, -            lr_scheduler, -            lr_scheduler_args, -            swa_args, -            device, -        ) -        self.pad_token = dataset_args["args"]["pad_token"] -        if self._mapper is None: -            self._mapper = EmnistMapper(pad_token=self.pad_token,) -        self.tensor_transform = ToTensor() -        self.softmax = nn.Softmax(dim=0) - -    @torch.no_grad() -    def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: -        """Reconstruction of image. - -        Args: -            image (Union[np.ndarray, torch.Tensor]): An image containing a character. - -        Returns: -            Tuple[str, float]: The predicted character and the confidence in the prediction. - -        """ -        self.eval() - -        if image.dtype == np.uint8: -            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. -            image = self.tensor_transform(image) -        if image.dtype == torch.uint8: -            # If the image is an unscaled tensor. -            image = image.type("torch.FloatTensor") / 255 - -        # Put the image tensor on the device the model weights are on. -        image = image.to(self.device) -        image_reconstructed, _ = self.forward(image) - -        return image_reconstructed diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 1521355..e69de29 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,43 +0,0 @@ -"""Network modules.""" -from .cnn import CNN -from .cnn_transformer import CNNTransformer -from .crnn import ConvolutionalRecurrentNetwork -from .ctc import greedy_decoder -from .densenet import DenseNet -from .lenet import LeNet -from .metrics import accuracy, cer, wer -from .mlp import MLP -from .residual_network import ResidualNetwork, ResidualNetworkEncoder -from .transducer import load_transducer_loss, TDS2d -from .transformer import Transformer -from .unet import UNet -from .util import sliding_window -from .vit import ViT -from .vq_transformer import VQTransformer -from .vqvae import VQVAE -from .wide_resnet import WideResidualNetwork - -__all__ = [ -    "accuracy", -    "cer", -    "CNN", -    "CNNTransformer", -    "ConvolutionalRecurrentNetwork", -    "DenseNet", -    "FCN", -    "greedy_decoder", -    "MLP", -    "LeNet", -    "load_transducer_loss", -    "ResidualNetwork", -    "ResidualNetworkEncoder", -    "sliding_window", -    "UNet", -    "TDS2d", -    "Transformer", -    "ViT", -    "VQTransformer", -    "VQVAE", -    "wer", -    "WideResidualNetwork", -] diff --git a/text_recognizer/networks/beam.py b/text_recognizer/networks/beam.py deleted file mode 100644 index dccccdb..0000000 --- a/text_recognizer/networks/beam.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Implementation of beam search decoder for a sequence to sequence network. - -Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py - -""" -# from typing import List -# from Queue import PriorityQueue - -# from loguru import logger -# import torch -# from torch import nn -# from torch import Tensor -# import torch.nn.functional as F - - -# class Node: -#     def __init__( -#         self, parent: Node, target_index: int, log_prob: Tensor, length: int -#     ) -> None: -#         self.parent = parent -#         self.target_index = target_index -#         self.log_prob = log_prob -#         self.length = length -#         self.reward = 0.0 - -#     def eval(self, alpha: float = 1.0) -> Tensor: -#         return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward - - -# @torch.no_grad() -# def beam_decoder( -#     network, mapper, device, memory: Tensor = None, max_len: int = 97, -# ) -> Tensor: -#     beam_width = 10 -#     topk = 1  # How many sentences to generate. - -#     trg_indices = [mapper(mapper.init_token)] - -#     end_nodes = [] - -#     node = Node(None, trg_indices, 0, 1) -#     nodes = PriorityQueue() - -#     nodes.put((node.eval(), node)) -#     q_size = 1 - -#     # Beam search -#     for _ in range(max_len): -#         if q_size > 2000: -#             logger.warning("Could not decoder input") -#             break - -#         # Fetch the best node. -#         score, n = nodes.get() -#         decoder_input = n.target_index - -#         if n.target_index == mapper(mapper.eos_token) and n.parent is not None: -#             end_nodes.append((score, n)) - -#             # If we reached the maximum number of sentences required. -#             if len(end_nodes) >= 1: -#                 break -#             else: -#                 continue - -#         # Forward pass with transformer. -#         trg = torch.tensor(trg_indices, device=device)[None, :].long() -#         trg = network.target_embedding(trg) -#         logits = network.decoder(trg=trg, memory=memory, trg_mask=None) -#         log_prob = F.log_softmax(logits, dim=2) - -#         log_prob, indices = torch.topk(log_prob, beam_width) - -#         for new_k in range(beam_width): -#             # TODO: continue from here -#             token_index = indices[0][new_k].view(1, -1) -#             log_p = log_prob[0][new_k].item() - -#             node = Node() - -#             pass - -#     pass diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py deleted file mode 100644 index 1807bb9..0000000 --- a/text_recognizer/networks/cnn.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Implementation of a simple backbone cnn network.""" -from typing import Callable, Dict, Optional, Tuple - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class CNN(nn.Module): -    """LeNet network for character prediction.""" - -    def __init__( -        self, -        channels: Tuple[int, ...] = (1, 32, 64, 128), -        kernel_sizes: Tuple[int, ...] = (4, 4, 4), -        strides: Tuple[int, ...] = (2, 2, 2), -        max_pool_kernel: int = 2, -        dropout_rate: float = 0.2, -        activation: Optional[str] = "relu", -    ) -> None: -        """Initialization of the LeNet network. - -        Args: -            channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). -            kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). -            strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2). -            max_pool_kernel (int): 2D max pooling kernel. Defaults to 2. -            dropout_rate (float): The dropout rate. Defaults to 0.2. -            activation (Optional[str]): The name of non-linear activation function. Defaults to relu. - -        Raises: -            RuntimeError: if the number of hyperparameters does not match in length. - -        """ -        super().__init__() - -        if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides): -            raise RuntimeError("The number of the hyperparameters does not match.") - -        self.cnn = self._build_network( -            channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation, -        ) - -    def _build_network( -        self, -        channels: Tuple[int, ...], -        kernel_sizes: Tuple[int, ...], -        strides: Tuple[int, ...], -        max_pool_kernel: int, -        dropout_rate: float, -        activation: str, -    ) -> nn.Sequential: -        # Load activation function. -        activation_fn = activation_function(activation) - -        channels = list(channels) -        in_channels = channels.pop(0) -        configuration = zip(channels, kernel_sizes, strides) - -        modules = nn.ModuleList([]) - -        for i, (out_channels, kernel_size, stride) in enumerate(configuration): -            # Add max pool to reduce output size. -            if i == len(channels) // 2: -                modules.append(nn.MaxPool2d(max_pool_kernel)) -            if i == 0: -                modules.append( -                    nn.Conv2d( -                        in_channels, out_channels, kernel_size, stride=stride, padding=1 -                    ) -                ) -            else: -                modules.append( -                    nn.Sequential( -                        activation_fn, -                        nn.BatchNorm2d(in_channels), -                        nn.Conv2d( -                            in_channels, -                            out_channels, -                            kernel_size, -                            stride=stride, -                            padding=1, -                        ), -                    ) -                ) - -            if dropout_rate: -                modules.append(nn.Dropout2d(p=dropout_rate)) - -            in_channels = out_channels - -        return nn.Sequential(*modules) - -    def forward(self, x: torch.Tensor) -> torch.Tensor: -        """The feedforward pass.""" -        # If batch dimenstion is missing, it needs to be added. -        if len(x.shape) < 4: -            x = x[(None,) * (4 - len(x.shape))] -        return self.cnn(x) diff --git a/text_recognizer/networks/crnn.py b/text_recognizer/networks/crnn.py deleted file mode 100644 index 778e232..0000000 --- a/text_recognizer/networks/crnn.py +++ /dev/null @@ -1,110 +0,0 @@ -"""CRNN for handwritten text recognition.""" -from typing import Dict, Tuple - -from einops import rearrange, reduce -from einops.layers.torch import Rearrange -from loguru import logger -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import configure_backbone - - -class ConvolutionalRecurrentNetwork(nn.Module): -    """Network that takes a image of a text line and predicts tokens that are in the image.""" - -    def __init__( -        self, -        backbone: str, -        backbone_args: Dict = None, -        input_size: int = 128, -        hidden_size: int = 128, -        bidirectional: bool = False, -        num_layers: int = 1, -        num_classes: int = 80, -        patch_size: Tuple[int, int] = (28, 28), -        stride: Tuple[int, int] = (1, 14), -        recurrent_cell: str = "lstm", -        avg_pool: bool = False, -        use_sliding_window: bool = True, -    ) -> None: -        super().__init__() -        self.backbone_args = backbone_args or {} -        self.patch_size = patch_size -        self.stride = stride -        self.sliding_window = ( -            self._configure_sliding_window() if use_sliding_window else None -        ) -        self.input_size = input_size -        self.hidden_size = hidden_size -        self.backbone = configure_backbone(backbone, backbone_args) -        self.bidirectional = bidirectional -        self.avg_pool = avg_pool - -        if recurrent_cell.upper() in ["LSTM", "GRU"]: -            recurrent_cell = getattr(nn, recurrent_cell) -        else: -            logger.warning( -                f"Option {recurrent_cell} not valid, defaulting to LSTM cell." -            ) -            recurrent_cell = nn.LSTM - -        self.rnn = recurrent_cell( -            input_size=self.input_size, -            hidden_size=self.hidden_size, -            bidirectional=bidirectional, -            num_layers=num_layers, -        ) - -        decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size - -        self.decoder = nn.Sequential( -            nn.Linear(in_features=decoder_size, out_features=num_classes), -            nn.LogSoftmax(dim=2), -        ) - -    def _configure_sliding_window(self) -> nn.Sequential: -        return nn.Sequential( -            nn.Unfold(kernel_size=self.patch_size, stride=self.stride), -            Rearrange( -                "b (c h w) t -> b t c h w", -                h=self.patch_size[0], -                w=self.patch_size[1], -                c=1, -            ), -        ) - -    def forward(self, x: Tensor) -> Tensor: -        """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" -        if len(x.shape) < 4: -            x = x[(None,) * (4 - len(x.shape))] - -        if self.sliding_window is not None: -            # Create image patches with a sliding window kernel. -            x = self.sliding_window(x) - -            # Rearrange from a sequence of patches for feedforward network. -            b, t = x.shape[:2] -            x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - -            x = self.backbone(x) - -            # Average pooling. -            if self.avg_pool: -                x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) -            else: -                x = rearrange(x, "(b t) h -> t b h", b=b, t=t) -        else: -            # Encode the entire image with a CNN, and use the channels as temporal dimension. -            x = self.backbone(x) -            x = rearrange(x, "b c h w -> b w c h") -            if self.adaptive_pool is not None: -                x = self.adaptive_pool(x) -            x = x.squeeze(3) - -        # Sequence predictions. -        x, _ = self.rnn(x) - -        # Sequence to classification layer. -        x = self.decoder(x) -        return x diff --git a/text_recognizer/networks/ctc.py b/text_recognizer/networks/ctc.py deleted file mode 100644 index af9b700..0000000 --- a/text_recognizer/networks/ctc.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Decodes the CTC output.""" -from typing import Callable, List, Optional, Tuple - -from einops import rearrange -import torch -from torch import Tensor - -from text_recognizer.datasets.util import EmnistMapper - - -def greedy_decoder( -    predictions: Tensor, -    targets: Optional[Tensor] = None, -    target_lengths: Optional[Tensor] = None, -    character_mapper: Optional[Callable] = None, -    blank_label: int = 79, -    collapse_repeated: bool = True, -) -> Tuple[List[str], List[str]]: -    """Greedy CTC decoder. - -    Args: -        predictions (Tensor): Tenor of network predictions, shape [time, batch, classes]. -        targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None. -        target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None. -        character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters.  Defaults -            to None. -        blank_label (int): The blank character to be ignored. Defaults to 80. -        collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True. - -    Returns: -        Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets. - -    """ - -    if character_mapper is None: -        character_mapper = EmnistMapper(pad_token="_")  # noqa: S106 - -    predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") -    decoded_predictions = [] -    decoded_targets = [] -    for i, prediction in enumerate(predictions): -        decoded_prediction = [] -        decoded_target = [] -        if targets is not None and target_lengths is not None: -            for target_index in targets[i][: target_lengths[i]]: -                if target_index == blank_label: -                    continue -                decoded_target.append(character_mapper(int(target_index))) -            decoded_targets.append(decoded_target) -        for j, index in enumerate(prediction): -            if index != blank_label: -                if collapse_repeated and j != 0 and index == prediction[j - 1]: -                    continue -                decoded_prediction.append(index.item()) -        decoded_predictions.append( -            [character_mapper(int(pred_index)) for pred_index in decoded_prediction] -        ) -    return decoded_predictions, decoded_targets diff --git a/text_recognizer/networks/densenet.py b/text_recognizer/networks/densenet.py deleted file mode 100644 index 7dc58d9..0000000 --- a/text_recognizer/networks/densenet.py +++ /dev/null @@ -1,225 +0,0 @@ -"""Defines a Densely Connected Convolutional Networks in PyTorch. - -Sources: -https://arxiv.org/abs/1608.06993 -https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py - -""" -from typing import List, Optional, Union - -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class _DenseLayer(nn.Module): -    """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2.""" - -    def __init__( -        self, -        in_channels: int, -        growth_rate: int, -        bn_size: int, -        dropout_rate: float, -        activation: str = "relu", -    ) -> None: -        super().__init__() -        activation_fn = activation_function(activation) -        self.dense_layer = [ -            nn.BatchNorm2d(in_channels), -            activation_fn, -            nn.Conv2d( -                in_channels=in_channels, -                out_channels=bn_size * growth_rate, -                kernel_size=1, -                stride=1, -                bias=False, -            ), -            nn.BatchNorm2d(bn_size * growth_rate), -            activation_fn, -            nn.Conv2d( -                in_channels=bn_size * growth_rate, -                out_channels=growth_rate, -                kernel_size=3, -                stride=1, -                padding=1, -                bias=False, -            ), -        ] -        if dropout_rate: -            self.dense_layer.append(nn.Dropout(p=dropout_rate)) - -        self.dense_layer = nn.Sequential(*self.dense_layer) - -    def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor: -        if isinstance(x, list): -            x = torch.cat(x, 1) -        return self.dense_layer(x) - - -class _DenseBlock(nn.Module): -    def __init__( -        self, -        num_layers: int, -        in_channels: int, -        bn_size: int, -        growth_rate: int, -        dropout_rate: float, -        activation: str = "relu", -    ) -> None: -        super().__init__() -        self.dense_block = self._build_dense_blocks( -            num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation, -        ) - -    def _build_dense_blocks( -        self, -        num_layers: int, -        in_channels: int, -        bn_size: int, -        growth_rate: int, -        dropout_rate: float, -        activation: str = "relu", -    ) -> nn.ModuleList: -        dense_block = [] -        for i in range(num_layers): -            dense_block.append( -                _DenseLayer( -                    in_channels=in_channels + i * growth_rate, -                    growth_rate=growth_rate, -                    bn_size=bn_size, -                    dropout_rate=dropout_rate, -                    activation=activation, -                ) -            ) -        return nn.ModuleList(dense_block) - -    def forward(self, x: Tensor) -> Tensor: -        feature_maps = [x] -        for layer in self.dense_block: -            x = layer(feature_maps) -            feature_maps.append(x) -        return torch.cat(feature_maps, 1) - - -class _Transition(nn.Module): -    def __init__( -        self, in_channels: int, out_channels: int, activation: str = "relu", -    ) -> None: -        super().__init__() -        activation_fn = activation_function(activation) -        self.transition = nn.Sequential( -            nn.BatchNorm2d(in_channels), -            activation_fn, -            nn.Conv2d( -                in_channels=in_channels, -                out_channels=out_channels, -                kernel_size=1, -                stride=1, -                bias=False, -            ), -            nn.AvgPool2d(kernel_size=2, stride=2), -        ) - -    def forward(self, x: Tensor) -> Tensor: -        return self.transition(x) - - -class DenseNet(nn.Module): -    """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow.""" - -    def __init__( -        self, -        growth_rate: int = 32, -        block_config: List[int] = (6, 12, 24, 16), -        in_channels: int = 1, -        base_channels: int = 64, -        num_classes: int = 80, -        bn_size: int = 4, -        dropout_rate: float = 0, -        classifier: bool = True, -        activation: str = "relu", -    ) -> None: -        super().__init__() -        self.densenet = self._configure_densenet( -            in_channels, -            base_channels, -            num_classes, -            growth_rate, -            block_config, -            bn_size, -            dropout_rate, -            classifier, -            activation, -        ) - -    def _configure_densenet( -        self, -        in_channels: int, -        base_channels: int, -        num_classes: int, -        growth_rate: int, -        block_config: List[int], -        bn_size: int, -        dropout_rate: float, -        classifier: bool, -        activation: str, -    ) -> nn.Sequential: -        activation_fn = activation_function(activation) -        densenet = [ -            nn.Conv2d( -                in_channels=in_channels, -                out_channels=base_channels, -                kernel_size=3, -                stride=1, -                padding=1, -                bias=False, -            ), -            nn.BatchNorm2d(base_channels), -            activation_fn, -        ] - -        num_features = base_channels - -        for i, num_layers in enumerate(block_config): -            densenet.append( -                _DenseBlock( -                    num_layers=num_layers, -                    in_channels=num_features, -                    bn_size=bn_size, -                    growth_rate=growth_rate, -                    dropout_rate=dropout_rate, -                    activation=activation, -                ) -            ) -            num_features = num_features + num_layers * growth_rate -            if i != len(block_config) - 1: -                densenet.append( -                    _Transition( -                        in_channels=num_features, -                        out_channels=num_features // 2, -                        activation=activation, -                    ) -                ) -                num_features = num_features // 2 - -        densenet.append(activation_fn) - -        if classifier: -            densenet.append(nn.AdaptiveAvgPool2d((1, 1))) -            densenet.append(Rearrange("b c h w -> b (c h w)")) -            densenet.append( -                nn.Linear(in_features=num_features, out_features=num_classes) -            ) - -        return nn.Sequential(*densenet) - -    def forward(self, x: Tensor) -> Tensor: -        """Forward pass of Densenet.""" -        # If batch dimenstion is missing, it will be added. -        if len(x.shape) < 4: -            x = x[(None,) * (4 - len(x.shape))] -        return self.densenet(x) diff --git a/text_recognizer/networks/lenet.py b/text_recognizer/networks/lenet.py deleted file mode 100644 index 527e1a0..0000000 --- a/text_recognizer/networks/lenet.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Implementation of the LeNet network.""" -from typing import Callable, Dict, Optional, Tuple - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class LeNet(nn.Module): -    """LeNet network for character prediction.""" - -    def __init__( -        self, -        channels: Tuple[int, ...] = (1, 32, 64), -        kernel_sizes: Tuple[int, ...] = (3, 3, 2), -        hidden_size: Tuple[int, ...] = (9216, 128), -        dropout_rate: float = 0.2, -        num_classes: int = 10, -        activation_fn: Optional[str] = "relu", -    ) -> None: -        """Initialization of the LeNet network. - -        Args: -            channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). -            kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). -            hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. -                Defaults to (9216, 128). -            dropout_rate (float): The dropout rate. Defaults to 0.2. -            num_classes (int): Number of classes. Defaults to 10. -            activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. - -        """ -        super().__init__() - -        activation_fn = activation_function(activation_fn) - -        self.layers = [ -            nn.Conv2d( -                in_channels=channels[0], -                out_channels=channels[1], -                kernel_size=kernel_sizes[0], -            ), -            activation_fn, -            nn.Conv2d( -                in_channels=channels[1], -                out_channels=channels[2], -                kernel_size=kernel_sizes[1], -            ), -            activation_fn, -            nn.MaxPool2d(kernel_sizes[2]), -            nn.Dropout(p=dropout_rate), -            Rearrange("b c h w -> b (c h w)"), -            nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), -            activation_fn, -            nn.Dropout(p=dropout_rate), -            nn.Linear(in_features=hidden_size[1], out_features=num_classes), -        ] - -        self.layers = nn.Sequential(*self.layers) - -    def forward(self, x: torch.Tensor) -> torch.Tensor: -        """The feedforward pass.""" -        # If batch dimenstion is missing, it needs to be added. -        if len(x.shape) < 4: -            x = x[(None,) * (4 - len(x.shape))] -        return self.layers(x) diff --git a/text_recognizer/networks/metrics.py b/text_recognizer/networks/metrics.py deleted file mode 100644 index 2605731..0000000 --- a/text_recognizer/networks/metrics.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Utility functions for models.""" -from typing import Optional - -from einops import rearrange -import Levenshtein as Lev -import torch -from torch import Tensor - -from text_recognizer.networks import greedy_decoder - - -def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: -    """Computes the accuracy. - -    Args: -        outputs (Tensor): The output from the network. -        labels (Tensor): Ground truth labels. -        pad_index (int): Padding index. - -    Returns: -        float: The accuracy for the batch. - -    """ - -    _, predicted = torch.max(outputs, dim=-1) - -    # Mask out the pad tokens -    mask = labels != pad_index - -    predicted *= mask -    labels *= mask - -    acc = (predicted == labels).sum().float() / labels.shape[0] -    acc = acc.item() -    return acc - - -def cer( -    outputs: Tensor, -    targets: Tensor, -    batch_size: Optional[int] = None, -    blank_label: Optional[int] = int, -) -> float: -    """Computes the character error rate. - -    Args: -        outputs (Tensor): The output from the network. -        targets (Tensor): Ground truth labels. -        batch_size (Optional[int]): Batch size if target and output has been flattend. -        blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. - -    Returns: -        float: The cer for the batch. - -    """ -    if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: -        targets = rearrange(targets, "(b t) -> b t", b=batch_size) -        outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) - -    target_lengths = torch.full( -        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, -    ) -    decoded_predictions, decoded_targets = greedy_decoder( -        outputs, targets, target_lengths, blank_label=blank_label, -    ) - -    lev_dist = 0 - -    for prediction, target in zip(decoded_predictions, decoded_targets): -        prediction = "".join(prediction) -        target = "".join(target) -        prediction, target = ( -            prediction.replace(" ", ""), -            target.replace(" ", ""), -        ) -        lev_dist += Lev.distance(prediction, target) -    return lev_dist / len(decoded_predictions) - - -def wer( -    outputs: Tensor, -    targets: Tensor, -    batch_size: Optional[int] = None, -    blank_label: Optional[int] = int, -) -> float: -    """Computes the Word error rate. - -    Args: -        outputs (Tensor): The output from the network. -        targets (Tensor): Ground truth labels. -        batch_size (optional[int]): Batch size if target and output has been flattend. -        blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. - -    Returns: -        float: The wer for the batch. - -    """ -    if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: -        targets = rearrange(targets, "(b t) -> b t", b=batch_size) -        outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) - -    target_lengths = torch.full( -        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, -    ) -    decoded_predictions, decoded_targets = greedy_decoder( -        outputs, targets, target_lengths, blank_label=blank_label, -    ) - -    lev_dist = 0 - -    for prediction, target in zip(decoded_predictions, decoded_targets): -        prediction = "".join(prediction) -        target = "".join(target) - -        b = set(prediction.split() + target.split()) -        word2char = dict(zip(b, range(len(b)))) - -        w1 = [chr(word2char[w]) for w in prediction.split()] -        w2 = [chr(word2char[w]) for w in target.split()] - -        lev_dist += Lev.distance("".join(w1), "".join(w2)) - -    return lev_dist / len(decoded_predictions) diff --git a/text_recognizer/networks/mlp.py b/text_recognizer/networks/mlp.py deleted file mode 100644 index 1101912..0000000 --- a/text_recognizer/networks/mlp.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Defines the MLP network.""" -from typing import Callable, Dict, List, Optional, Union - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class MLP(nn.Module): -    """Multi layered perceptron network.""" - -    def __init__( -        self, -        input_size: int = 784, -        num_classes: int = 10, -        hidden_size: Union[int, List] = 128, -        num_layers: int = 3, -        dropout_rate: float = 0.2, -        activation_fn: str = "relu", -    ) -> None: -        """Initialization of the MLP network. - -        Args: -            input_size (int): The input shape of the network. Defaults to 784. -            num_classes (int): Number of classes in the dataset. Defaults to 10. -            hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. -            num_layers (int): The number of hidden layers. Defaults to 3. -            dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. -            activation_fn (str): Name of the activation function in the hidden layers. Defaults to -                relu. - -        """ -        super().__init__() - -        activation_fn = activation_function(activation_fn) - -        if isinstance(hidden_size, int): -            hidden_size = [hidden_size] * num_layers - -        self.layers = [ -            Rearrange("b c h w -> b (c h w)"), -            nn.Linear(in_features=input_size, out_features=hidden_size[0]), -            activation_fn, -        ] - -        for i in range(num_layers - 1): -            self.layers += [ -                nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]), -                activation_fn, -            ] - -            if dropout_rate: -                self.layers.append(nn.Dropout(p=dropout_rate)) - -        self.layers.append( -            nn.Linear(in_features=hidden_size[-1], out_features=num_classes) -        ) - -        self.layers = nn.Sequential(*self.layers) - -    def forward(self, x: torch.Tensor) -> torch.Tensor: -        """The feedforward pass.""" -        # If batch dimenstion is missing, it needs to be added. -        if len(x.shape) < 4: -            x = x[(None,) * (4 - len(x.shape))] -        return self.layers(x) - -    @property -    def __name__(self) -> str: -        """Returns the name of the network.""" -        return "mlp" diff --git a/text_recognizer/networks/stn.py b/text_recognizer/networks/stn.py deleted file mode 100644 index e9d216f..0000000 --- a/text_recognizer/networks/stn.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Spatial Transformer Network.""" - -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -class SpatialTransformerNetwork(nn.Module): -    """A network with differentiable attention. - -    Network that learns how to perform spatial transformations on the input image in order to enhance the -    geometric invariance of the model. - -    # TODO: add arguments to make it more general. - -    """ - -    def __init__(self) -> None: -        super().__init__() -        # Initialize the identity transformation and its weights and biases. -        linear = nn.Linear(32, 3 * 2) -        linear.weight.data.zero_() -        linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) - -        self.theta = nn.Sequential( -            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7), -            nn.MaxPool2d(kernel_size=2, stride=2), -            nn.ReLU(inplace=True), -            nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5), -            nn.MaxPool2d(kernel_size=2, stride=2), -            nn.ReLU(inplace=True), -            Rearrange("b c h w -> b (c h w)", h=3, w=3), -            nn.Linear(in_features=10 * 3 * 3, out_features=32), -            nn.ReLU(inplace=True), -            linear, -            Rearrange("b (row col) -> b row col", row=2, col=3), -        ) - -    def forward(self, x: Tensor) -> Tensor: -        """The spatial transformation.""" -        grid = F.affine_grid(self.theta(x), x.shape) -        return F.grid_sample(x, grid, align_corners=False) diff --git a/text_recognizer/networks/unet.py b/text_recognizer/networks/unet.py deleted file mode 100644 index 510910f..0000000 --- a/text_recognizer/networks/unet.py +++ /dev/null @@ -1,255 +0,0 @@ -"""UNet for segmentation.""" -from typing import List, Optional, Tuple, Union - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class _ConvBlock(nn.Module): -    """Modified UNet convolutional block with dilation.""" - -    def __init__( -        self, -        channels: List[int], -        activation: str, -        num_groups: int, -        dropout_rate: float = 0.1, -        kernel_size: int = 3, -        dilation: int = 1, -        padding: int = 0, -    ) -> None: -        super().__init__() -        self.channels = channels -        self.dropout_rate = dropout_rate -        self.kernel_size = kernel_size -        self.dilation = dilation -        self.padding = padding -        self.num_groups = num_groups -        self.activation = activation_function(activation) -        self.block = self._configure_block() -        self.residual_conv = nn.Sequential( -            nn.Conv2d( -                self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1 -            ), -            self.activation, -        ) - -    def _configure_block(self) -> nn.Sequential: -        block = [] -        for i in range(len(self.channels) - 1): -            block += [ -                nn.Dropout(p=self.dropout_rate), -                nn.GroupNorm(self.num_groups, self.channels[i]), -                self.activation, -                nn.Conv2d( -                    self.channels[i], -                    self.channels[i + 1], -                    kernel_size=self.kernel_size, -                    padding=self.padding, -                    stride=1, -                    dilation=self.dilation, -                ), -            ] - -        return nn.Sequential(*block) - -    def forward(self, x: Tensor) -> Tensor: -        """Apply the convolutional block.""" -        residual = self.residual_conv(x) -        return self.block(x) + residual - - -class _DownSamplingBlock(nn.Module): -    """Basic down sampling block.""" - -    def __init__( -        self, -        channels: List[int], -        activation: str, -        num_groups: int, -        pooling_kernel: Union[int, bool] = 2, -        dropout_rate: float = 0.1, -        kernel_size: int = 3, -        dilation: int = 1, -        padding: int = 0, -    ) -> None: -        super().__init__() -        self.conv_block = _ConvBlock( -            channels, -            activation, -            num_groups, -            dropout_rate, -            kernel_size, -            dilation, -            padding, -        ) -        self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None - -    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: -        """Return the convolutional block output and a down sampled tensor.""" -        x = self.conv_block(x) -        x_down = self.down_sampling(x) if self.down_sampling is not None else x - -        return x_down, x - - -class _UpSamplingBlock(nn.Module): -    """The upsampling block of the UNet.""" - -    def __init__( -        self, -        channels: List[int], -        activation: str, -        num_groups: int, -        scale_factor: int = 2, -        dropout_rate: float = 0.1, -        kernel_size: int = 3, -        dilation: int = 1, -        padding: int = 0, -    ) -> None: -        super().__init__() -        self.conv_block = _ConvBlock( -            channels, -            activation, -            num_groups, -            dropout_rate, -            kernel_size, -            dilation, -            padding, -        ) -        self.up_sampling = nn.Upsample( -            scale_factor=scale_factor, mode="bilinear", align_corners=True -        ) - -    def forward(self, x: Tensor, x_skip: Optional[Tensor] = None) -> Tensor: -        """Apply the up sampling and convolutional block.""" -        x = self.up_sampling(x) -        if x_skip is not None: -            x = torch.cat((x, x_skip), dim=1) -        return self.conv_block(x) - - -class UNet(nn.Module): -    """UNet architecture.""" - -    def __init__( -        self, -        in_channels: int = 1, -        base_channels: int = 64, -        num_classes: int = 3, -        depth: int = 4, -        activation: str = "relu", -        num_groups: int = 8, -        dropout_rate: float = 0.1, -        pooling_kernel: int = 2, -        scale_factor: int = 2, -        kernel_size: Optional[List[int]] = None, -        dilation: Optional[List[int]] = None, -        padding: Optional[List[int]] = None, -    ) -> None: -        super().__init__() -        self.depth = depth -        self.num_groups = num_groups - -        if kernel_size is not None and dilation is not None and padding is not None: -            if ( -                len(kernel_size) != depth -                and len(dilation) != depth -                and len(padding) != depth -            ): -                raise RuntimeError( -                    "Length of convolutional parameters does not match the depth." -                ) -            self.kernel_size = kernel_size -            self.padding = padding -            self.dilation = dilation - -        else: -            self.kernel_size = [3] * depth -            self.padding = [1] * depth -            self.dilation = [1] * depth - -        self.dropout_rate = dropout_rate -        self.conv = nn.Conv2d( -            in_channels, base_channels, kernel_size=3, stride=1, padding=1 -        ) - -        channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)] -        self.encoder_blocks = self._configure_down_sampling_blocks( -            channels, activation, pooling_kernel -        ) -        self.decoder_blocks = self._configure_up_sampling_blocks( -            channels, activation, scale_factor -        ) - -        self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1) - -    def _configure_down_sampling_blocks( -        self, channels: List[int], activation: str, pooling_kernel: int -    ) -> nn.ModuleList: -        blocks = nn.ModuleList([]) -        for i in range(len(channels) - 1): -            pooling_kernel = pooling_kernel if i < self.depth - 1 else False -            dropout_rate = self.dropout_rate if i < 0 else 0 -            blocks += [ -                _DownSamplingBlock( -                    [channels[i], channels[i + 1], channels[i + 1]], -                    activation, -                    self.num_groups, -                    pooling_kernel, -                    dropout_rate, -                    self.kernel_size[i], -                    self.dilation[i], -                    self.padding[i], -                ) -            ] - -        return blocks - -    def _configure_up_sampling_blocks( -        self, channels: List[int], activation: str, scale_factor: int, -    ) -> nn.ModuleList: -        channels.reverse() -        self.kernel_size.reverse() -        self.dilation.reverse() -        self.padding.reverse() -        return nn.ModuleList( -            [ -                _UpSamplingBlock( -                    [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]], -                    activation, -                    self.num_groups, -                    scale_factor, -                    self.dropout_rate, -                    self.kernel_size[i], -                    self.dilation[i], -                    self.padding[i], -                ) -                for i in range(len(channels) - 2) -            ] -        ) - -    def _encode(self, x: Tensor) -> List[Tensor]: -        x_skips = [] -        for block in self.encoder_blocks: -            x, x_skip = block(x) -            x_skips.append(x_skip) -        return x_skips - -    def _decode(self, x_skips: List[Tensor]) -> Tensor: -        x = x_skips[-1] -        for i, block in enumerate(self.decoder_blocks): -            x = block(x, x_skips[-(i + 2)]) -        return x - -    def forward(self, x: Tensor) -> Tensor: -        """Forward pass with the UNet model.""" -        if len(x.shape) < 4: -            x = x[(None,) * (4 - len(x.shape))] -        x = self.conv(x) -        x_skips = self._encode(x) -        x = self._decode(x_skips) -        return self.head(x) diff --git a/text_recognizer/networks/vit.py b/text_recognizer/networks/vit.py deleted file mode 100644 index efb3701..0000000 --- a/text_recognizer/networks/vit.py +++ /dev/null @@ -1,150 +0,0 @@ -"""A Vision Transformer. - -Inspired by: -https://openreview.net/pdf?id=YicbFdNTTy - -""" -from typing import Optional, Tuple - -from einops import rearrange, repeat -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import Transformer - - -class ViT(nn.Module): -    """Transfomer for image to sequence prediction.""" - -    def __init__( -        self, -        num_encoder_layers: int, -        num_decoder_layers: int, -        hidden_dim: int, -        vocab_size: int, -        num_heads: int, -        expansion_dim: int, -        patch_dim: Tuple[int, int], -        image_size: Tuple[int, int], -        dropout_rate: float, -        trg_pad_index: int, -        max_len: int, -        activation: str = "gelu", -    ) -> None: -        super().__init__() - -        self.trg_pad_index = trg_pad_index -        self.patch_dim = patch_dim -        self.num_patches = image_size[-1] // self.patch_dim[1] - -        # Encoder -        self.patch_to_embedding = nn.Linear( -            self.patch_dim[0] * self.patch_dim[1], hidden_dim -        ) -        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) -        self.character_embedding = nn.Embedding(vocab_size, hidden_dim) -        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) -        self.dropout = nn.Dropout(dropout_rate) -        self._init() - -        self.transformer = Transformer( -            num_encoder_layers, -            num_decoder_layers, -            hidden_dim, -            num_heads, -            expansion_dim, -            dropout_rate, -            activation, -        ) - -        self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) - -    def _init(self) -> None: -        nn.init.normal_(self.character_embedding.weight, std=0.02) -        # nn.init.normal_(self.pos_embedding.weight, std=0.02) - -    def _create_trg_mask(self, trg: Tensor) -> Tensor: -        # Move this outside the transformer. -        trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] -        trg_len = trg.shape[1] -        trg_sub_mask = torch.tril( -            torch.ones((trg_len, trg_len), device=trg.device) -        ).bool() -        trg_mask = trg_pad_mask & trg_sub_mask -        return trg_mask - -    def encoder(self, src: Tensor) -> Tensor: -        """Forward pass with the encoder of the transformer.""" -        return self.transformer.encoder(src) - -    def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: -        """Forward pass with the decoder of the transformer + classification head.""" -        return self.head( -            self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) -        ) - -    def extract_image_features(self, src: Tensor) -> Tensor: -        """Extracts image features with a backbone neural network. - -        It seem like the winning idea was to swap channels and width dimension and collapse -        the height dimension. The transformer is learning like a baby with this implementation!!! :D -        Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D - -        Args: -            src (Tensor): Input tensor. - -        Returns: -            Tensor: A input src to the transformer. - -        """ -        # If batch dimension is missing, it needs to be added. -        if len(src.shape) < 4: -            src = src[(None,) * (4 - len(src.shape))] - -        patches = rearrange( -            src, -            "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", -            p1=self.patch_dim[0], -            p2=self.patch_dim[1], -        ) - -        # From patches to encoded sequence. -        x = self.patch_to_embedding(patches) -        b, n, _ = x.shape -        cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) -        x = torch.cat((cls_tokens, x), dim=1) -        x += self.pos_embedding[:, : (n + 1)] -        x = self.dropout(x) - -        return x - -    def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: -        """Encodes target tensor with embedding and postion. - -        Args: -            trg (Tensor): Target tensor. - -        Returns: -            Tuple[Tensor, Tensor]: Encoded target tensor and target mask. - -        """ -        _, n = trg.shape -        trg = self.character_embedding(trg.long()) -        trg += self.pos_embedding[:, :n] -        return trg - -    def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: -        """Takes images features from the backbone and decodes them with the transformer.""" -        trg_mask = self._create_trg_mask(trg) -        trg = self.target_embedding(trg) -        out = self.transformer(h, trg, trg_mask=trg_mask) - -        logits = self.head(out) -        return logits - -    def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: -        """Forward pass with CNN transfomer.""" -        h = self.extract_image_features(x) -        logits = self.decode_image_features(h, trg) -        return logits diff --git a/text_recognizer/paragraph_text_recognizer.py b/text_recognizer/paragraph_text_recognizer.py deleted file mode 100644 index aa39662..0000000 --- a/text_recognizer/paragraph_text_recognizer.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Full model. - -Takes an image and returns the text in the image, by first segmenting the image with a LineDetector, then extracting the -each crop of the image corresponding to line regions, and feeding them to a LinePredictor model that outputs the text -in each region. -""" -from typing import Dict, List, Tuple, Union - -import cv2 -import numpy as np -import torch - -from text_recognizer.models import SegmentationModel, TransformerModel -from text_recognizer.util import read_image - - -class ParagraphTextRecognizor: -    """Given an image of a single handwritten character, recognizes it.""" - -    def __init__(self, line_predictor_args: Dict, line_detector_args: Dict) -> None: -        self._line_predictor = TransformerModel(**line_predictor_args) -        self._line_detector = SegmentationModel(**line_detector_args) -        self._line_detector.eval() -        self._line_predictor.eval() - -    def predict(self, image_or_filename: Union[str, np.ndarray]) -> Tuple: -        """Takes an image and returns all text within it.""" -        image = ( -            read_image(image_or_filename) -            if isinstance(image_or_filename, str) -            else image_or_filename -        ) - -        line_region_crops = self._get_line_region_crops(image) -        processed_line_region_crops = [ -            self._process_image_for_line_predictor(image=crop) -            for crop in line_region_crops -        ] -        line_region_strings = [ -            self.line_predictor_model.predict_on_image(crop)[0] -            for crop in processed_line_region_crops -        ] - -        return " ".join(line_region_strings), line_region_crops - -    def _get_line_region_crops( -        self, image: np.ndarray, min_crop_len_factor: float = 0.02 -    ) -> List[np.ndarray]: -        """Returns all the crops of text lines in a square image.""" -        processed_image, scale_down_factor = self._process_image_for_line_detector( -            image -        ) -        line_segmentation = self._line_detector.predict_on_image(processed_image) -        bounding_boxes = _find_line_bounding_boxes(line_segmentation) - -        bounding_boxes = (bounding_boxes * scale_down_factor).astype(int) - -        min_crop_len = int(min_crop_len_factor * min(image.shape[0], image.shape[1])) -        line_region_crops = [ -            image[y : y + h, x : x + w] -            for x, y, w, h in bounding_boxes -            if w >= min_crop_len and h >= min_crop_len -        ] -        return line_region_crops - -    def _process_image_for_line_detector( -        self, image: np.ndarray -    ) -> Tuple[np.ndarray, float]: -        """Convert uint8 image to float image with black background with shape self._line_detector.image_shape.""" -        resized_image, scale_down_factor = _resize_image_for_line_detector( -            image=image, max_shape=self._line_detector.image_shape -        ) -        resized_image = (1.0 - resized_image / 255).astype("float32") -        return resized_image, scale_down_factor - -    def _process_image_for_line_predictor(self, image: np.ndarray) -> np.ndarray: -        """Preprocessing of image before feeding it to the LinePrediction model. - -        Convert uint8 image to float image with black background with shape -        self._line_predictor.image_shape while maintaining the image aspect ratio. - -        Args: -            image (np.ndarray): Crop of text line. - -        Returns: -            np.ndarray: Processed crop for feeding line predictor. -        """ -        expected_shape = self._line_detector.image_shape -        scale_factor = (np.array(expected_shape) / np.array(image.shape)).min() -        scaled_image = cv2.resize( -            image, -            dsize=None, -            fx=scale_factor, -            fy=scale_factor, -            interpolation=cv2.INTER_AREA, -        ) - -        pad_with = ( -            (0, expected_shape[0] - scaled_image.shape[0]), -            (0, expected_shape[1] - scaled_image.shape[1]), -        ) - -        padded_image = np.pad( -            scaled_image, pad_with=pad_with, mode="constant", constant_values=255 -        ) -        return 1 - padded_image / 255 - - -def _find_line_bounding_boxes(line_segmentation: np.ndarray) -> np.ndarray: -    """Given a line segmentation, find bounding boxes for connected-component regions corresponding to non-0 labels.""" - -    def _find_line_bounding_boxes_in_channel( -        line_segmentation_channel: np.ndarray, -    ) -> np.ndarray: -        line_segmentation_image = cv2.dilate( -            line_segmentation_channel, kernel=np.ones((3, 3)), iterations=1 -        ) -        line_activation_image = (line_segmentation_image * 255).astype("uint8") -        line_activation_image = cv2.threshold( -            line_activation_image, 0.5, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU -        )[1] - -        bounding_cnts, _ = cv2.findContours( -            line_segmentation_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE -        ) -        return np.array([cv2.boundingRect(cnt) for cnt in bounding_cnts]) - -    bounding_boxes = np.concatenate( -        [ -            _find_line_bounding_boxes_in_channel(line_segmentation[:, :, i]) -            for i in [1, 2] -        ], -        axis=0, -    ) - -    return bounding_boxes[np.argsort(bounding_boxes[:, 1])] - - -def _resize_image_for_line_detector( -    image: np.ndarray, max_shape: Tuple[int, int] -) -> Tuple[np.ndarray, float]: -    """Resize the image to less than the max_shape while maintaining the aspect ratio.""" -    scale_down_factor = max(np.ndarray(image.shape) / np.ndarray(max_shape)) -    if scale_down_factor == 1: -        return image.copy(), scale_down_factor -    resize_image = cv2.resize( -        image, -        dsize=None, -        fx=1 / scale_down_factor, -        fy=1 / scale_down_factor, -        interpolation=cv2.INTER_AREA, -    ) -    return resize_image, scale_down_factor  |