From 8248f173132dfb7e47ec62b08e9235990c8626e3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 24 Mar 2021 22:15:54 +0100 Subject: renamed datasets to data, added iam refactor --- text_recognizer/data/__init__.py | 1 + text_recognizer/data/base_data_module.py | 89 ++++++++ text_recognizer/data/base_dataset.py | 73 +++++++ text_recognizer/data/download_utils.py | 73 +++++++ text_recognizer/data/emnist.py | 210 ++++++++++++++++++ text_recognizer/data/emnist_essentials.json | 1 + text_recognizer/data/emnist_lines.py | 280 ++++++++++++++++++++++++ text_recognizer/data/iam.py | 120 ++++++++++ text_recognizer/data/iam_dataset.py | 133 +++++++++++ text_recognizer/data/iam_lines_dataset.py | 110 ++++++++++ text_recognizer/data/iam_paragraphs_dataset.py | 291 +++++++++++++++++++++++++ text_recognizer/data/iam_preprocessor.py | 196 +++++++++++++++++ text_recognizer/data/sentence_generator.py | 85 ++++++++ text_recognizer/data/transforms.py | 266 ++++++++++++++++++++++ text_recognizer/data/util.py | 209 ++++++++++++++++++ 15 files changed, 2137 insertions(+) create mode 100644 text_recognizer/data/__init__.py create mode 100644 text_recognizer/data/base_data_module.py create mode 100644 text_recognizer/data/base_dataset.py create mode 100644 text_recognizer/data/download_utils.py create mode 100644 text_recognizer/data/emnist.py create mode 100644 text_recognizer/data/emnist_essentials.json create mode 100644 text_recognizer/data/emnist_lines.py create mode 100644 text_recognizer/data/iam.py create mode 100644 text_recognizer/data/iam_dataset.py create mode 100644 text_recognizer/data/iam_lines_dataset.py create mode 100644 text_recognizer/data/iam_paragraphs_dataset.py create mode 100644 text_recognizer/data/iam_preprocessor.py create mode 100644 text_recognizer/data/sentence_generator.py create mode 100644 text_recognizer/data/transforms.py create mode 100644 text_recognizer/data/util.py (limited to 'text_recognizer/data') diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py new file mode 100644 index 0000000..2727b20 --- /dev/null +++ b/text_recognizer/data/__init__.py @@ -0,0 +1 @@ +"""Dataset modules.""" diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py new file mode 100644 index 0000000..f5e7300 --- /dev/null +++ b/text_recognizer/data/base_data_module.py @@ -0,0 +1,89 @@ +"""Base lightning DataModule class.""" +from pathlib import Path +from typing import Dict + +import pytorch_lightning as pl +from torch.utils.data import DataLoader + + +def load_and_print_info(data_module_class: type) -> None: + """Load EMNISTLines and prints info.""" + dataset = data_module_class() + dataset.prepare_data() + dataset.setup() + print(dataset) + + +class BaseDataModule(pl.LightningDataModule): + """Base PyTorch Lightning DataModule.""" + + def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: + super().__init__() + self.batch_size = batch_size + self.num_workers = num_workers + + # Placeholders for subclasses. + self.dims = None + self.output_dims = None + self.mapping = None + + @classmethod + def data_dirname(cls) -> Path: + """Return the path to the base data directory.""" + return Path(__file__).resolve().parents[2] / "data" + + def config(self) -> Dict: + """Return important settings of the dataset.""" + return { + "input_dim": self.dims, + "output_dims": self.output_dims, + "mapping": self.mapping, + } + + def prepare_data(self) -> None: + """Prepare data for training.""" + pass + + def setup(self, stage: str = None) -> None: + """Split into train, val, test, and set dims. + + Should assign `torch Dataset` objects to self.data_train, self.data_val, and + optionally self.data_test. + + Args: + stage (Any): Variable to set splits. + + """ + self.data_train = None + self.data_val = None + self.data_test = None + + def train_dataloader(self) -> DataLoader: + """Retun DataLoader for train data.""" + return DataLoader( + self.data_train, + shuffle=True, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) + + def val_dataloader(self) -> DataLoader: + """Return DataLoader for val data.""" + return DataLoader( + self.data_val, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) + + def test_dataloader(self) -> DataLoader: + """Return DataLoader for val data.""" + return DataLoader( + self.data_test, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py new file mode 100644 index 0000000..a9e9c24 --- /dev/null +++ b/text_recognizer/data/base_dataset.py @@ -0,0 +1,73 @@ +"""Base PyTorch Dataset class.""" +from typing import Any, Callable, Dict, Sequence, Tuple, Union + +import torch +from torch import Tensor +from torch.utils.data import Dataset + + +class BaseDataset(Dataset): + """ + Base Dataset class that processes data and targets through optional transfroms. + + Args: + data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images. + targets (Union[Sequence, Tensor]): Torch tensors or numpy arrays. + tranform (Callable): Function that takes a datum and applies transforms. + target_transform (Callable): Fucntion that takes a target and applies + target transforms. + """ + + def __init__( + self, + data: Union[Sequence, Tensor], + targets: Union[Sequence, Tensor], + transform: Callable = None, + target_transform: Callable = None, + ) -> None: + if len(data) != len(targets): + raise ValueError("Data and targets must be of equal length.") + self.data = data + self.targets = targets + self.transform = transform + self.target_transform = target_transform + + def __len__(self) -> int: + """Return the length of the dataset.""" + return len(self.data) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """Return a datum and its target, after processing by transforms. + + Args: + index (int): Index of a datum in the dataset. + + Returns: + Tuple[Any, Any]: Datum and target pair. + + """ + datum, target = self.data[index], self.targets[index] + + if self.transform is not None: + datum = self.transform(datum) + + if self.target_transform is not None: + target = self.target_transform(target) + + return datum, target + + +def convert_strings_to_labels( + strings: Sequence[str], mapping: Dict[str, int], length: int +) -> Tensor: + """ + Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with and tokens, + and padded wiht the

token. + """ + labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] + for i, string in enumerate(strings): + tokens = list(string) + tokens = ["", *tokens, ""] + for j, token in enumerate(tokens): + labels[i, j] = mapping[token] + return labels diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py new file mode 100644 index 0000000..e3dc68c --- /dev/null +++ b/text_recognizer/data/download_utils.py @@ -0,0 +1,73 @@ +"""Util functions for downloading datasets.""" +import hashlib +from pathlib import Path +from typing import Dict, List, Optional +from urllib.request import urlretrieve + +from loguru import logger +from tqdm import tqdm + + +def _compute_sha256(filename: Path) -> str: + """Returns the SHA256 checksum of a file.""" + with filename.open(mode="rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +class TqdmUpTo(tqdm): + """TQDM progress bar when downloading files. + + From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py + + """ + + def update_to( + self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None + ) -> None: + """Updates the progress bar. + + Args: + blocks (int): Number of blocks transferred so far. Defaults to 1. + block_size (int): Size of each block, in tqdm units. Defaults to 1. + total_size (Optional[int]): Total size in tqdm units. Defaults to None. + """ + if total_size is not None: + self.total = total_size # pylint: disable=attribute-defined-outside-init + self.update(blocks * block_size - self.n) + + +def _download_url(url: str, filename: str) -> None: + """Downloads a file from url to filename, with a progress bar.""" + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec + + +def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]: + """Downloads dataset using a metadata file. + + Args: + metadata (Dict): A metadata file of the dataset. + dl_dir (Path): Download directory for the dataset. + + Returns: + Optional[Path]: Returns filename if dataset is downloaded, None if it already + exists. + + Raises: + ValueError: If the SHA-256 value is not the same between the dataset and + the metadata file. + + """ + dl_dir.mkdir(parents=True, exist_ok=True) + filename = dl_dir / metadata["filename"] + if filename.exists(): + return + logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") + _download_url(metadata["url"], filename) + logger.info("Computing the SHA-256...") + sha256 = _compute_sha256(filename) + if sha256 != metadata["sha256"]: + raise ValueError( + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) + return filename diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py new file mode 100644 index 0000000..7f67893 --- /dev/null +++ b/text_recognizer/data/emnist.py @@ -0,0 +1,210 @@ +"""EMNIST dataset: downloads it from FSDL aws url if not present.""" +from pathlib import Path +from typing import Sequence, Tuple +import json +import os +import shutil +import zipfile + +import h5py +import numpy as np +from loguru import logger +import toml +import torch +from torch.utils.data import random_split +from torchvision import transforms + +from text_recognizer.data.base_dataset import BaseDataset +from text_recognizer.data.base_data_module import ( + BaseDataModule, + load_and_print_info, +) +from text_recognizer.data.download_utils import download_dataset + + +SEED = 4711 +NUM_SPECIAL_TOKENS = 4 +SAMPLE_TO_BALANCE = True + +RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" +PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" +PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" +ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" + + +class EMNIST(BaseDataModule): + """ + "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 + and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." + From https://www.nist.gov/itl/iad/image-group/emnist-dataset + + The data split we will use is + EMNIST ByClass: 814,255 characters. 62 unbalanced classes. + """ + + def __init__( + self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 + ) -> None: + super().__init__(batch_size, num_workers) + if not ESSENTIALS_FILENAME.exists(): + _download_and_process_emnist() + with ESSENTIALS_FILENAME.open() as f: + essentials = json.load(f) + self.train_fraction = train_fraction + self.mapping = list(essentials["characters"]) + self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} + self.data_train = None + self.data_val = None + self.data_test = None + self.transform = transforms.Compose([transforms.ToTensor()]) + self.dims = (1, *essentials["input_shape"]) + self.output_dims = (1,) + + def prepare_data(self) -> None: + if not PROCESSED_DATA_FILENAME.exists(): + _download_and_process_emnist() + + def setup(self, stage: str = None) -> None: + if stage == "fit" or stage is None: + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + self.x_train = f["x_train"][:] + self.y_train = f["y_train"][:].squeeze().astype(int) + + dataset_train = BaseDataset( + self.x_train, self.y_train, transform=self.transform + ) + train_size = int(self.train_fraction * len(dataset_train)) + val_size = len(dataset_train) - train_size + self.data_train, self.data_val = random_split( + dataset_train, [train_size, val_size], generator=torch.Generator() + ) + + if stage == "test" or stage is None: + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + self.x_test = f["x_test"][:] + self.y_test = f["y_test"][:].squeeze().astype(int) + self.data_test = BaseDataset( + self.x_test, self.y_test, transform=self.transform + ) + + def __repr__(self) -> str: + basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" + if not any([self.data_train, self.data_val, self.data_test]): + return basic + + datum, target = next(iter(self.train_dataloader())) + data = ( + f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" + f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n" + f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n" + ) + + return basic + data + + +def _download_and_process_emnist() -> None: + metadata = toml.load(METADATA_FILENAME) + download_dataset(metadata, DL_DATA_DIRNAME) + _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) + + +def _process_raw_dataset(filename: str, dirname: Path) -> None: + logger.info("Unzipping EMNIST...") + curdir = os.getcwd() + os.chdir(dirname) + content = zipfile.ZipFile(filename, "r") + content.extract("matlab/emnist-byclass.mat") + + from scipy.io import loadmat + + logger.info("Loading training data from .mat file") + data = loadmat("matlab/emnist-byclass.mat") + x_train = ( + data["dataset"]["train"][0, 0]["images"][0, 0] + .reshape(-1, 28, 28) + .swapaxes(1, 2) + ) + y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS + x_test = ( + data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + ) + y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS + + if SAMPLE_TO_BALANCE: + logger.info("Balancing classes to reduce amount of data") + x_train, y_train = _sample_to_balance(x_train, y_train) + x_test, y_test = _sample_to_balance(x_test, y_test) + + logger.info("Saving to HDF5 in a compressed format...") + PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: + f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") + f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") + f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") + f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") + + logger.info("Saving essential dataset parameters to text_recognizer/datasets...") + mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} + characters = _augment_emnist_characters(mapping.values()) + essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} + + with ESSENTIALS_FILENAME.open(mode="w") as f: + json.dump(essentials, f) + + logger.info("Cleaning up...") + shutil.rmtree("matlab") + os.chdir(curdir) + + +def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Balances the dataset by taking the mean number of instances per class.""" + np.random.seed(SEED) + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + indices = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(indices, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled = y[indices] + return x_sampled, y_sampled + + +def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: + """Augment the mapping with extra symbols.""" + # Extra characters from the IAM dataset. + iam_characters = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # Also add special tokens for: + # - CTC blank token at index 0 + # - Start token at index 1 + # - End token at index 2 + # - Padding token at index 3 + # Note: Do not forget to update NUM_SPECIAL_TOKENS if changing this! + return ["", "", "", "

", *characters, *iam_characters] + + +def download_emnist() -> None: + """Download dataset from internet, if it does not exists, and displays info.""" + load_and_print_info(EMNIST) diff --git a/text_recognizer/data/emnist_essentials.json b/text_recognizer/data/emnist_essentials.json new file mode 100644 index 0000000..3f46a73 --- /dev/null +++ b/text_recognizer/data/emnist_essentials.json @@ -0,0 +1 @@ +{"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py new file mode 100644 index 0000000..6c14add --- /dev/null +++ b/text_recognizer/data/emnist_lines.py @@ -0,0 +1,280 @@ +"""Dataset of generated text from EMNIST characters.""" +from collections import defaultdict +from pathlib import Path +from typing import Callable, Dict, Tuple, Sequence + +import h5py +from loguru import logger +import numpy as np +from PIL import Image +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + +from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels +from text_recognizer.data.base_data_module import ( + BaseDataModule, + load_and_print_info, +) +from text_recognizer.data.emnist import EMNIST +from text_recognizer.data.sentence_generator import SentenceGenerator + + +DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" +ESSENTIALS_FILENAME = ( + Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json" +) + +SEED = 4711 +IMAGE_HEIGHT = 56 +IMAGE_WIDTH = 1024 +IMAGE_X_PADDING = 28 +MAX_OUTPUT_LENGTH = 89 # Same as IAMLines + + +class EMNISTLines(BaseDataModule): + """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" + + def __init__( + self, + augment: bool = True, + batch_size: int = 128, + num_workers: int = 0, + max_length: int = 32, + min_overlap: float = 0.0, + max_overlap: float = 0.33, + num_train: int = 10_000, + num_val: int = 2_000, + num_test: int = 2_000, + ) -> None: + super().__init__(batch_size, num_workers) + + self.augment = augment + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + + self.emnist = EMNIST() + self.mapping = self.emnist.mapping + max_width = ( + int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + + IMAGE_X_PADDING + ) + + if max_width >= IMAGE_WIDTH: + raise ValueError( + f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" + ) + + self.dims = ( + self.emnist.dims[0], + IMAGE_HEIGHT, + IMAGE_WIDTH + ) + + if self.max_length >= MAX_OUTPUT_LENGTH: + raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") + + self.output_dims = (MAX_OUTPUT_LENGTH, 1) + self.data_train = None + self.data_val = None + self.data_test = None + + @property + def data_filename(self) -> Path: + """Return name of dataset.""" + return ( + DATA_DIRNAME / (f"ml_{self.max_length}_" + f"o{self.min_overlap:f}_{self.max_overlap:f}_" + f"ntr{self.num_train}_" + f"ntv{self.num_val}_" + f"nte{self.num_test}.h5") + ) + + def prepare_data(self) -> None: + if self.data_filename.exists(): + return + np.random.seed(SEED) + self._generate_data("train") + self._generate_data("val") + self._generate_data("test") + + def setup(self, stage: str = None) -> None: + logger.info("EMNISTLinesDataset loading data from HDF5...") + if stage == "fit" or stage is None: + print(self.data_filename) + with h5py.File(self.data_filename, "r") as f: + x_train = f["x_train"][:] + y_train = torch.LongTensor(f["y_train"][:]) + x_val = f["x_val"][:] + y_val = torch.LongTensor(f["y_val"][:]) + + self.data_train = BaseDataset( + x_train, y_train, transform=_get_transform(augment=self.augment) + ) + self.data_val = BaseDataset( + x_val, y_val, transform=_get_transform(augment=self.augment) + ) + + if stage == "test" or stage is None: + with h5py.File(self.data_filename, "r") as f: + x_test = f["x_test"][:] + y_test = torch.LongTensor(f["y_test"][:]) + + self.data_test = BaseDataset( + x_test, y_test, transform=_get_transform(augment=False) + ) + + def __repr__(self) -> str: + """Return str about dataset.""" + basic = ( + "EMNISTLines2 Dataset\n" # pylint: disable=no-member + f"Min overlap: {self.min_overlap}\n" + f"Max overlap: {self.max_overlap}\n" + f"Num classes: {len(self.mapping)}\n" + f"Dims: {self.dims}\n" + f"Output dims: {self.output_dims}\n" + ) + + if not any([self.data_train, self.data_val, self.data_test]): + return basic + + x, y = next(iter(self.train_dataloader())) + data = ( + f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" + f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" + ) + return basic + data + + def _generate_data(self, split: str) -> None: + logger.info(f"EMNISTLines generating data for {split}...") + sentence_generator = SentenceGenerator( + self.max_length - 2 + ) # Subtract by 2 because start/end token + + emnist = self.emnist + emnist.prepare_data() + emnist.setup() + + if split == "train": + samples_by_char = _get_samples_by_char( + emnist.x_train, emnist.y_train, emnist.mapping + ) + num = self.num_train + elif split == "val": + samples_by_char = _get_samples_by_char( + emnist.x_train, emnist.y_train, emnist.mapping + ) + num = self.num_val + else: + samples_by_char = _get_samples_by_char( + emnist.x_test, emnist.y_test, emnist.mapping + ) + num = self.num_test + + DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(self.data_filename, "a") as f: + x, y = _create_dataset_of_images( + num, + samples_by_char, + sentence_generator, + self.min_overlap, + self.max_overlap, + self.dims, + ) + y = convert_strings_to_labels( + y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH + ) + f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") + f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") + + +def _get_samples_by_char( + samples: np.ndarray, labels: np.ndarray, mapping: Dict +) -> defaultdict: + samples_by_char = defaultdict(list) + for sample, label in zip(samples, labels): + samples_by_char[mapping[label]].append(sample) + return samples_by_char + + +def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): + null_image = torch.zeros((28, 28), dtype=torch.uint8) + sample_image_by_char = {} + for char in string: + if char in sample_image_by_char: + continue + samples = samples_by_char[char] + sample = samples[np.random.choice(len(samples))] if samples else null_image + sample_image_by_char[char] = sample.reshape(28, 28) + return [sample_image_by_char[char] for char in string] + + +def _construct_image_from_string( + string: str, + samples_by_char: defaultdict, + min_overlap: float, + max_overlap: float, + width: int, +) -> torch.Tensor: + overlap = np.random.uniform(min_overlap, max_overlap) + sampled_images = _select_letter_samples_for_string(string, samples_by_char) + N = len(sampled_images) + H, W = sampled_images[0].shape + next_overlap_width = W - int(overlap * W) + concatenated_image = torch.zeros((H, width), dtype=torch.uint8) + x = IMAGE_X_PADDING + for image in sampled_images: + concatenated_image[:, x : (x + W)] += image + x += next_overlap_width + return torch.minimum(torch.Tensor([255]), concatenated_image) + + +def _create_dataset_of_images( + num_samples: int, + samples_by_char: defaultdict, + sentence_generator: SentenceGenerator, + min_overlap: float, + max_overlap: float, + dims: Tuple, +) -> Tuple[torch.Tensor, torch.Tensor]: + images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) + labels = [] + for n in range(num_samples): + label = sentence_generator.generate() + crop = _construct_image_from_string( + label, samples_by_char, min_overlap, max_overlap, dims[-1] + ) + height = crop.shape[0] + y = (IMAGE_HEIGHT - height) // 2 + images[n, y : (y + height), :] = crop + labels.append(label) + return images, labels + + +def _get_transform(augment: bool = False) -> Callable: + if not augment: + return transforms.Compose([transforms.ToTensor()]) + return transforms.Compose( + [ + transforms.ToTensor(), + transforms.ColorJitter(brightness=(0.5, 1.0)), + transforms.RandomAffine( + degrees=3, + translate=(0.0, 0.05), + scale=(0.4, 1.1), + shear=(-40, 50), + interpolation=InterpolationMode.BILINEAR, + fill=0, + ), + ] + ) + + +def generate_emnist_lines() -> None: + """Generates a synthetic handwritten dataset and displays info,""" + load_and_print_info(EMNISTLines) diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py new file mode 100644 index 0000000..fcfe9a7 --- /dev/null +++ b/text_recognizer/data/iam.py @@ -0,0 +1,120 @@ +"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" +import os +from pathlib import Path +from typing import Any, Dict, List +import xml.etree.ElementTree as ElementTree +import zipfile + +from boltons.cacheutils import cachedproperty +from loguru import logger +from PIL import Image +import toml + +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info +from text_recognizer.data.download_utils import download_dataset + + +RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "iam" +EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb" + +DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. +LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates. + + +class IAM(BaseDataModule): + """ + "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, + which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels. + From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + The data split we will use is + IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. + The validation set has been merged into the train set. + The train set has 7,101 lines from 326 writers. + The test set has 1,861 lines from 128 writers. + The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. + """ + + def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: + super().__init__(batch_size, num_workers) + self.metadata = toml.load(METADATA_FILENAME) + + def prepare_data(self) -> None: + if self.xml_filenames: + return + filename = download_dataset(self.metadata, DL_DATA_DIRNAME) + _extract_raw_dataset(filename, DL_DATA_DIRNAME) + + @property + def xml_filenames(self) -> List[Path]: + return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) + + @property + def form_filenames(self) -> List[Path]: + return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) + + @property + def form_filenames_by_id(self) -> Dict[str, Path]: + return {filename.stem: filename for filename in self.form_filenames} + + @property + def split_by_id(self) -> Dict[str, str]: + return {filename.stem: "test" if filename.stem in self.metadata["test_ids"] else "trainval" for filename in self.form_filenames} + + @cachedproperty + def line_strings_by_id(self) -> Dict[str, List[str]]: + """Return a dict from name of IAM form to list of line texts in it.""" + return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} + + @cachedproperty + def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]: + """Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it.""" + return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} + + def __repr__(self) -> str: + """Return info about the dataset.""" + return ("IAM Dataset\n" + f"Num forms total: {len(self.xml_filenames)}\n" + f"Num in test set: {len(self.metadata['test_ids'])}\n") + + +def _extract_raw_dataset(filename: Path, dirname: Path) -> None: + logger.info("Extracting IAM data...") + curdir = os.getcwd() + os.chdir(dirname) + with zipfile.ZipFile(filename, "r") as f: + f.extractall() + os.chdir(curdir) + + +def _get_line_strings_from_xml_file(filename: str) -> List[str]: + """Get the text content of each line. Note that we replace ": with ".""" + xml_root_element = ElementTree.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + + +def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: + """Get line region dict for each line.""" + xml_root_element = ElementTree.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [_get_line_region_from_xml_file(el) for el in xml_line_elements] + + +def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]: + word_elements = xml_line.findall("word/cmp") + x1s = [int(el.attrib["x"]) for el in word_elements] + y1s = [int(el.attrib["y"]) for el in word_elements] + x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] + y2s = [int(el.attrib["x"]) + int(el.attrib["height"]) for el in word_elements] + return { + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": min(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": min(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } + + +def download_iam() -> None: + load_and_print_info(IAM) diff --git a/text_recognizer/data/iam_dataset.py b/text_recognizer/data/iam_dataset.py new file mode 100644 index 0000000..a8998b9 --- /dev/null +++ b/text_recognizer/data/iam_dataset.py @@ -0,0 +1,133 @@ +"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" +import os +from typing import Any, Dict, List +import zipfile + +from boltons.cacheutils import cachedproperty +import defusedxml.ElementTree as ET +from loguru import logger +import toml + +from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME + +RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" +RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + +DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. +LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. + + +class IamDataset: + """IAM dataset. + + "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, + which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." + From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + + The data split we will use is + IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. + The validation set has been merged into the train set. + The train set has 7,101 lines from 326 writers. + The test set has 1,861 lines from 128 writers. + The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. + + """ + + def __init__(self) -> None: + self.metadata = toml.load(METADATA_FILENAME) + + def load_or_generate_data(self) -> None: + """Downloads IAM dataset if xml files does not exist.""" + if not self.xml_filenames: + self._download_iam() + + @property + def xml_filenames(self) -> List: + """List of xml filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) + + @property + def form_filenames(self) -> List: + """List of forms filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) + + def _download_iam(self) -> None: + curdir = os.getcwd() + os.chdir(RAW_DATA_DIRNAME) + _download_raw_dataset(self.metadata) + _extract_raw_dataset(self.metadata) + os.chdir(curdir) + + @property + def form_filenames_by_id(self) -> Dict: + """Creates a dictionary with filenames as keys and forms as values.""" + return {filename.stem: filename for filename in self.form_filenames} + + @cachedproperty + def line_strings_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of line texts in it.""" + return { + filename.stem: _get_line_strings_from_xml_file(filename) + for filename in self.xml_filenames + } + + @cachedproperty + def line_regions_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" + return { + filename.stem: _get_line_regions_from_xml_file(filename) + for filename in self.xml_filenames + } + + def __repr__(self) -> str: + """Print info about dataset.""" + return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" + + +def _extract_raw_dataset(metadata: Dict) -> None: + logger.info("Extracting IAM data.") + with zipfile.ZipFile(metadata["filename"], "r") as zip_file: + zip_file.extractall() + + +def _get_line_strings_from_xml_file(filename: str) -> List[str]: + """Get the text content of each line. Note that we replace " with ".""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + + +def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: + """Get the line region dict for each line.""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [_get_line_region_from_xml_element(el) for el in xml_line_elements] + + +def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: + """Extracts coordinates for each line of text.""" + # TODO: fix input! + word_elements = xml_line.findall("word/cmp") + x1s = [int(el.attrib["x"]) for el in word_elements] + y1s = [int(el.attrib["y"]) for el in word_elements] + x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] + y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] + return { + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } + + +def main() -> None: + """Initializes the dataset and print info about the dataset.""" + dataset = IamDataset() + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/text_recognizer/data/iam_lines_dataset.py b/text_recognizer/data/iam_lines_dataset.py new file mode 100644 index 0000000..1cb84bd --- /dev/null +++ b/text_recognizer/data/iam_lines_dataset.py @@ -0,0 +1,110 @@ +"""IamLinesDataset class.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import h5py +from loguru import logger +import torch +from torch import Tensor +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) + + +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" +PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" +PROCESSED_DATA_URL = ( + "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" +) + + +class IamLinesDataset(Dataset): + """IAM lines datasets for handwritten text lines.""" + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + self.pad_token = "_" if pad_token is None else pad_token + + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + init_token=init_token, + pad_token=pad_token, + eos_token=eos_token, + lower=lower, + ) + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self.data.shape[1:] if self.data is not None else None + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return ( + self.targets.shape[1:] + (self.num_classes,) + if self.targets is not None + else None + ) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + if not PROCESSED_DATA_FILENAME.exists(): + PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info("Downloading IAM lines...") + download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + self._data = f[f"x_{self.split}"][:] + self._targets = f[f"y_{self.split}"][:] + self._subsample() + + def __repr__(self) -> str: + """Print info about the dataset.""" + return ( + "IAM Lines Dataset\n" # pylint: disable=no-member + f"Number classes: {self.num_classes}\n" + f"Mapping: {self.mapper.mapping}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets diff --git a/text_recognizer/data/iam_paragraphs_dataset.py b/text_recognizer/data/iam_paragraphs_dataset.py new file mode 100644 index 0000000..8ba5142 --- /dev/null +++ b/text_recognizer/data/iam_paragraphs_dataset.py @@ -0,0 +1,291 @@ +"""IamParagraphsDataset class and functions for data processing.""" +import random +from typing import Callable, Dict, List, Optional, Tuple, Union + +import click +import cv2 +import h5py +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torchvision.transforms import ToTensor + +from text_recognizer import util +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.iam_dataset import IamDataset +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) + +INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" +DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" +CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" +GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" + +PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines. +TEST_FRACTION = 0.2 +SEED = 4711 + + +class IamParagraphsDataset(Dataset): + """IAM Paragraphs dataset for paragraphs of handwritten text.""" + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + ) + # Load Iam dataset. + self.iam_dataset = IamDataset() + + self.num_classes = 3 + self._input_shape = (256, 256) + self._output_shape = self._input_shape + (self.num_classes,) + self._ids = None + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + seed = np.random.randint(SEED) + random.seed(seed) # apply this seed to target tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 + if self.transform: + data = self.transform(data) + + random.seed(seed) # apply this seed to target tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets.long() + + @property + def ids(self) -> Tensor: + """Ids of the dataset.""" + return self._ids + + def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: + """Get data target pair from id.""" + ind = self.ids.index(id_) + return self.data[ind], self.targets[ind] + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) + num_targets = len(self.iam_dataset.line_regions_by_id) + + if num_actual < num_targets - 2: + self._process_iam_paragraphs() + + self._data, self._targets, self._ids = _load_iam_paragraphs() + self._get_random_split() + self._subsample() + + def _get_random_split(self) -> None: + np.random.seed(SEED) + num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) + indices = np.random.permutation(self.data.shape[0]) + train_indices, test_indices = indices[:num_train], indices[num_train:] + if self.train: + self._data = self.data[train_indices] + self._targets = self.targets[train_indices] + else: + self._data = self.data[test_indices] + self._targets = self.targets[test_indices] + + def _process_iam_paragraphs(self) -> None: + """Crop the part with the text. + + For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are + self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel + corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line + """ + crop_dims = self._decide_on_crop_dims() + CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + GT_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info( + f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" + ) + for filename in self.iam_dataset.form_filenames: + id_ = filename.stem + line_region = self.iam_dataset.line_regions_by_id[id_] + _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) + + def _decide_on_crop_dims(self) -> Tuple[int, int]: + """Decide on the dimensions to crop out of the form image. + + Since image width is larger than a comfortable crop around the longest paragraph, + we will make the crop a square form factor. + And since the found dimensions 610x610 are pretty close to 512x512, + we might as well resize crops and make it exactly that, which lets us + do all kinds of power-of-2 pooling and upsampling should we choose to. + + Returns: + Tuple[int, int]: A tuple of crop dimensions. + + Raises: + RuntimeError: When max crop height is larger than max crop width. + + """ + + sample_form_filename = self.iam_dataset.form_filenames[0] + sample_image = util.read_image(sample_form_filename, grayscale=True) + max_crop_width = sample_image.shape[1] + max_crop_height = _get_max_paragraph_crop_height( + self.iam_dataset.line_regions_by_id + ) + if not max_crop_height <= max_crop_width: + raise RuntimeError( + f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" + ) + + crop_dims = (max_crop_width, max_crop_width) + logger.info( + f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." + ) + logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") + return crop_dims + + def __repr__(self) -> str: + """Return info about the dataset.""" + return ( + "IAM Paragraph Dataset\n" # pylint: disable=no-member + f"Num classes: {self.num_classes}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + +def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: + heights = [] + for regions in line_regions_by_id.values(): + min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + heights.append(height) + return max(heights) + + +def _crop_paragraph_image( + filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple +) -> None: + image = util.read_image(filename, grayscale=True) + + min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + crop_height = crop_dims[0] + buffer = (crop_height - height) // 2 + + # Generate image crop. + image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) + try: + image_crop[buffer : buffer + height] = image[min_y1:max_y2] + except Exception as e: # pylint: disable=broad-except + logger.error(f"Rescued {filename}: {e}") + return + + # Generate ground truth. + gt_image = np.zeros_like(image_crop, dtype=np.uint8) + for index, region in enumerate(line_regions): + gt_image[ + (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), + region["x1"] : region["x2"], + ] = (index % 2 + 1) + + # Generate image for debugging. + import matplotlib.pyplot as plt + + cmap = plt.get_cmap("Set1") + image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) + for index, region in enumerate(line_regions): + color = [255 * _ for _ in cmap(index)[:-1]] + cv2.rectangle( + image_crop_for_debug, + (region["x1"], region["y1"] - min_y1 + buffer), + (region["x2"], region["y2"] - min_y1 + buffer), + color, + 3, + ) + image_crop_for_debug = cv2.resize( + image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA + ) + util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") + + image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) + util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") + + gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) + util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") + + +def _load_iam_paragraphs() -> None: + logger.info("Loading IAM paragraph crops and ground truth from image files...") + images = [] + gt_images = [] + ids = [] + for filename in CROPS_DIRNAME.glob("*.jpg"): + id_ = filename.stem + image = util.read_image(filename, grayscale=True) + image = 1.0 - image / 255 + + gt_filename = GT_DIRNAME / f"{id_}.png" + gt_image = util.read_image(gt_filename, grayscale=True) + + images.append(image) + gt_images.append(gt_image) + ids.append(id_) + images = np.array(images).astype(np.float32) + gt_images = np.array(gt_images).astype(np.uint8) + ids = np.array(ids) + return images, gt_images, ids + + +@click.command() +@click.option( + "--subsample_fraction", + type=float, + default=None, + help="The subsampling factor of the dataset.", +) +def main(subsample_fraction: float) -> None: + """Load dataset and print info.""" + logger.info("Creating train set...") + dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + logger.info("Creating test set...") + dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py new file mode 100644 index 0000000..a93eb00 --- /dev/null +++ b/text_recognizer/data/iam_preprocessor.py @@ -0,0 +1,196 @@ +"""Preprocessor for extracting word letters from the IAM dataset. + +The code is mostly stolen from: + + https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py + +""" + +import collections +import itertools +from pathlib import Path +import re +from typing import List, Optional, Union + +import click +from loguru import logger +import torch + + +def load_metadata( + data_dir: Path, wordsep: str, use_words: bool = False +) -> collections.defaultdict: + """Loads IAM metadata and returns it as a dictionary.""" + forms = collections.defaultdict(list) + filename = "words.txt" if use_words else "lines.txt" + + with open(data_dir / "ascii" / filename, "r") as f: + lines = (line.strip().split() for line in f if line[0] != "#") + for line in lines: + # Skip word segmentation errors. + if use_words and line[1] == "err": + continue + text = " ".join(line[8:]) + + # Remove garbage tokens: + text = text.replace("#", "") + + # Swap word sep form | to wordsep + text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep) + form_key = "-".join(line[0].split("-")[:2]) + line_key = "-".join(line[0].split("-")[:3]) + box_idx = 4 - use_words + box = tuple(int(val) for val in line[box_idx : box_idx + 4]) + forms[form_key].append({"key": line_key, "box": box, "text": text}) + return forms + + +class Preprocessor: + """A preprocessor for the IAM dataset.""" + + # TODO: add lower case only to when generating... + + def __init__( + self, + data_dir: Union[str, Path], + num_features: int, + tokens_path: Optional[Union[str, Path]] = None, + lexicon_path: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + self.wordsep = "▁" + self._use_word = use_words + self._prepend_wordsep = prepend_wordsep + + self.data_dir = Path(data_dir) + + self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) + + # Load the set of graphemes: + graphemes = set() + for _, form in self.forms.items(): + for line in form: + graphemes.update(line["text"].lower()) + self.graphemes = sorted(graphemes) + + # Build the token-to-index and index-to-token maps. + if tokens_path is not None: + with open(tokens_path, "r") as f: + self.tokens = [line.strip() for line in f] + else: + self.tokens = self.graphemes + + if lexicon_path is not None: + with open(lexicon_path, "r") as f: + lexicon = (line.strip().split() for line in f) + lexicon = {line[0]: line[1:] for line in lexicon} + self.lexicon = lexicon + else: + self.lexicon = None + + self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} + self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} + self.num_features = num_features + self.text = [] + + @property + def num_tokens(self) -> int: + """Returns the number or tokens.""" + return len(self.tokens) + + @property + def use_words(self) -> bool: + """If words are used.""" + return self._use_word + + def extract_train_text(self) -> None: + """Extracts training text.""" + keys = [] + with open(self.data_dir / "task" / "trainset.txt") as f: + keys.extend((line.strip() for line in f)) + + for _, examples in self.forms.items(): + for example in examples: + if example["key"] not in keys: + continue + self.text.append(example["text"].lower()) + + def to_index(self, line: str) -> torch.LongTensor: + """Converts text to a tensor of indices.""" + token_to_index = self.graphemes_to_index + if self.lexicon is not None: + if len(line) > 0: + # If the word is not found in the lexicon, fall back to letters. + line = [ + t + for w in line.split(self.wordsep) + for t in self.lexicon.get(w, self.wordsep + w) + ] + token_to_index = self.tokens_to_index + if self._prepend_wordsep: + line = itertools.chain([self.wordsep], line) + return torch.LongTensor([token_to_index[t] for t in line]) + + def to_text(self, indices: List[int]) -> str: + """Converts indices to text.""" + # Roughly the inverse of `to_index` + encoding = self.graphemes + if self.lexicon is not None: + encoding = self.tokens + return self._post_process(encoding[i] for i in indices) + + def tokens_to_text(self, indices: List[int]) -> str: + """Converts tokens to text.""" + return self._post_process(self.tokens[i] for i in indices) + + def _post_process(self, indices: List[int]) -> str: + """A list join.""" + return "".join(indices).strip(self.wordsep) + + +@click.command() +@click.option("--data_dir", type=str, default=None, help="Path to iam dataset") +@click.option( + "--use_words", is_flag=True, help="Load word segmented dataset instead of lines" +) +@click.option( + "--save_text", type=str, default=None, help="Path to save parsed train text" +) +@click.option("--save_tokens", type=str, default=None, help="Path to save tokens") +def cli( + data_dir: Optional[str], + use_words: bool, + save_text: Optional[str], + save_tokens: Optional[str], +) -> None: + """CLI for extracting text data from the iam dataset.""" + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + + preprocessor = Preprocessor(data_dir, 64, use_words=use_words) + preprocessor.extract_train_text() + + processed_dir = data_dir.parents[2] / "processed" / "iam_lines" + logger.debug(f"Saving processed files at: {processed_dir}") + + if save_text is not None: + logger.info("Saving training text") + with open(processed_dir / save_text, "w") as f: + f.write("\n".join(t for t in preprocessor.text)) + + if save_tokens is not None: + logger.info("Saving tokens") + with open(processed_dir / save_tokens, "w") as f: + f.write("\n".join(preprocessor.tokens)) + + +if __name__ == "__main__": + cli() diff --git a/text_recognizer/data/sentence_generator.py b/text_recognizer/data/sentence_generator.py new file mode 100644 index 0000000..53b781c --- /dev/null +++ b/text_recognizer/data/sentence_generator.py @@ -0,0 +1,85 @@ +"""Downloading the Brown corpus with NLTK for sentence generating.""" + +import itertools +import re +import string +from typing import Optional + +import nltk +from nltk.corpus.reader.util import ConcatenatedCorpusView +import numpy as np + +from text_recognizer.datasets.util import DATA_DIRNAME + +NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk" + + +class SentenceGenerator: + """Generates text sentences using the Brown corpus.""" + + def __init__(self, max_length: Optional[int] = None) -> None: + """Loads the corpus and sets word start indices.""" + self.corpus = brown_corpus() + self.word_start_indices = [0] + [ + _.start(0) + 1 for _ in re.finditer(" ", self.corpus) + ] + self.max_length = max_length + + def generate(self, max_length: Optional[int] = None) -> str: + """Generates a word or sentences from the Brown corpus. + + Sample a string from the Brown corpus of length at least one word and at most max_length, padding to + max_length with the '_' characters if sentence is shorter. + + Args: + max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None. + + Returns: + str: A sentence from the Brown corpus. + + Raises: + ValueError: If max_length was not specified at initialization and not given as an argument. + + """ + if max_length is None: + max_length = self.max_length + if max_length is None: + raise ValueError( + "Must provide max_length to this method or when making this object." + ) + + for _ in range(10): + try: + index = np.random.randint(0, len(self.word_start_indices) - 1) + start_index = self.word_start_indices[index] + end_index_candidates = [] + for index in range(index + 1, len(self.word_start_indices)): + if self.word_start_indices[index] - start_index > max_length: + break + end_index_candidates.append(self.word_start_indices[index]) + end_index = np.random.choice(end_index_candidates) + sampled_text = self.corpus[start_index:end_index].strip() + return sampled_text + except Exception: + pass + raise RuntimeError("Was not able to generate a valid string") + + +def brown_corpus() -> str: + """Returns a single string with the Brown corpus with all punctuations stripped.""" + sentences = load_nltk_brown_corpus() + corpus = " ".join(itertools.chain.from_iterable(sentences)) + corpus = corpus.translate({ord(c): None for c in string.punctuation}) + corpus = re.sub(" +", " ", corpus) + return corpus + + +def load_nltk_brown_corpus() -> ConcatenatedCorpusView: + """Load the Brown corpus using the NLTK library.""" + nltk.data.path.append(NLTK_DATA_DIRNAME) + try: + nltk.corpus.brown.sents() + except LookupError: + NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) + return nltk.corpus.brown.sents() diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py new file mode 100644 index 0000000..b6a48f5 --- /dev/null +++ b/text_recognizer/data/transforms.py @@ -0,0 +1,266 @@ +"""Transforms for PyTorch datasets.""" +from abc import abstractmethod +from pathlib import Path +import random +from typing import Any, Optional, Union + +from loguru import logger +import numpy as np +from PIL import Image +import torch +from torch import Tensor +import torch.nn.functional as F +from torchvision import transforms +from torchvision.transforms import ( + ColorJitter, + Compose, + Normalize, + RandomAffine, + RandomHorizontalFlip, + RandomRotation, + ToPILImage, + ToTensor, +) + +from text_recognizer.datasets.iam_preprocessor import Preprocessor +from text_recognizer.datasets.util import EmnistMapper + + +class RandomResizeCrop: + """Image transform with random resize and crop applied. + + Stolen from + + https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py + + """ + + def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: + self.jitter = jitter + self.ratio = ratio + + def __call__(self, img: np.ndarray) -> np.ndarray: + """Applies random crop and rotation to an image.""" + w, h = img.size + + # pad with white: + img = transforms.functional.pad(img, self.jitter, fill=255) + + # crop at random (x, y): + x = self.jitter + random.randint(-self.jitter, self.jitter) + y = self.jitter + random.randint(-self.jitter, self.jitter) + + # randomize aspect ratio: + size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) + size = (h, int(size_w)) + img = transforms.functional.resized_crop(img, y, x, h, w, size) + return img + + +class Transpose: + """Transposes the EMNIST image to the correct orientation.""" + + def __call__(self, image: Image) -> np.ndarray: + """Swaps axis.""" + return np.array(image).swapaxes(0, 1) + + +class Resize: + """Resizes a tensor to a specified width.""" + + def __init__(self, width: int = 952) -> None: + # The default is 952 because of the IAM dataset. + self.width = width + + def __call__(self, image: Tensor) -> Tensor: + """Resize tensor in the last dimension.""" + return F.interpolate(image, size=self.width, mode="nearest") + + +class AddTokens: + """Adds start of sequence and end of sequence tokens to target tensor.""" + + def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, + ) + self.pad_value = self.emnist_mapper(self.pad_token) + self.eos_value = self.emnist_mapper(self.eos_token) + + def __call__(self, target: Tensor) -> Tensor: + """Adds a sos token to the begining and a eos token to the end of a target sequence.""" + dtype, device = target.dtype, target.device + + # Find the where padding starts. + pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() + + target[pad_index] = self.eos_value + + if self.init_token is not None: + self.sos_value = self.emnist_mapper(self.init_token) + sos = torch.tensor([self.sos_value], dtype=dtype, device=device) + target = torch.cat([sos, target], dim=0) + + return target + + +class ApplyContrast: + """Sets everything below a threshold to zero, i.e. increase contrast.""" + + def __init__(self, low: float = 0.0, high: float = 0.25) -> None: + self.low = low + self.high = high + + def __call__(self, x: Tensor) -> Tensor: + """Apply mask binary mask to input tensor.""" + mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) + return x * mask + + +class Unsqueeze: + """Add a dimension to the tensor.""" + + def __call__(self, x: Tensor) -> Tensor: + """Adds dim.""" + return x.unsqueeze(0) + + +class Squeeze: + """Removes the first dimension of a tensor.""" + + def __call__(self, x: Tensor) -> Tensor: + """Removes first dim.""" + return x.squeeze(0) + + +class ToLower: + """Converts target to lower case.""" + + def __call__(self, target: Tensor) -> Tensor: + """Corrects index value in target tensor.""" + device = target.device + return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) + + +class ToCharcters: + """Converts integers to characters.""" + + def __init__( + self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True + ) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + lower=lower, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, lower=lower + ) + + def __call__(self, y: Tensor) -> str: + """Converts a Tensor to a str.""" + return ( + "".join([self.emnist_mapper(int(i)) for i in y]) + .strip("_") + .replace(" ", "▁") + ) + + +class WordPieces: + """Abstract transform for word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + processed_path = ( + Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" + ) + tokens_path = processed_path / tokens + lexicon_path = processed_path / lexicon + + self.preprocessor = Preprocessor( + data_dir, + num_features, + tokens_path, + lexicon_path, + use_words, + prepend_wordsep, + ) + + @abstractmethod + def __call__(self, *args, **kwargs) -> Any: + """Transforms input.""" + ... + + +class ToWordPieces(WordPieces): + """Transforms str to word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, line: str) -> Tensor: + """Transforms str to word pieces.""" + return self.preprocessor.to_index(line) + + +class ToText(WordPieces): + """Takes word pieces and converts them to text.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, x: Tensor) -> str: + """Converts tensor to text.""" + return self.preprocessor.to_text(x.tolist()) diff --git a/text_recognizer/data/util.py b/text_recognizer/data/util.py new file mode 100644 index 0000000..da87756 --- /dev/null +++ b/text_recognizer/data/util.py @@ -0,0 +1,209 @@ +"""Util functions for datasets.""" +import hashlib +import json +import os +from pathlib import Path +import string +from typing import Dict, List, Optional, Union +from urllib.request import urlretrieve + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torchvision.datasets import EMNIST +from tqdm import tqdm + +DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" +ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" + + +def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: + """Extract and saves EMNIST essentials.""" + labels = emnsit_dataset.classes + labels.sort() + mapping = [(i, str(label)) for i, label in enumerate(labels)] + essentials = { + "mapping": mapping, + "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), + } + logger.info("Saving emnist essentials...") + with open(ESSENTIALS_FILENAME, "w") as f: + json.dump(essentials, f) + + +def download_emnist() -> None: + """Download the EMNIST dataset via the PyTorch class.""" + logger.info(f"Data directory is: {DATA_DIRNAME}") + dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) + save_emnist_essentials(dataset) + + +class EmnistMapper: + """Mapper between network output to Emnist character.""" + + def __init__( + self, + pad_token: str, + init_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + """Loads the emnist essentials file with the mapping and input shape.""" + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + self.lower = lower + + self.essentials = self._load_emnist_essentials() + # Load dataset information. + self._mapping = dict(self.essentials["mapping"]) + self._augment_emnist_mapping() + self._inverse_mapping = {v: k for k, v in self.mapping.items()} + self._num_classes = len(self.mapping) + self._input_shape = self.essentials["input_shape"] + + def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: + """Maps the token to emnist character or character index. + + If the token is an integer (index), the method will return the Emnist character corresponding to that index. + If the token is a str (Emnist character), the method will return the corresponding index for that character. + + Args: + token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). + + Returns: + Union[str, int]: The mapping result. + + Raises: + KeyError: If the index or string does not exist in the mapping. + + """ + if ( + (isinstance(token, np.uint8) or isinstance(token, int)) + or torch.is_tensor(token) + and int(token) in self.mapping + ): + return self.mapping[int(token)] + elif isinstance(token, str) and token in self._inverse_mapping: + return self._inverse_mapping[token] + else: + raise KeyError(f"Token {token} does not exist in the mappings.") + + @property + def mapping(self) -> Dict: + """Returns the mapping between index and character.""" + return self._mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the mapping between character and index.""" + return self._inverse_mapping + + @property + def num_classes(self) -> int: + """Returns the number of classes in the dataset.""" + return self._num_classes + + @property + def input_shape(self) -> List[int]: + """Returns the input shape of the Emnist characters.""" + return self._input_shape + + def _load_emnist_essentials(self) -> Dict: + """Load the EMNIST mapping.""" + with open(str(ESSENTIALS_FILENAME)) as f: + essentials = json.load(f) + return essentials + + def _augment_emnist_mapping(self) -> None: + """Augment the mapping with extra symbols.""" + # Extra symbols in IAM dataset + if self.lower: + self._mapping = { + k: str(v) + for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) + } + + extra_symbols = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # padding symbol, and acts as blank symbol as well. + extra_symbols.append(self.pad_token) + + if self.init_token is not None: + extra_symbols.append(self.init_token) + + if self.eos_token is not None: + extra_symbols.append(self.eos_token) + + max_key = max(self.mapping.keys()) + extra_mapping = {} + for i, symbol in enumerate(extra_symbols): + extra_mapping[max_key + 1 + i] = symbol + + self._mapping = {**self.mapping, **extra_mapping} + + +def compute_sha256(filename: Union[Path, str]) -> str: + """Returns the SHA256 checksum of a file.""" + with open(filename, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +class TqdmUpTo(tqdm): + """TQDM progress bar when downloading files. + + From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py + + """ + + def update_to( + self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None + ) -> None: + """Updates the progress bar. + + Args: + blocks (int): Number of blocks transferred so far. Defaults to 1. + block_size (int): Size of each block, in tqdm units. Defaults to 1. + total_size (Optional[int]): Total size in tqdm units. Defaults to None. + """ + if total_size is not None: + self.total = total_size # pylint: disable=attribute-defined-outside-init + self.update(blocks * block_size - self.n) + + +def download_url(url: str, filename: str) -> None: + """Downloads a file from url to filename, with a progress bar.""" + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec + + +def _download_raw_dataset(metadata: Dict) -> None: + if os.path.exists(metadata["filename"]): + return + logger.info(f"Downloading raw dataset from {metadata['url']}...") + download_url(metadata["url"], metadata["filename"]) + logger.info("Computing SHA-256...") + sha256 = compute_sha256(metadata["filename"]) + if sha256 != metadata["sha256"]: + raise ValueError( + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) -- cgit v1.2.3-70-g09d2