diff options
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 12 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 32 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 38 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_dataset.py | 134 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_lines_dataset.py | 126 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_paragraphs_dataset.py | 322 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 99 |
7 files changed, 685 insertions, 78 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index 05f74f6..ede4541 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -10,15 +10,23 @@ from .emnist_lines_dataset import ( EmnistLinesDataset, get_samples_by_character, ) -from .util import fetch_data_loaders, Transpose +from .iam_dataset import IamDataset +from .iam_lines_dataset import IamLinesDataset +from .iam_paragraphs_dataset import IamParagraphsDataset +from .util import _download_raw_dataset, compute_sha256, download_url, Transpose __all__ = [ + "_download_raw_dataset", + "compute_sha256", "construct_image_from_string", "DATA_DIRNAME", + "download_url", "EmnistDataset", "EmnistMapper", "EmnistLinesDataset", - "fetch_data_loaders", "get_samples_by_character", + "IamDataset", + "IamLinesDataset", + "IamParagraphsDataset", "Transpose", ] diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 49ebad3..0715aae 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -152,8 +152,7 @@ class EmnistDataset(Dataset): """Loads the dataset and the mappings. Args: - train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to - False. + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. transform (Optional[Callable]): Transform(s) for input data. Defaults to None. @@ -181,17 +180,37 @@ class EmnistDataset(Dataset): self.seed = seed self._mapper = EmnistMapper() - self.input_shape = self._mapper.input_shape + self._input_shape = self._mapper.input_shape self.num_classes = self._mapper.num_classes # Load dataset. - self.data, self.targets = self.load_emnist_dataset() + self._data, self._targets = self.load_emnist_dataset() + + @property + def data(self) -> Tensor: + """The input data.""" + return self._data + + @property + def targets(self) -> Tensor: + """The target data.""" + return self._targets + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape @property def mapper(self) -> EmnistMapper: """Returns the EmnistMapper.""" return self._mapper + @property + def inverse_mapping(self) -> Dict: + """Returns the inverse mapping from character to index.""" + return self.mapper.inverse_mapping + def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.data) @@ -220,11 +239,6 @@ class EmnistDataset(Dataset): return data, targets - @property - def __name__(self) -> str: - """Returns the name of the dataset.""" - return "EmnistDataset" - def __repr__(self) -> str: """Returns information about the dataset.""" return ( diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index b0617f5..656131a 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -9,8 +9,8 @@ from loguru import logger import numpy as np import torch from torch import Tensor -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import Compose, Normalize, ToTensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor from text_recognizer.datasets import ( DATA_DIRNAME, @@ -20,6 +20,7 @@ from text_recognizer.datasets import ( ) from text_recognizer.datasets.sentence_generator import SentenceGenerator from text_recognizer.datasets.util import Transpose +from text_recognizer.networks import sliding_window DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" @@ -55,7 +56,7 @@ class EmnistLinesDataset(Dataset): self.transform = transform if self.transform is None: - self.transform = Compose([ToTensor()]) + self.transform = ToTensor() self.target_transform = target_transform if self.target_transform is None: @@ -63,14 +64,14 @@ class EmnistLinesDataset(Dataset): # Extract dataset information. self._mapper = EmnistMapper() - self.input_shape = self._mapper.input_shape + self._input_shape = self._mapper.input_shape self.num_classes = self._mapper.num_classes self.max_length = max_length self.min_overlap = min_overlap self.max_overlap = max_overlap self.num_samples = num_samples - self.input_shape = ( + self._input_shape = ( self.input_shape[0], self.input_shape[1] * self.max_length, ) @@ -84,6 +85,11 @@ class EmnistLinesDataset(Dataset): # Load dataset. self._load_or_generate_data() + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape + def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.data) @@ -112,11 +118,6 @@ class EmnistLinesDataset(Dataset): return data, targets - @property - def __name__(self) -> str: - """Returns the name of the dataset.""" - return "EmnistLinesDataset" - def __repr__(self) -> str: """Returns information about the dataset.""" return ( @@ -136,13 +137,18 @@ class EmnistLinesDataset(Dataset): return self._mapper @property + def mapping(self) -> Dict: + """Return EMNIST mapping from index to character.""" + return self._mapper.mapping + + @property def data_filename(self) -> Path: """Path to the h5 file.""" filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" if self.train: filename = "train_" + filename else: - filename = "val_" + filename + filename = "test_" + filename return DATA_DIRNAME / filename def _load_or_generate_data(self) -> None: @@ -184,7 +190,7 @@ class EmnistLinesDataset(Dataset): ) targets = convert_strings_to_categorical_labels( - targets, self.emnist.inverse_mapping + targets, emnist.inverse_mapping ) f.create_dataset("data", data=data, dtype="u1", compression="lzf") @@ -322,13 +328,13 @@ def create_datasets( min_overlap: float = 0, max_overlap: float = 0.33, num_train: int = 10000, - num_val: int = 1000, + num_test: int = 1000, ) -> None: """Creates a training an validation dataset of Emnist lines.""" emnist_train = EmnistDataset(train=True, sample_to_balance=True) - emnist_val = EmnistDataset(train=False, sample_to_balance=True) - datasets = [emnist_train, emnist_val] - num_samples = [num_train, num_val] + emnist_test = EmnistDataset(train=False, sample_to_balance=True) + datasets = [emnist_train, emnist_test] + num_samples = [num_train, num_test] for num, train, dataset in zip(num_samples, [True, False], datasets): emnist_lines = EmnistLinesDataset( train=train, diff --git a/src/text_recognizer/datasets/iam_dataset.py b/src/text_recognizer/datasets/iam_dataset.py new file mode 100644 index 0000000..5e47350 --- /dev/null +++ b/src/text_recognizer/datasets/iam_dataset.py @@ -0,0 +1,134 @@ +"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" +import os +from typing import Any, Dict, List +import zipfile + +from boltons.cacheutils import cachedproperty +import defusedxml.ElementTree as ET +from loguru import logger +import toml +from torch.utils.data import Dataset + +from text_recognizer.datasets import DATA_DIRNAME +from text_recognizer.datasets.util import _download_raw_dataset + +RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" + +DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. +LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. + + +class IamDataset(Dataset): + """IAM dataset. + + "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, + which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." + From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + + The data split we will use is + IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. + The validation set has been merged into the train set. + The train set has 7,101 lines from 326 writers. + The test set has 1,861 lines from 128 writers. + The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. + + """ + + def __init__(self) -> None: + self.metadata = toml.load(METADATA_FILENAME) + + def load_or_generate_data(self) -> None: + """Downloads IAM dataset if xml files does not exist.""" + if not self.xml_filenames: + self._download_iam() + + @property + def xml_filenames(self) -> List: + """List of xml filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) + + @property + def form_filenames(self) -> List: + """List of forms filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) + + def _download_iam(self) -> None: + curdir = os.getcwd() + os.chdir(RAW_DATA_DIRNAME) + _download_raw_dataset(self.metadata) + _extract_raw_dataset(self.metadata) + os.chdir(curdir) + + @property + def form_filenames_by_id(self) -> Dict: + """Creates a dictionary with filenames as keys and forms as values.""" + return {filename.stem: filename for filename in self.form_filenames} + + @cachedproperty + def line_strings_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of line texts in it.""" + return { + filename.stem: _get_line_strings_from_xml_file(filename) + for filename in self.xml_filenames + } + + @cachedproperty + def line_regions_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" + return { + filename.stem: _get_line_regions_from_xml_file(filename) + for filename in self.xml_filenames + } + + def __repr__(self) -> str: + """Print info about dataset.""" + return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" + + +def _extract_raw_dataset(metadata: Dict) -> None: + logger.info("Extracting IAM data.") + with zipfile.ZipFile(metadata["filename"], "r") as zip_file: + zip_file.extractall() + + +def _get_line_strings_from_xml_file(filename: str) -> List[str]: + """Get the text content of each line. Note that we replace " with ".""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + + +def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: + """Get the line region dict for each line.""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [_get_line_region_from_xml_element(el) for el in xml_line_elements] + + +def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: + """Extracts coordinates for each line of text.""" + # TODO: fix input! + word_elements = xml_line.findall("word/cmp") + x1s = [int(el.attrib["x"]) for el in word_elements] + y1s = [int(el.attrib["y"]) for el in word_elements] + x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] + y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] + return { + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } + + +def main() -> None: + """Initializes the dataset and print info about the dataset.""" + dataset = IamDataset() + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py new file mode 100644 index 0000000..477f500 --- /dev/null +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -0,0 +1,126 @@ +"""IamLinesDataset class.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import h5py +from loguru import logger +import torch +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper +from text_recognizer.datasets.util import compute_sha256, download_url + + +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" +PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" +PROCESSED_DATA_URL = ( + "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" +) + + +class IamLinesDataset(Dataset): + """IAM lines datasets for handwritten text lines.""" + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + self.train = train + self.split = "train" if self.train else "test" + self._mapper = EmnistMapper() + self.num_classes = self.mapper.num_classes + + # Set transforms. + self.transform = transform + if self.transform is None: + self.transform = ToTensor() + + self.target_transform = target_transform + if self.target_transform is None: + self.target_transform = torch.tensor + + self.subsample_fraction = subsample_fraction + self.data = None + self.targets = None + + @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + + @property + def mapping(self) -> Dict: + """Return EMNIST mapping from index to character.""" + return self._mapper.mapping + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self.data.shape[1:] + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return self.targets.shape[1:] + (self.num_classes,) + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + if not PROCESSED_DATA_FILENAME.exists(): + PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info("Downloading IAM lines...") + download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + self.data = f[f"x_{self.split}"][:] + self.targets = f[f"y_{self.split}"][:] + self._subsample() + + def _subsample(self) -> None: + """Only a fraction of the data will be loaded.""" + if self.subsample_fraction is None: + return + + num_samples = int(self.data.shape[0] * self.subsample_fraction) + self.data = self.data[:num_samples] + self.targets = self.targets[:num_samples] + + def __repr__(self) -> str: + """Print info about the dataset.""" + return ( + "IAM Lines Dataset\n" # pylint: disable=no-member + f"Number classes: {self.num_classes}\n" + f"Mapping: {self.mapper.mapping}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py new file mode 100644 index 0000000..d65b346 --- /dev/null +++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -0,0 +1,322 @@ +"""IamParagraphsDataset class and functions for data processing.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import click +import cv2 +import h5py +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer import util +from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper +from text_recognizer.datasets.iam_dataset import IamDataset +from text_recognizer.datasets.util import compute_sha256, download_url + +INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" +DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" +CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" +GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" + +PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines. +TEST_FRACTION = 0.2 +SEED = 4711 + + +class IamParagraphsDataset(Dataset): + """IAM Paragraphs dataset for paragraphs of handwritten text. + + TODO: __getitem__, __len__, get_data_target_from_id + + """ + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + + # Load Iam dataset. + self.iam_dataset = IamDataset() + + self.train = train + self.split = "train" if self.train else "test" + self.num_classes = 3 + self._input_shape = (256, 256) + self._output_shape = self._input_shape + (self.num_classes,) + self.subsample_fraction = subsample_fraction + + # Set transforms. + self.transform = transform + if self.transform is None: + self.transform = ToTensor() + + self.target_transform = target_transform + if self.target_transform is None: + self.target_transform = torch.tensor + + self._data = None + self._targets = None + self._ids = None + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return self._output_shape + + @property + def data(self) -> Tensor: + """The input data.""" + return self._data + + @property + def targets(self) -> Tensor: + """The target data.""" + return self._targets + + @property + def ids(self) -> Tensor: + """Ids of the dataset.""" + return self._ids + + def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: + """Get data target pair from id.""" + ind = self.ids.index(id_) + return self.data[ind], self.targets[ind] + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) + num_targets = len(self.iam_dataset.line_regions_by_id) + + if num_actual < num_targets - 2: + self._process_iam_paragraphs() + + self._data, self._targets, self._ids = _load_iam_paragraphs() + self._get_random_split() + self._subsample() + + def _get_random_split(self) -> None: + np.random.seed(SEED) + num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) + indices = np.random.permutation(self.data.shape[0]) + train_indices, test_indices = indices[:num_train], indices[num_train:] + if self.train: + self._data = self.data[train_indices] + self._targets = self.targets[train_indices] + else: + self._data = self.data[test_indices] + self._targets = self.targets[test_indices] + + def _process_iam_paragraphs(self) -> None: + """Crop the part with the text. + + For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are + self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel + corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line + """ + crop_dims = self._decide_on_crop_dims() + CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + GT_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info( + f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" + ) + for filename in self.iam_dataset.form_filenames: + id_ = filename.stem + line_region = self.iam_dataset.line_regions_by_id[id_] + _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) + + def _decide_on_crop_dims(self) -> Tuple[int, int]: + """Decide on the dimensions to crop out of the form image. + + Since image width is larger than a comfortable crop around the longest paragraph, + we will make the crop a square form factor. + And since the found dimensions 610x610 are pretty close to 512x512, + we might as well resize crops and make it exactly that, which lets us + do all kinds of power-of-2 pooling and upsampling should we choose to. + + Returns: + Tuple[int, int]: A tuple of crop dimensions. + + Raises: + RuntimeError: When max crop height is larger than max crop width. + + """ + + sample_form_filename = self.iam_dataset.form_filenames[0] + sample_image = util.read_image(sample_form_filename, grayscale=True) + max_crop_width = sample_image.shape[1] + max_crop_height = _get_max_paragraph_crop_height( + self.iam_dataset.line_regions_by_id + ) + if not max_crop_height <= max_crop_width: + raise RuntimeError( + f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" + ) + + crop_dims = (max_crop_width, max_crop_width) + logger.info( + f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." + ) + logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") + return crop_dims + + def _subsample(self) -> None: + """Only this fraction of the data will be loaded.""" + if self.subsample_fraction is None: + return + num_subsample = int(self.data.shape[0] * self.subsample_fraction) + self.data = self.data[:num_subsample] + self.targets = self.targets[:num_subsample] + + def __repr__(self) -> str: + """Return info about the dataset.""" + return ( + "IAM Paragraph Dataset\n" # pylint: disable=no-member + f"Num classes: {self.num_classes}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + +def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: + heights = [] + for regions in line_regions_by_id.values(): + min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + heights.append(height) + return max(heights) + + +def _crop_paragraph_image( + filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple +) -> None: + image = util.read_image(filename, grayscale=True) + + min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + crop_height = crop_dims[0] + buffer = (crop_height - height) // 2 + + # Generate image crop. + image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) + try: + image_crop[buffer : buffer + height] = image[min_y1:max_y2] + except Exception as e: # pylint: disable=broad-except + logger.error(f"Rescued {filename}: {e}") + return + + # Generate ground truth. + gt_image = np.zeros_like(image_crop, dtype=np.uint8) + for index, region in enumerate(line_regions): + gt_image[ + (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), + region["x1"] : region["x2"], + ] = (index % 2 + 1) + + # Generate image for debugging. + import matplotlib.pyplot as plt + + cmap = plt.get_cmap("Set1") + image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) + for index, region in enumerate(line_regions): + color = [255 * _ for _ in cmap(index)[:-1]] + cv2.rectangle( + image_crop_for_debug, + (region["x1"], region["y1"] - min_y1 + buffer), + (region["x2"], region["y2"] - min_y1 + buffer), + color, + 3, + ) + image_crop_for_debug = cv2.resize( + image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA + ) + util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") + + image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) + util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") + + gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) + util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") + + +def _load_iam_paragraphs() -> None: + logger.info("Loading IAM paragraph crops and ground truth from image files...") + images = [] + gt_images = [] + ids = [] + for filename in CROPS_DIRNAME.glob("*.jpg"): + id_ = filename.stem + image = util.read_image(filename, grayscale=True) + image = 1.0 - image / 255 + + gt_filename = GT_DIRNAME / f"{id_}.png" + gt_image = util.read_image(gt_filename, grayscale=True) + + images.append(image) + gt_images.append(gt_image) + ids.append(id_) + images = np.array(images).astype(np.float32) + gt_images = np.array(gt_images).astype(np.uint8) + ids = np.array(ids) + return images, gt_images, ids + + +@click.command() +@click.option( + "--subsample_fraction", + type=float, + default=0.0, + help="The subsampling factor of the dataset.", +) +def main(subsample_fraction: float) -> None: + """Load dataset and print info.""" + dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 76bd85f..dd16bed 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,10 +1,17 @@ """Util functions for datasets.""" +import hashlib import importlib -from typing import Callable, Dict, List, Type +import os +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union +from urllib.request import urlopen, urlretrieve +import cv2 +from loguru import logger import numpy as np from PIL import Image from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm class Transpose: @@ -15,58 +22,48 @@ class Transpose: return np.array(image).swapaxes(0, 1) -def fetch_data_loaders( - splits: List[str], - dataset: str, - dataset_args: Dict, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, -) -> Dict[str, DataLoader]: - """Fetches DataLoaders for given split(s) as a dictionary. - - Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then - calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value. - - Args: - splits (List[str]): One or both of the dataset splits "train" and "val". - dataset (str): The name of the dataset. - dataset_args (Dict): The dataset arguments. - batch_size (int): How many samples per batch to load. Defaults to 128. - shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be - loaded in the main process. Defaults to 0. - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning - them. Defaults to True. - - Returns: - Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value. +def compute_sha256(filename: Union[Path, str]) -> str: + """Returns the SHA256 checksum of a file.""" + with open(filename, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() - """ - - def check_dataset_args(args: Dict, split: str) -> Dict: - """Adds train flag to the dataset args.""" - args["train"] = True if split == "train" else False - return args - - # Import dataset module. - datasets_module = importlib.import_module("text_recognizer.datasets") - dataset_ = getattr(datasets_module, dataset) - data_loaders = {} +class TqdmUpTo(tqdm): + """TQDM progress bar when downloading files. - for split in ["train", "val"]: - if split in splits: + From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - data_loader = DataLoader( - dataset=dataset_(**check_dataset_args(dataset_args, split)), - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - pin_memory=cuda, - ) - - data_loaders[split] = data_loader + """ - return data_loaders + def update_to( + self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None + ) -> None: + """Updates the progress bar. + + Args: + blocks (int): Number of blocks transferred so far. Defaults to 1. + block_size (int): Size of each block, in tqdm units. Defaults to 1. + total_size (Optional[int]): Total size in tqdm units. Defaults to None. + """ + if total_size is not None: + self.total = total_size # pylint: disable=attribute-defined-outside-init + self.update(blocks * block_size - self.n) + + +def download_url(url: str, filename: str) -> None: + """Downloads a file from url to filename, with a progress bar.""" + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec + + +def _download_raw_dataset(metadata: Dict) -> None: + if os.path.exists(metadata["filename"]): + return + logger.info(f"Downloading raw dataset from {metadata['url']}...") + download_url(metadata["url"], metadata["filename"]) + logger.info("Computing SHA-256...") + sha256 = compute_sha256(metadata["filename"]) + if sha256 != metadata["sha256"]: + raise ValueError( + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) |