summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-31 21:55:10 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-31 21:55:10 +0200
commit3196144ec99e803cef218295ddea592748931c57 (patch)
tree867d38ed08c78b8186fdd9a8abab4257f14d05c7 /text_recognizer/data
parentd21594211e29c40c135b753e33b248b0737cd76f (diff)
Removing legacy code
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/iam_paragraphs_dataset.py291
-rw-r--r--text_recognizer/data/util.py209
2 files changed, 0 insertions, 500 deletions
diff --git a/text_recognizer/data/iam_paragraphs_dataset.py b/text_recognizer/data/iam_paragraphs_dataset.py
deleted file mode 100644
index 8ba5142..0000000
--- a/text_recognizer/data/iam_paragraphs_dataset.py
+++ /dev/null
@@ -1,291 +0,0 @@
-"""IamParagraphsDataset class and functions for data processing."""
-import random
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import click
-import cv2
-import h5py
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from torchvision.transforms import ToTensor
-
-from text_recognizer import util
-from text_recognizer.datasets.dataset import Dataset
-from text_recognizer.datasets.iam_dataset import IamDataset
-from text_recognizer.datasets.util import (
- compute_sha256,
- DATA_DIRNAME,
- download_url,
- EmnistMapper,
-)
-
-INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
-DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
-PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs"
-CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops"
-GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt"
-
-PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines.
-TEST_FRACTION = 0.2
-SEED = 4711
-
-
-class IamParagraphsDataset(Dataset):
- """IAM Paragraphs dataset for paragraphs of handwritten text."""
-
- def __init__(
- self,
- train: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- ) -> None:
- super().__init__(
- train=train,
- subsample_fraction=subsample_fraction,
- transform=transform,
- target_transform=target_transform,
- )
- # Load Iam dataset.
- self.iam_dataset = IamDataset()
-
- self.num_classes = 3
- self._input_shape = (256, 256)
- self._output_shape = self._input_shape + (self.num_classes,)
- self._ids = None
-
- def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
- """Fetches data, target pair of the dataset for a given and index or indices.
-
- Args:
- index (Union[int, Tensor]): Either a list or int of indices/index.
-
- Returns:
- Tuple[Tensor, Tensor]: Data target pair.
-
- """
- if torch.is_tensor(index):
- index = index.tolist()
-
- data = self.data[index]
- targets = self.targets[index]
-
- seed = np.random.randint(SEED)
- random.seed(seed) # apply this seed to target tranfsorms
- torch.manual_seed(seed) # needed for torchvision 0.7
- if self.transform:
- data = self.transform(data)
-
- random.seed(seed) # apply this seed to target tranfsorms
- torch.manual_seed(seed) # needed for torchvision 0.7
- if self.target_transform:
- targets = self.target_transform(targets)
-
- return data, targets.long()
-
- @property
- def ids(self) -> Tensor:
- """Ids of the dataset."""
- return self._ids
-
- def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]:
- """Get data target pair from id."""
- ind = self.ids.index(id_)
- return self.data[ind], self.targets[ind]
-
- def load_or_generate_data(self) -> None:
- """Load or generate dataset data."""
- num_actual = len(list(CROPS_DIRNAME.glob("*.jpg")))
- num_targets = len(self.iam_dataset.line_regions_by_id)
-
- if num_actual < num_targets - 2:
- self._process_iam_paragraphs()
-
- self._data, self._targets, self._ids = _load_iam_paragraphs()
- self._get_random_split()
- self._subsample()
-
- def _get_random_split(self) -> None:
- np.random.seed(SEED)
- num_train = int((1 - TEST_FRACTION) * self.data.shape[0])
- indices = np.random.permutation(self.data.shape[0])
- train_indices, test_indices = indices[:num_train], indices[num_train:]
- if self.train:
- self._data = self.data[train_indices]
- self._targets = self.targets[train_indices]
- else:
- self._data = self.data[test_indices]
- self._targets = self.targets[test_indices]
-
- def _process_iam_paragraphs(self) -> None:
- """Crop the part with the text.
-
- For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are
- self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel
- corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line
- """
- crop_dims = self._decide_on_crop_dims()
- CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
- DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
- GT_DIRNAME.mkdir(parents=True, exist_ok=True)
- logger.info(
- f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}"
- )
- for filename in self.iam_dataset.form_filenames:
- id_ = filename.stem
- line_region = self.iam_dataset.line_regions_by_id[id_]
- _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape)
-
- def _decide_on_crop_dims(self) -> Tuple[int, int]:
- """Decide on the dimensions to crop out of the form image.
-
- Since image width is larger than a comfortable crop around the longest paragraph,
- we will make the crop a square form factor.
- And since the found dimensions 610x610 are pretty close to 512x512,
- we might as well resize crops and make it exactly that, which lets us
- do all kinds of power-of-2 pooling and upsampling should we choose to.
-
- Returns:
- Tuple[int, int]: A tuple of crop dimensions.
-
- Raises:
- RuntimeError: When max crop height is larger than max crop width.
-
- """
-
- sample_form_filename = self.iam_dataset.form_filenames[0]
- sample_image = util.read_image(sample_form_filename, grayscale=True)
- max_crop_width = sample_image.shape[1]
- max_crop_height = _get_max_paragraph_crop_height(
- self.iam_dataset.line_regions_by_id
- )
- if not max_crop_height <= max_crop_width:
- raise RuntimeError(
- f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}"
- )
-
- crop_dims = (max_crop_width, max_crop_width)
- logger.info(
- f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}."
- )
- logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
- return crop_dims
-
- def __repr__(self) -> str:
- """Return info about the dataset."""
- return (
- "IAM Paragraph Dataset\n" # pylint: disable=no-member
- f"Num classes: {self.num_classes}\n"
- f"Data: {self.data.shape}\n"
- f"Targets: {self.targets.shape}\n"
- )
-
-
-def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int:
- heights = []
- for regions in line_regions_by_id.values():
- min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER
- max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER
- height = max_y2 - min_y1
- heights.append(height)
- return max(heights)
-
-
-def _crop_paragraph_image(
- filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple
-) -> None:
- image = util.read_image(filename, grayscale=True)
-
- min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER
- max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER
- height = max_y2 - min_y1
- crop_height = crop_dims[0]
- buffer = (crop_height - height) // 2
-
- # Generate image crop.
- image_crop = 255 * np.ones(crop_dims, dtype=np.uint8)
- try:
- image_crop[buffer : buffer + height] = image[min_y1:max_y2]
- except Exception as e: # pylint: disable=broad-except
- logger.error(f"Rescued {filename}: {e}")
- return
-
- # Generate ground truth.
- gt_image = np.zeros_like(image_crop, dtype=np.uint8)
- for index, region in enumerate(line_regions):
- gt_image[
- (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer),
- region["x1"] : region["x2"],
- ] = (index % 2 + 1)
-
- # Generate image for debugging.
- import matplotlib.pyplot as plt
-
- cmap = plt.get_cmap("Set1")
- image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop])
- for index, region in enumerate(line_regions):
- color = [255 * _ for _ in cmap(index)[:-1]]
- cv2.rectangle(
- image_crop_for_debug,
- (region["x1"], region["y1"] - min_y1 + buffer),
- (region["x2"], region["y2"] - min_y1 + buffer),
- color,
- 3,
- )
- image_crop_for_debug = cv2.resize(
- image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA
- )
- util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg")
-
- image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA)
- util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg")
-
- gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST)
- util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png")
-
-
-def _load_iam_paragraphs() -> None:
- logger.info("Loading IAM paragraph crops and ground truth from image files...")
- images = []
- gt_images = []
- ids = []
- for filename in CROPS_DIRNAME.glob("*.jpg"):
- id_ = filename.stem
- image = util.read_image(filename, grayscale=True)
- image = 1.0 - image / 255
-
- gt_filename = GT_DIRNAME / f"{id_}.png"
- gt_image = util.read_image(gt_filename, grayscale=True)
-
- images.append(image)
- gt_images.append(gt_image)
- ids.append(id_)
- images = np.array(images).astype(np.float32)
- gt_images = np.array(gt_images).astype(np.uint8)
- ids = np.array(ids)
- return images, gt_images, ids
-
-
-@click.command()
-@click.option(
- "--subsample_fraction",
- type=float,
- default=None,
- help="The subsampling factor of the dataset.",
-)
-def main(subsample_fraction: float) -> None:
- """Load dataset and print info."""
- logger.info("Creating train set...")
- dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction)
- dataset.load_or_generate_data()
- print(dataset)
- logger.info("Creating test set...")
- dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
- dataset.load_or_generate_data()
- print(dataset)
-
-
-if __name__ == "__main__":
- main()
diff --git a/text_recognizer/data/util.py b/text_recognizer/data/util.py
deleted file mode 100644
index da87756..0000000
--- a/text_recognizer/data/util.py
+++ /dev/null
@@ -1,209 +0,0 @@
-"""Util functions for datasets."""
-import hashlib
-import json
-import os
-from pathlib import Path
-import string
-from typing import Dict, List, Optional, Union
-from urllib.request import urlretrieve
-
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from torchvision.datasets import EMNIST
-from tqdm import tqdm
-
-DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
-ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
-
-
-def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None:
- """Extract and saves EMNIST essentials."""
- labels = emnsit_dataset.classes
- labels.sort()
- mapping = [(i, str(label)) for i, label in enumerate(labels)]
- essentials = {
- "mapping": mapping,
- "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]),
- }
- logger.info("Saving emnist essentials...")
- with open(ESSENTIALS_FILENAME, "w") as f:
- json.dump(essentials, f)
-
-
-def download_emnist() -> None:
- """Download the EMNIST dataset via the PyTorch class."""
- logger.info(f"Data directory is: {DATA_DIRNAME}")
- dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
- save_emnist_essentials(dataset)
-
-
-class EmnistMapper:
- """Mapper between network output to Emnist character."""
-
- def __init__(
- self,
- pad_token: str,
- init_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- """Loads the emnist essentials file with the mapping and input shape."""
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- self.lower = lower
-
- self.essentials = self._load_emnist_essentials()
- # Load dataset information.
- self._mapping = dict(self.essentials["mapping"])
- self._augment_emnist_mapping()
- self._inverse_mapping = {v: k for k, v in self.mapping.items()}
- self._num_classes = len(self.mapping)
- self._input_shape = self.essentials["input_shape"]
-
- def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]:
- """Maps the token to emnist character or character index.
-
- If the token is an integer (index), the method will return the Emnist character corresponding to that index.
- If the token is a str (Emnist character), the method will return the corresponding index for that character.
-
- Args:
- token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer).
-
- Returns:
- Union[str, int]: The mapping result.
-
- Raises:
- KeyError: If the index or string does not exist in the mapping.
-
- """
- if (
- (isinstance(token, np.uint8) or isinstance(token, int))
- or torch.is_tensor(token)
- and int(token) in self.mapping
- ):
- return self.mapping[int(token)]
- elif isinstance(token, str) and token in self._inverse_mapping:
- return self._inverse_mapping[token]
- else:
- raise KeyError(f"Token {token} does not exist in the mappings.")
-
- @property
- def mapping(self) -> Dict:
- """Returns the mapping between index and character."""
- return self._mapping
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the mapping between character and index."""
- return self._inverse_mapping
-
- @property
- def num_classes(self) -> int:
- """Returns the number of classes in the dataset."""
- return self._num_classes
-
- @property
- def input_shape(self) -> List[int]:
- """Returns the input shape of the Emnist characters."""
- return self._input_shape
-
- def _load_emnist_essentials(self) -> Dict:
- """Load the EMNIST mapping."""
- with open(str(ESSENTIALS_FILENAME)) as f:
- essentials = json.load(f)
- return essentials
-
- def _augment_emnist_mapping(self) -> None:
- """Augment the mapping with extra symbols."""
- # Extra symbols in IAM dataset
- if self.lower:
- self._mapping = {
- k: str(v)
- for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase))
- }
-
- extra_symbols = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
-
- # padding symbol, and acts as blank symbol as well.
- extra_symbols.append(self.pad_token)
-
- if self.init_token is not None:
- extra_symbols.append(self.init_token)
-
- if self.eos_token is not None:
- extra_symbols.append(self.eos_token)
-
- max_key = max(self.mapping.keys())
- extra_mapping = {}
- for i, symbol in enumerate(extra_symbols):
- extra_mapping[max_key + 1 + i] = symbol
-
- self._mapping = {**self.mapping, **extra_mapping}
-
-
-def compute_sha256(filename: Union[Path, str]) -> str:
- """Returns the SHA256 checksum of a file."""
- with open(filename, "rb") as f:
- return hashlib.sha256(f.read()).hexdigest()
-
-
-class TqdmUpTo(tqdm):
- """TQDM progress bar when downloading files.
-
- From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
-
- """
-
- def update_to(
- self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
- ) -> None:
- """Updates the progress bar.
-
- Args:
- blocks (int): Number of blocks transferred so far. Defaults to 1.
- block_size (int): Size of each block, in tqdm units. Defaults to 1.
- total_size (Optional[int]): Total size in tqdm units. Defaults to None.
- """
- if total_size is not None:
- self.total = total_size # pylint: disable=attribute-defined-outside-init
- self.update(blocks * block_size - self.n)
-
-
-def download_url(url: str, filename: str) -> None:
- """Downloads a file from url to filename, with a progress bar."""
- with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
- urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
-
-
-def _download_raw_dataset(metadata: Dict) -> None:
- if os.path.exists(metadata["filename"]):
- return
- logger.info(f"Downloading raw dataset from {metadata['url']}...")
- download_url(metadata["url"], metadata["filename"])
- logger.info("Computing SHA-256...")
- sha256 = compute_sha256(metadata["filename"])
- if sha256 != metadata["sha256"]:
- raise ValueError(
- "Downloaded data file SHA-256 does not match that listed in metadata document."
- )