diff options
Diffstat (limited to 'text_recognizer/data')
| -rw-r--r-- | text_recognizer/data/iam_paragraphs_dataset.py | 291 | ||||
| -rw-r--r-- | text_recognizer/data/util.py | 209 | 
2 files changed, 0 insertions, 500 deletions
diff --git a/text_recognizer/data/iam_paragraphs_dataset.py b/text_recognizer/data/iam_paragraphs_dataset.py deleted file mode 100644 index 8ba5142..0000000 --- a/text_recognizer/data/iam_paragraphs_dataset.py +++ /dev/null @@ -1,291 +0,0 @@ -"""IamParagraphsDataset class and functions for data processing.""" -import random -from typing import Callable, Dict, List, Optional, Tuple, Union - -import click -import cv2 -import h5py -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer import util -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.iam_dataset import IamDataset -from text_recognizer.datasets.util import ( -    compute_sha256, -    DATA_DIRNAME, -    download_url, -    EmnistMapper, -) - -INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" -DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" -CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" -GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" - -PARAGRAPH_BUFFER = 50  # Pixels in the IAM form images to leave around the lines. -TEST_FRACTION = 0.2 -SEED = 4711 - - -class IamParagraphsDataset(Dataset): -    """IAM Paragraphs dataset for paragraphs of handwritten text.""" - -    def __init__( -        self, -        train: bool = False, -        subsample_fraction: float = None, -        transform: Optional[Callable] = None, -        target_transform: Optional[Callable] = None, -    ) -> None: -        super().__init__( -            train=train, -            subsample_fraction=subsample_fraction, -            transform=transform, -            target_transform=target_transform, -        ) -        # Load Iam dataset. -        self.iam_dataset = IamDataset() - -        self.num_classes = 3 -        self._input_shape = (256, 256) -        self._output_shape = self._input_shape + (self.num_classes,) -        self._ids = None - -    def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: -        """Fetches data, target pair of the dataset for a given and index or indices. - -        Args: -            index (Union[int, Tensor]): Either a list or int of indices/index. - -        Returns: -            Tuple[Tensor, Tensor]: Data target pair. - -        """ -        if torch.is_tensor(index): -            index = index.tolist() - -        data = self.data[index] -        targets = self.targets[index] - -        seed = np.random.randint(SEED) -        random.seed(seed)  # apply this seed to target tranfsorms -        torch.manual_seed(seed)  # needed for torchvision 0.7 -        if self.transform: -            data = self.transform(data) - -        random.seed(seed)  # apply this seed to target tranfsorms -        torch.manual_seed(seed)  # needed for torchvision 0.7 -        if self.target_transform: -            targets = self.target_transform(targets) - -        return data, targets.long() - -    @property -    def ids(self) -> Tensor: -        """Ids of the dataset.""" -        return self._ids - -    def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: -        """Get data target pair from id.""" -        ind = self.ids.index(id_) -        return self.data[ind], self.targets[ind] - -    def load_or_generate_data(self) -> None: -        """Load or generate dataset data.""" -        num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) -        num_targets = len(self.iam_dataset.line_regions_by_id) - -        if num_actual < num_targets - 2: -            self._process_iam_paragraphs() - -        self._data, self._targets, self._ids = _load_iam_paragraphs() -        self._get_random_split() -        self._subsample() - -    def _get_random_split(self) -> None: -        np.random.seed(SEED) -        num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) -        indices = np.random.permutation(self.data.shape[0]) -        train_indices, test_indices = indices[:num_train], indices[num_train:] -        if self.train: -            self._data = self.data[train_indices] -            self._targets = self.targets[train_indices] -        else: -            self._data = self.data[test_indices] -            self._targets = self.targets[test_indices] - -    def _process_iam_paragraphs(self) -> None: -        """Crop the part with the text. - -        For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are -        self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel -        corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line -        """ -        crop_dims = self._decide_on_crop_dims() -        CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) -        DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) -        GT_DIRNAME.mkdir(parents=True, exist_ok=True) -        logger.info( -            f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" -        ) -        for filename in self.iam_dataset.form_filenames: -            id_ = filename.stem -            line_region = self.iam_dataset.line_regions_by_id[id_] -            _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) - -    def _decide_on_crop_dims(self) -> Tuple[int, int]: -        """Decide on the dimensions to crop out of the form image. - -        Since image width is larger than a comfortable crop around the longest paragraph, -        we will make the crop a square form factor. -        And since the found dimensions 610x610 are pretty close to 512x512, -        we might as well resize crops and make it exactly that, which lets us -        do all kinds of power-of-2 pooling and upsampling should we choose to. - -        Returns: -            Tuple[int, int]: A tuple of crop dimensions. - -        Raises: -            RuntimeError: When max crop height is larger than max crop width. - -        """ - -        sample_form_filename = self.iam_dataset.form_filenames[0] -        sample_image = util.read_image(sample_form_filename, grayscale=True) -        max_crop_width = sample_image.shape[1] -        max_crop_height = _get_max_paragraph_crop_height( -            self.iam_dataset.line_regions_by_id -        ) -        if not max_crop_height <= max_crop_width: -            raise RuntimeError( -                f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" -            ) - -        crop_dims = (max_crop_width, max_crop_width) -        logger.info( -            f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." -        ) -        logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") -        return crop_dims - -    def __repr__(self) -> str: -        """Return info about the dataset.""" -        return ( -            "IAM Paragraph Dataset\n"  # pylint: disable=no-member -            f"Num classes: {self.num_classes}\n" -            f"Data: {self.data.shape}\n" -            f"Targets: {self.targets.shape}\n" -        ) - - -def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: -    heights = [] -    for regions in line_regions_by_id.values(): -        min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER -        max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER -        height = max_y2 - min_y1 -        heights.append(height) -    return max(heights) - - -def _crop_paragraph_image( -    filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple -) -> None: -    image = util.read_image(filename, grayscale=True) - -    min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER -    max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER -    height = max_y2 - min_y1 -    crop_height = crop_dims[0] -    buffer = (crop_height - height) // 2 - -    # Generate image crop. -    image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) -    try: -        image_crop[buffer : buffer + height] = image[min_y1:max_y2] -    except Exception as e:  # pylint: disable=broad-except -        logger.error(f"Rescued {filename}: {e}") -        return - -    # Generate ground truth. -    gt_image = np.zeros_like(image_crop, dtype=np.uint8) -    for index, region in enumerate(line_regions): -        gt_image[ -            (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), -            region["x1"] : region["x2"], -        ] = (index % 2 + 1) - -    # Generate image for debugging. -    import matplotlib.pyplot as plt - -    cmap = plt.get_cmap("Set1") -    image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) -    for index, region in enumerate(line_regions): -        color = [255 * _ for _ in cmap(index)[:-1]] -        cv2.rectangle( -            image_crop_for_debug, -            (region["x1"], region["y1"] - min_y1 + buffer), -            (region["x2"], region["y2"] - min_y1 + buffer), -            color, -            3, -        ) -    image_crop_for_debug = cv2.resize( -        image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA -    ) -    util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") - -    image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) -    util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") - -    gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) -    util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") - - -def _load_iam_paragraphs() -> None: -    logger.info("Loading IAM paragraph crops and ground truth from image files...") -    images = [] -    gt_images = [] -    ids = [] -    for filename in CROPS_DIRNAME.glob("*.jpg"): -        id_ = filename.stem -        image = util.read_image(filename, grayscale=True) -        image = 1.0 - image / 255 - -        gt_filename = GT_DIRNAME / f"{id_}.png" -        gt_image = util.read_image(gt_filename, grayscale=True) - -        images.append(image) -        gt_images.append(gt_image) -        ids.append(id_) -    images = np.array(images).astype(np.float32) -    gt_images = np.array(gt_images).astype(np.uint8) -    ids = np.array(ids) -    return images, gt_images, ids - - -@click.command() -@click.option( -    "--subsample_fraction", -    type=float, -    default=None, -    help="The subsampling factor of the dataset.", -) -def main(subsample_fraction: float) -> None: -    """Load dataset and print info.""" -    logger.info("Creating train set...") -    dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) -    dataset.load_or_generate_data() -    print(dataset) -    logger.info("Creating test set...") -    dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) -    dataset.load_or_generate_data() -    print(dataset) - - -if __name__ == "__main__": -    main() diff --git a/text_recognizer/data/util.py b/text_recognizer/data/util.py deleted file mode 100644 index da87756..0000000 --- a/text_recognizer/data/util.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Util functions for datasets.""" -import hashlib -import json -import os -from pathlib import Path -import string -from typing import Dict, List, Optional, Union -from urllib.request import urlretrieve - -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.datasets import EMNIST -from tqdm import tqdm - -DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" -ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" - - -def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: -    """Extract and saves EMNIST essentials.""" -    labels = emnsit_dataset.classes -    labels.sort() -    mapping = [(i, str(label)) for i, label in enumerate(labels)] -    essentials = { -        "mapping": mapping, -        "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), -    } -    logger.info("Saving emnist essentials...") -    with open(ESSENTIALS_FILENAME, "w") as f: -        json.dump(essentials, f) - - -def download_emnist() -> None: -    """Download the EMNIST dataset via the PyTorch class.""" -    logger.info(f"Data directory is: {DATA_DIRNAME}") -    dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) -    save_emnist_essentials(dataset) - - -class EmnistMapper: -    """Mapper between network output to Emnist character.""" - -    def __init__( -        self, -        pad_token: str, -        init_token: Optional[str] = None, -        eos_token: Optional[str] = None, -        lower: bool = False, -    ) -> None: -        """Loads the emnist essentials file with the mapping and input shape.""" -        self.init_token = init_token -        self.pad_token = pad_token -        self.eos_token = eos_token -        self.lower = lower - -        self.essentials = self._load_emnist_essentials() -        # Load dataset information. -        self._mapping = dict(self.essentials["mapping"]) -        self._augment_emnist_mapping() -        self._inverse_mapping = {v: k for k, v in self.mapping.items()} -        self._num_classes = len(self.mapping) -        self._input_shape = self.essentials["input_shape"] - -    def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: -        """Maps the token to emnist character or character index. - -        If the token is an integer (index), the method will return the Emnist character corresponding to that index. -        If the token is a str (Emnist character), the method will return the corresponding index for that character. - -        Args: -            token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). - -        Returns: -            Union[str, int]: The mapping result. - -        Raises: -            KeyError: If the index or string does not exist in the mapping. - -        """ -        if ( -            (isinstance(token, np.uint8) or isinstance(token, int)) -            or torch.is_tensor(token) -            and int(token) in self.mapping -        ): -            return self.mapping[int(token)] -        elif isinstance(token, str) and token in self._inverse_mapping: -            return self._inverse_mapping[token] -        else: -            raise KeyError(f"Token {token} does not exist in the mappings.") - -    @property -    def mapping(self) -> Dict: -        """Returns the mapping between index and character.""" -        return self._mapping - -    @property -    def inverse_mapping(self) -> Dict: -        """Returns the mapping between character and index.""" -        return self._inverse_mapping - -    @property -    def num_classes(self) -> int: -        """Returns the number of classes in the dataset.""" -        return self._num_classes - -    @property -    def input_shape(self) -> List[int]: -        """Returns the input shape of the Emnist characters.""" -        return self._input_shape - -    def _load_emnist_essentials(self) -> Dict: -        """Load the EMNIST mapping.""" -        with open(str(ESSENTIALS_FILENAME)) as f: -            essentials = json.load(f) -        return essentials - -    def _augment_emnist_mapping(self) -> None: -        """Augment the mapping with extra symbols.""" -        # Extra symbols in IAM dataset -        if self.lower: -            self._mapping = { -                k: str(v) -                for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) -            } - -        extra_symbols = [ -            " ", -            "!", -            '"', -            "#", -            "&", -            "'", -            "(", -            ")", -            "*", -            "+", -            ",", -            "-", -            ".", -            "/", -            ":", -            ";", -            "?", -        ] - -        # padding symbol, and acts as blank symbol as well. -        extra_symbols.append(self.pad_token) - -        if self.init_token is not None: -            extra_symbols.append(self.init_token) - -        if self.eos_token is not None: -            extra_symbols.append(self.eos_token) - -        max_key = max(self.mapping.keys()) -        extra_mapping = {} -        for i, symbol in enumerate(extra_symbols): -            extra_mapping[max_key + 1 + i] = symbol - -        self._mapping = {**self.mapping, **extra_mapping} - - -def compute_sha256(filename: Union[Path, str]) -> str: -    """Returns the SHA256 checksum of a file.""" -    with open(filename, "rb") as f: -        return hashlib.sha256(f.read()).hexdigest() - - -class TqdmUpTo(tqdm): -    """TQDM progress bar when downloading files. - -    From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - -    """ - -    def update_to( -        self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None -    ) -> None: -        """Updates the progress bar. - -        Args: -            blocks (int): Number of blocks transferred so far. Defaults to 1. -            block_size (int): Size of each block, in tqdm units. Defaults to 1. -            total_size (Optional[int]): Total size in tqdm units. Defaults to None. -        """ -        if total_size is not None: -            self.total = total_size  # pylint: disable=attribute-defined-outside-init -        self.update(blocks * block_size - self.n) - - -def download_url(url: str, filename: str) -> None: -    """Downloads a file from url to filename, with a progress bar.""" -    with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: -        urlretrieve(url, filename, reporthook=t.update_to, data=None)  # nosec - - -def _download_raw_dataset(metadata: Dict) -> None: -    if os.path.exists(metadata["filename"]): -        return -    logger.info(f"Downloading raw dataset from {metadata['url']}...") -    download_url(metadata["url"], metadata["filename"]) -    logger.info("Computing SHA-256...") -    sha256 = compute_sha256(metadata["filename"]) -    if sha256 != metadata["sha256"]: -        raise ValueError( -            "Downloaded data file SHA-256 does not match that listed in metadata document." -        )  |