diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-28 22:02:24 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-28 22:02:24 +0200 |
commit | 46a1472d33d3a4180798492e819f2ec02bc3b1a3 (patch) | |
tree | 22322ed0d8f9f803966ea745ec5bb8c759f8db64 /text_recognizer/data | |
parent | 8248f173132dfb7e47ec62b08e9235990c8626e3 (diff) |
Add refactor of iam lines
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/__init__.py | 3 | ||||
-rw-r--r-- | text_recognizer/data/base_data_module.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/base_dataset.py | 13 | ||||
-rw-r--r-- | text_recognizer/data/emnist.py | 37 | ||||
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 32 | ||||
-rw-r--r-- | text_recognizer/data/iam.py | 39 | ||||
-rw-r--r-- | text_recognizer/data/iam_dataset.py | 133 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 255 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines_dataset.py | 110 | ||||
-rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 1 | ||||
-rw-r--r-- | text_recognizer/data/image_utils.py | 49 | ||||
-rw-r--r-- | text_recognizer/data/sentence_generator.py | 7 | ||||
-rw-r--r-- | text_recognizer/data/transforms.py | 160 |
13 files changed, 391 insertions, 450 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py index 2727b20..9a42fa9 100644 --- a/text_recognizer/data/__init__.py +++ b/text_recognizer/data/__init__.py @@ -1 +1,4 @@ """Dataset modules.""" +from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset +from .base_data_module import BaseDataModule, load_and_print_info +from .download_utils import download_dataset diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index f5e7300..8b5c188 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader def load_and_print_info(data_module_class: type) -> None: - """Load EMNISTLines and prints info.""" + """Load dataset and print dataset information.""" dataset = data_module_class() dataset.prepare_data() dataset.setup() diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index a9e9c24..d00daaf 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -71,3 +71,16 @@ def convert_strings_to_labels( for j, token in enumerate(tokens): labels[i, j] = mapping[token] return labels + + +def split_dataset( + dataset: BaseDataset, fraction: float, seed: int +) -> Tuple[BaseDataset, BaseDataset]: + """Split dataset into two parts with fraction * size and (1 - fraction) * size.""" + if fraction >= 1.0: + raise ValueError("Fraction cannot be larger greater or equal to 1.0.") + split_1 = int(fraction * len(dataset)) + split_2 = len(dataset) - split_1 + return torch.utils.data.random_split( + dataset, [split_1, split_2], generator=torch.Generator().manual_seed(seed) + ) diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 7f67893..3e10b5f 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -1,6 +1,6 @@ """EMNIST dataset: downloads it from FSDL aws url if not present.""" from pathlib import Path -from typing import Sequence, Tuple +from typing import Dict, List, Sequence, Tuple import json import os import shutil @@ -10,11 +10,9 @@ import h5py import numpy as np from loguru import logger import toml -import torch -from torch.utils.data import random_split from torchvision import transforms -from text_recognizer.data.base_dataset import BaseDataset +from text_recognizer.data.base_dataset import BaseDataset, split_dataset from text_recognizer.data.base_data_module import ( BaseDataModule, load_and_print_info, @@ -48,23 +46,18 @@ class EMNIST(BaseDataModule): self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 ) -> None: super().__init__(batch_size, num_workers) - if not ESSENTIALS_FILENAME.exists(): - _download_and_process_emnist() - with ESSENTIALS_FILENAME.open() as f: - essentials = json.load(f) self.train_fraction = train_fraction - self.mapping = list(essentials["characters"]) - self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} + self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping() self.data_train = None self.data_val = None self.data_test = None self.transform = transforms.Compose([transforms.ToTensor()]) - self.dims = (1, *essentials["input_shape"]) + self.dims = (1, * self.input_shape) self.output_dims = (1,) def prepare_data(self) -> None: if not PROCESSED_DATA_FILENAME.exists(): - _download_and_process_emnist() + download_and_process_emnist() def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: @@ -75,10 +68,8 @@ class EMNIST(BaseDataModule): dataset_train = BaseDataset( self.x_train, self.y_train, transform=self.transform ) - train_size = int(self.train_fraction * len(dataset_train)) - val_size = len(dataset_train) - train_size - self.data_train, self.data_val = random_split( - dataset_train, [train_size, val_size], generator=torch.Generator() + self.data_train, self.data_val = split_dataset( + dataset_train, fraction=self.train_fraction, seed=SEED ) if stage == "test" or stage is None: @@ -104,7 +95,19 @@ class EMNIST(BaseDataModule): return basic + data -def _download_and_process_emnist() -> None: +def emnist_mapping() -> Tuple[List, Dict[str, int], List[int]]: + """Return the EMNIST mapping.""" + if not ESSENTIALS_FILENAME.exists(): + download_and_process_emnist() + with ESSENTIALS_FILENAME.open() as f: + essentials = json.load(f) + mapping = list(essentials["characters"]) + inverse_mapping = {v: k for k, v in enumerate(mapping)} + input_shape = essentials["input_shape"] + return mapping, inverse_mapping, input_shape + + +def download_and_process_emnist() -> None: metadata = toml.load(METADATA_FILENAME) download_dataset(metadata, DL_DATA_DIRNAME) _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 6c14add..72665d0 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,12 +1,11 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import Callable, Dict, Tuple, Sequence +from typing import Callable, Dict, Tuple import h5py from loguru import logger import numpy as np -from PIL import Image import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode @@ -58,6 +57,7 @@ class EMNISTLines(BaseDataModule): self.num_test = num_test self.emnist = EMNIST() + # TODO: fix mapping self.mapping = self.emnist.mapping max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) @@ -66,32 +66,28 @@ class EMNISTLines(BaseDataModule): if max_width >= IMAGE_WIDTH: raise ValueError( - f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" - ) + f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" + ) - self.dims = ( - self.emnist.dims[0], - IMAGE_HEIGHT, - IMAGE_WIDTH - ) + self.dims = (self.emnist.dims[0], IMAGE_HEIGHT, IMAGE_WIDTH) if self.max_length >= MAX_OUTPUT_LENGTH: raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") self.output_dims = (MAX_OUTPUT_LENGTH, 1) - self.data_train = None - self.data_val = None - self.data_test = None + self.data_train: BaseDataset = None + self.data_val: BaseDataset = None + self.data_test: BaseDataset = None @property def data_filename(self) -> Path: """Return name of dataset.""" - return ( - DATA_DIRNAME / (f"ml_{self.max_length}_" + return DATA_DIRNAME / ( + f"ml_{self.max_length}_" f"o{self.min_overlap:f}_{self.max_overlap:f}_" f"ntr{self.num_train}_" f"ntv{self.num_val}_" - f"nte{self.num_test}.h5") + f"nte{self.num_test}.h5" ) def prepare_data(self) -> None: @@ -144,7 +140,10 @@ class EMNISTLines(BaseDataModule): x, y = next(iter(self.train_dataloader())) data = ( - f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" + "Train/val/test sizes: " + f"{len(self.data_train)}, " + f"{len(self.data_val)}, " + f"{len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) @@ -223,7 +222,6 @@ def _construct_image_from_string( ) -> torch.Tensor: overlap = np.random.uniform(min_overlap, max_overlap) sampled_images = _select_letter_samples_for_string(string, samples_by_char) - N = len(sampled_images) H, W = sampled_images[0].shape next_overlap_width = W - int(overlap * W) concatenated_image = torch.zeros((H, width), dtype=torch.uint8) diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index fcfe9a7..01272ba 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -60,23 +60,36 @@ class IAM(BaseDataModule): @property def split_by_id(self) -> Dict[str, str]: - return {filename.stem: "test" if filename.stem in self.metadata["test_ids"] else "trainval" for filename in self.form_filenames} + return { + filename.stem: "test" + if filename.stem in self.metadata["test_ids"] + else "train" + for filename in self.form_filenames + } @cachedproperty def line_strings_by_id(self) -> Dict[str, List[str]]: """Return a dict from name of IAM form to list of line texts in it.""" - return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} + return { + filename.stem: _get_line_strings_from_xml_file(filename) + for filename in self.xml_filenames + } @cachedproperty def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]: """Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it.""" - return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} + return { + filename.stem: _get_line_regions_from_xml_file(filename) + for filename in self.xml_filenames + } def __repr__(self) -> str: """Return info about the dataset.""" - return ("IAM Dataset\n" - f"Num forms total: {len(self.xml_filenames)}\n" - f"Num in test set: {len(self.metadata['test_ids'])}\n") + return ( + "IAM Dataset\n" + f"Num forms total: {len(self.xml_filenames)}\n" + f"Num in test set: {len(self.metadata['test_ids'])}\n" + ) def _extract_raw_dataset(filename: Path, dirname: Path) -> None: @@ -92,7 +105,7 @@ def _get_line_strings_from_xml_file(filename: str) -> List[str]: """Get the text content of each line. Note that we replace ": with ".""" xml_root_element = ElementTree.parse(filename).getroot() # nosec xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: @@ -107,13 +120,13 @@ def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]: x1s = [int(el.attrib["x"]) for el in word_elements] y1s = [int(el.attrib["y"]) for el in word_elements] x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] - y2s = [int(el.attrib["x"]) + int(el.attrib["height"]) for el in word_elements] + y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] return { - "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "x2": min(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - "y2": min(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - } + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } def download_iam() -> None: diff --git a/text_recognizer/data/iam_dataset.py b/text_recognizer/data/iam_dataset.py deleted file mode 100644 index a8998b9..0000000 --- a/text_recognizer/data/iam_dataset.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" -import os -from typing import Any, Dict, List -import zipfile - -from boltons.cacheutils import cachedproperty -import defusedxml.ElementTree as ET -from loguru import logger -import toml - -from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME - -RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" -RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - -DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. -LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. - - -class IamDataset: - """IAM dataset. - - "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, - which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." - From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database - - The data split we will use is - IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. - The validation set has been merged into the train set. - The train set has 7,101 lines from 326 writers. - The test set has 1,861 lines from 128 writers. - The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. - - """ - - def __init__(self) -> None: - self.metadata = toml.load(METADATA_FILENAME) - - def load_or_generate_data(self) -> None: - """Downloads IAM dataset if xml files does not exist.""" - if not self.xml_filenames: - self._download_iam() - - @property - def xml_filenames(self) -> List: - """List of xml filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) - - @property - def form_filenames(self) -> List: - """List of forms filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) - - def _download_iam(self) -> None: - curdir = os.getcwd() - os.chdir(RAW_DATA_DIRNAME) - _download_raw_dataset(self.metadata) - _extract_raw_dataset(self.metadata) - os.chdir(curdir) - - @property - def form_filenames_by_id(self) -> Dict: - """Creates a dictionary with filenames as keys and forms as values.""" - return {filename.stem: filename for filename in self.form_filenames} - - @cachedproperty - def line_strings_by_id(self) -> Dict: - """Return a dict from name of IAM form to a list of line texts in it.""" - return { - filename.stem: _get_line_strings_from_xml_file(filename) - for filename in self.xml_filenames - } - - @cachedproperty - def line_regions_by_id(self) -> Dict: - """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" - return { - filename.stem: _get_line_regions_from_xml_file(filename) - for filename in self.xml_filenames - } - - def __repr__(self) -> str: - """Print info about dataset.""" - return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" - - -def _extract_raw_dataset(metadata: Dict) -> None: - logger.info("Extracting IAM data.") - with zipfile.ZipFile(metadata["filename"], "r") as zip_file: - zip_file.extractall() - - -def _get_line_strings_from_xml_file(filename: str) -> List[str]: - """Get the text content of each line. Note that we replace " with ".""" - xml_root_element = ET.parse(filename).getroot() # nosec - xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] - - -def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: - """Get the line region dict for each line.""" - xml_root_element = ET.parse(filename).getroot() # nosec - xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [_get_line_region_from_xml_element(el) for el in xml_line_elements] - - -def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: - """Extracts coordinates for each line of text.""" - # TODO: fix input! - word_elements = xml_line.findall("word/cmp") - x1s = [int(el.attrib["x"]) for el in word_elements] - y1s = [int(el.attrib["y"]) for el in word_elements] - x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] - y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] - return { - "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - } - - -def main() -> None: - """Initializes the dataset and print info about the dataset.""" - dataset = IamDataset() - dataset.load_or_generate_data() - print(dataset) - - -if __name__ == "__main__": - main() diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py new file mode 100644 index 0000000..391075a --- /dev/null +++ b/text_recognizer/data/iam_lines.py @@ -0,0 +1,255 @@ +"""Class for IAM Lines dataset. + +If not created, will generate a handwritten lines dataset from the IAM paragraphs +dataset. + +""" +import json +from pathlib import Path +import random +from typing import List, Sequence, Tuple + +from loguru import logger +from PIL import Image, ImageFile, ImageOps +import numpy as np +from torch import Tensor +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + +from text_recognizer.data.base_dataset import ( + BaseDataset, + convert_strings_to_labels, + split_dataset, +) +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info +from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.iam import IAM +from text_recognizer.data import image_utils + + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +SEED = 4711 +PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines" +IMAGE_HEIGHT = 56 +IMAGE_WIDTH = 1024 + + +class IAMLines(BaseDataModule): + """IAM handwritten lines dataset.""" + + def __init__( + self, + augment: bool = True, + fraction: float = 0.8, + batch_size: int = 128, + num_workers: int = 0, + ) -> None: + # TODO: add transforms + super().__init__(batch_size, num_workers) + self.augment = augment + self.fraction = fraction + self.mapping, self.inverse_mapping, _ = emnist_mapping() + self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) + self.output_dims = (89, 1) + self.data_train: BaseDataset = None + self.data_val: BaseDataset = None + self.data_test: BaseDataset = None + + def prepare_data(self) -> None: + """Creates the IAM lines dataset if not existing.""" + if PROCESSED_DATA_DIRNAME.exists(): + return + + logger.info("Cropping IAM lines regions...") + iam = IAM() + iam.prepare_data() + crops_train, labels_train = line_crops_and_labels(iam, "train") + crops_test, labels_test = line_crops_and_labels(iam, "test") + + shapes = np.array([crop.size for crop in crops_train + crops_test]) + aspect_ratios = shapes[:, 0] / shapes[:, 1] + + logger.info("Saving images, labels, and statistics...") + save_images_and_labels( + crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME + ) + save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) + + with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="w") as f: + f.write(str(aspect_ratios.max())) + + def setup(self, stage: str = None) -> None: + """Load data for training/testing.""" + with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="r") as f: + max_aspect_ratio = float(f.read()) + image_width = int(IMAGE_HEIGHT * max_aspect_ratio) + if image_width >= IMAGE_WIDTH: + raise ValueError("image_width equal or greater than IMAGE_WIDTH") + + if stage == "fit" or stage is None: + x_train, labels_train = load_line_crops_and_labels( + "train", PROCESSED_DATA_DIRNAME + ) + if self.output_dims[0] < max([len(l) for l in labels_train]) + 2: + raise ValueError("Target length longer than max output length.") + + y_train = convert_strings_to_labels( + labels_train, self.inverse_mapping, length=self.output_dims[0] + ) + data_train = BaseDataset( + x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment) + ) + + self.data_train, self.data_val = split_dataset( + dataset=data_train, fraction=self.fraction, seed=SEED + ) + + if stage == "test" or stage is None: + x_test, labels_test = load_line_crops_and_labels( + "test", PROCESSED_DATA_DIRNAME + ) + + if self.output_dims[0] < max([len(l) for l in labels_test]) + 2: + raise ValueError("Taget length longer than max output length.") + + y_test = convert_strings_to_labels( + labels_test, self.inverse_mapping, length=self.output_dims[0] + ) + self.data_test = BaseDataset( + x_test, y_test, transform=get_transform(IMAGE_WIDTH) + ) + + if stage is None: + self._verify_output_dims(labels_train, labels_test) + + def _verify_output_dims(self, labels_train: Tensor, labels_test: Tensor) -> None: + max_label_length = max([len(label) for label in labels_train + labels_test]) + 2 + output_dims = (max_label_length, 1) + if output_dims != self.output_dims: + raise ValueError("Output dim does not match expected output dims.") + + def __repr__(self) -> str: + """Return information about the dataset.""" + basic = ( + "IAM Lines dataset\n" + f"Num classes: {len(self.mapping)}\n" + f"Input dims: {self.dims}\n" + f"Output dims: {self.output_dims}\n" + ) + + if not any([self.data_train, self.data_val, self.data_test]): + return basic + + x, y = next(iter(self.train_dataloader())) + xt, yt = next(iter(self.test_dataloader())) + data = ( + "Train/val/test sizes: " + f"{len(self.data_train)}, " + f"{len(self.data_val)}, " + f"{len(self.data_test)}\n" + f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" + f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" + f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" + ) + return basic + data + + +def line_crops_and_labels(iam: IAM, split: str) -> Tuple[List, List]: + """Load IAM line labels and regions, and load image crops.""" + crops = [] + labels = [] + for filename in iam.form_filenames: + if not iam.split_by_id[filename.stem] == split: + continue + image = image_utils.read_image_pil(filename) + image = ImageOps.grayscale(image) + image = ImageOps.invert(image) + labels += iam.line_strings_by_id[filename.stem] + crops += [ + image.crop([region[box] for box in ["x1", "y1", "x2", "y2"]]) + for region in iam.line_regions_by_id[filename.stem] + ] + if len(crops) != len(labels): + raise ValueError("Length of crops does not match length of labels") + return crops, labels + + +def save_images_and_labels( + crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path +) -> None: + (data_dirname / split).mkdir(parents=True, exist_ok=True) + + with (data_dirname / split / "_labels.json").open(mode="w") as f: + json.dump(labels, f) + + for index, crop in enumerate(crops): + crop.save(data_dirname / split / f"{index}.png") + + +def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, List]: + """Load line crops and labels for given split from processed directoru.""" + with (data_dirname / split / "_labels.json").open(mode="r") as f: + labels = json.load(f) + + crop_filenames = sorted( + (data_dirname / split).glob("*.png"), + key=lambda filename: int(Path(filename).stem), + ) + crops = [ + image_utils.read_image_pil(filename, grayscale=True) + for filename in crop_filenames + ] + + if len(crops) != len(labels): + raise ValueError("Length of crops does not match length of labels") + + return crops, labels + + +def get_transform(image_width: int, augment: bool = False) -> transforms.Compose: + """Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise.""" + + def embed_crop(crop: Image, augment: bool = augment, image_width: int = image_width) -> Image: + # Crop is PIL.Image of dtype="L" (so value range is [0, 255]) + image = Image.new("L", (image_width, IMAGE_HEIGHT)) + + # Resize crop. + crop_width, crop_height = crop.size + new_crop_height = IMAGE_HEIGHT + new_crop_width = int(new_crop_height * (crop_width / crop_height)) + + if augment: + # Add random stretching + new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) + new_crop_width = min(new_crop_width, image_width) + crop_resized = crop.resize( + (new_crop_width, new_crop_height), resample=Image.BILINEAR + ) + + # Embed in image + x = min(28, image_width - new_crop_width) + y = IMAGE_HEIGHT - new_crop_height + image.paste(crop_resized, (x, y)) + + return image + + transfroms_list = [transforms.Lambda(embed_crop)] + + if augment: + transfroms_list += [ + transforms.ColorJitter(brightness=(0.8, 1.6)), + transforms.RandomAffine( + degrees=1, + shear=(-30, 20), + interpolation=InterpolationMode.BILINEAR, + fill=0, + ), + ] + transfroms_list.append(transforms.ToTensor()) + return transforms.Compose(transfroms_list) + + +def generate_iam_lines() -> None: + load_and_print_info(IAMLines) diff --git a/text_recognizer/data/iam_lines_dataset.py b/text_recognizer/data/iam_lines_dataset.py deleted file mode 100644 index 1cb84bd..0000000 --- a/text_recognizer/data/iam_lines_dataset.py +++ /dev/null @@ -1,110 +0,0 @@ -"""IamLinesDataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import h5py -from loguru import logger -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.util import ( - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, -) - - -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" -PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" -PROCESSED_DATA_URL = ( - "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" -) - - -class IamLinesDataset(Dataset): - """IAM lines datasets for handwritten text lines.""" - - def __init__( - self, - train: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - self.pad_token = "_" if pad_token is None else pad_token - - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - init_token=init_token, - pad_token=pad_token, - eos_token=eos_token, - lower=lower, - ) - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self.data.shape[1:] if self.data is not None else None - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return ( - self.targets.shape[1:] + (self.num_classes,) - if self.targets is not None - else None - ) - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - if not PROCESSED_DATA_FILENAME.exists(): - PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - logger.info("Downloading IAM lines...") - download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self._data = f[f"x_{self.split}"][:] - self._targets = f[f"y_{self.split}"][:] - self._subsample() - - def __repr__(self) -> str: - """Print info about the dataset.""" - return ( - "IAM Lines Dataset\n" # pylint: disable=no-member - f"Number classes: {self.num_classes}\n" - f"Mapping: {self.mapper.mapping}\n" - f"Data: {self.data.shape}\n" - f"Targets: {self.targets.shape}\n" - ) - - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - if self.transform: - data = self.transform(data) - - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index a93eb00..5d0fad6 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -5,7 +5,6 @@ The code is mostly stolen from: https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py """ - import collections import itertools from pathlib import Path diff --git a/text_recognizer/data/image_utils.py b/text_recognizer/data/image_utils.py new file mode 100644 index 0000000..c2b8915 --- /dev/null +++ b/text_recognizer/data/image_utils.py @@ -0,0 +1,49 @@ +"""Image util functions for loading and saving images.""" +from pathlib import Path +from typing import Union +from urllib.request import urlopen + +import cv2 +import numpy as np +from PIL import Image + + +def read_image_pil(image_uri: Union[Path, str], grayscale: bool = False) -> Image: + """Return PIL image.""" + image = Image.open(image_uri) + if grayscale: + image = image.convert("L") + return image + + +def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.array: + """Read image_uri.""" + + if isinstance(image_uri, str): + image_uri = Path(image_uri) + + def read_image_from_filename(image_filename: Path, imread_flag: int) -> np.array: + return cv2.imread(str(image_filename), imread_flag) + + def read_image_from_url(image_url: Path, imread_flag: int) -> np.array: + url_response = urlopen(str(image_url)) # nosec + image_array = np.array(bytearray(url_response.read()), dtype=np.uint8) + return cv2.imdecode(image_array, imread_flag) + + imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR + image = None + + if image_uri.exists(): + image = read_image_from_filename(image_uri, imread_flag) + else: + image = read_image_from_url(image_uri, imread_flag) + + if image is None: + raise ValueError(f"Could not load image at {image_uri}") + + return image + + +def write_image(image: np.ndarray, filename: Union[Path, str]) -> None: + """Write image to file.""" + cv2.imwrite(str(filename), image) diff --git a/text_recognizer/data/sentence_generator.py b/text_recognizer/data/sentence_generator.py index 53b781c..f09703b 100644 --- a/text_recognizer/data/sentence_generator.py +++ b/text_recognizer/data/sentence_generator.py @@ -1,5 +1,4 @@ """Downloading the Brown corpus with NLTK for sentence generating.""" - import itertools import re import string @@ -9,9 +8,9 @@ import nltk from nltk.corpus.reader.util import ConcatenatedCorpusView import numpy as np -from text_recognizer.datasets.util import DATA_DIRNAME +from text_recognizer.data.base_data_module import BaseDataModule -NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk" +NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" class SentenceGenerator: @@ -47,7 +46,7 @@ class SentenceGenerator: raise ValueError( "Must provide max_length to this method or when making this object." ) - + for _ in range(10): try: index = np.random.randint(0, len(self.word_start_indices) - 1) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index b6a48f5..2291eec 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -1,148 +1,15 @@ """Transforms for PyTorch datasets.""" from abc import abstractmethod from pathlib import Path -import random from typing import Any, Optional, Union from loguru import logger -import numpy as np -from PIL import Image import torch from torch import Tensor -import torch.nn.functional as F -from torchvision import transforms -from torchvision.transforms import ( - ColorJitter, - Compose, - Normalize, - RandomAffine, - RandomHorizontalFlip, - RandomRotation, - ToPILImage, - ToTensor, -) from text_recognizer.datasets.iam_preprocessor import Preprocessor -from text_recognizer.datasets.util import EmnistMapper - - -class RandomResizeCrop: - """Image transform with random resize and crop applied. - - Stolen from - - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - - """ - - def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: - self.jitter = jitter - self.ratio = ratio - - def __call__(self, img: np.ndarray) -> np.ndarray: - """Applies random crop and rotation to an image.""" - w, h = img.size - - # pad with white: - img = transforms.functional.pad(img, self.jitter, fill=255) - - # crop at random (x, y): - x = self.jitter + random.randint(-self.jitter, self.jitter) - y = self.jitter + random.randint(-self.jitter, self.jitter) - - # randomize aspect ratio: - size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) - size = (h, int(size_w)) - img = transforms.functional.resized_crop(img, y, x, h, w, size) - return img - - -class Transpose: - """Transposes the EMNIST image to the correct orientation.""" - - def __call__(self, image: Image) -> np.ndarray: - """Swaps axis.""" - return np.array(image).swapaxes(0, 1) - - -class Resize: - """Resizes a tensor to a specified width.""" - - def __init__(self, width: int = 952) -> None: - # The default is 952 because of the IAM dataset. - self.width = width - - def __call__(self, image: Tensor) -> Tensor: - """Resize tensor in the last dimension.""" - return F.interpolate(image, size=self.width, mode="nearest") - - -class AddTokens: - """Adds start of sequence and end of sequence tokens to target tensor.""" - - def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - if self.init_token is not None: - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - ) - else: - self.emnist_mapper = EmnistMapper( - pad_token=self.pad_token, eos_token=self.eos_token, - ) - self.pad_value = self.emnist_mapper(self.pad_token) - self.eos_value = self.emnist_mapper(self.eos_token) - - def __call__(self, target: Tensor) -> Tensor: - """Adds a sos token to the begining and a eos token to the end of a target sequence.""" - dtype, device = target.dtype, target.device - - # Find the where padding starts. - pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() - - target[pad_index] = self.eos_value - - if self.init_token is not None: - self.sos_value = self.emnist_mapper(self.init_token) - sos = torch.tensor([self.sos_value], dtype=dtype, device=device) - target = torch.cat([sos, target], dim=0) - - return target - - -class ApplyContrast: - """Sets everything below a threshold to zero, i.e. increase contrast.""" - - def __init__(self, low: float = 0.0, high: float = 0.25) -> None: - self.low = low - self.high = high - - def __call__(self, x: Tensor) -> Tensor: - """Apply mask binary mask to input tensor.""" - mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) - return x * mask - - -class Unsqueeze: - """Add a dimension to the tensor.""" - - def __call__(self, x: Tensor) -> Tensor: - """Adds dim.""" - return x.unsqueeze(0) - - -class Squeeze: - """Removes the first dimension of a tensor.""" - - def __call__(self, x: Tensor) -> Tensor: - """Removes first dim.""" - return x.squeeze(0) - - +from text_recognizer.data.emnist import emnist_mapping + class ToLower: """Converts target to lower case.""" @@ -155,29 +22,14 @@ class ToLower: class ToCharcters: """Converts integers to characters.""" - def __init__( - self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True - ) -> None: - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - if self.init_token is not None: - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - lower=lower, - ) - else: - self.emnist_mapper = EmnistMapper( - pad_token=self.pad_token, eos_token=self.eos_token, lower=lower - ) + def __init__(self) -> None: + self.mapping, _, _ = emnist_mapping() def __call__(self, y: Tensor) -> str: """Converts a Tensor to a str.""" return ( - "".join([self.emnist_mapper(int(i)) for i in y]) - .strip("_") + "".join([self.mapping(int(i)) for i in y]) + .strip("<p>") .replace(" ", "▁") ) |