diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-21 22:33:58 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-21 22:33:58 +0100 |
commit | e3741de333a3a43a7968241b6eccaaac66dd7b20 (patch) | |
tree | 7c50aee4ca61f77e95f1b038030292c64bbb86c2 /text_recognizer/datasets | |
parent | aac452a2dc008338cb543549652da293c14b6b4e (diff) |
Working on EMNIST Lines dataset
Diffstat (limited to 'text_recognizer/datasets')
-rw-r--r-- | text_recognizer/datasets/__init__.py | 38 | ||||
-rw-r--r-- | text_recognizer/datasets/base_data_module.py | 36 | ||||
-rw-r--r-- | text_recognizer/datasets/base_dataset.py | 21 | ||||
-rw-r--r-- | text_recognizer/datasets/dataset.py | 152 | ||||
-rw-r--r-- | text_recognizer/datasets/download_utils.py | 6 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist.py | 88 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_essentials.json | 1 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_lines.py | 184 |
8 files changed, 278 insertions, 248 deletions
diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py index a6c1c59..2727b20 100644 --- a/text_recognizer/datasets/__init__.py +++ b/text_recognizer/datasets/__init__.py @@ -1,39 +1 @@ """Dataset modules.""" -from .emnist_dataset import EmnistDataset -from .emnist_lines_dataset import ( - construct_image_from_string, - EmnistLinesDataset, - get_samples_by_character, -) -from .iam_dataset import IamDataset -from .iam_lines_dataset import IamLinesDataset -from .iam_paragraphs_dataset import IamParagraphsDataset -from .iam_preprocessor import load_metadata, Preprocessor -from .transforms import AddTokens, Transpose -from .util import ( - _download_raw_dataset, - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, - ESSENTIALS_FILENAME, -) - -__all__ = [ - "_download_raw_dataset", - "AddTokens", - "compute_sha256", - "construct_image_from_string", - "DATA_DIRNAME", - "download_url", - "EmnistDataset", - "EmnistMapper", - "EmnistLinesDataset", - "get_samples_by_character", - "load_metadata", - "IamDataset", - "IamLinesDataset", - "IamParagraphsDataset", - "Preprocessor", - "Transpose", -] diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py index 09a0a43..830b39b 100644 --- a/text_recognizer/datasets/base_data_module.py +++ b/text_recognizer/datasets/base_data_module.py @@ -16,7 +16,7 @@ def load_and_print_info(data_module_class: type) -> None: class BaseDataModule(pl.LightningDataModule): """Base PyTorch Lightning DataModule.""" - + def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: super().__init__() self.batch_size = batch_size @@ -34,13 +34,17 @@ class BaseDataModule(pl.LightningDataModule): def config(self) -> Dict: """Return important settings of the dataset.""" - return {"input_dim": self.dims, "output_dims": self.output_dims, "mapping": self.mapping} + return { + "input_dim": self.dims, + "output_dims": self.output_dims, + "mapping": self.mapping, + } def prepare_data(self) -> None: """Prepare data for training.""" pass - def setup(self, stage: Any = None) -> None: + def setup(self, stage: str = None) -> None: """Split into train, val, test, and set dims. Should assign `torch Dataset` objects to self.data_train, self.data_val, and @@ -54,16 +58,32 @@ class BaseDataModule(pl.LightningDataModule): self.data_val = None self.data_test = None - def train_dataloader(self) -> DataLoader: """Retun DataLoader for train data.""" - return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + return DataLoader( + self.data_train, + shuffle=True, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) def val_dataloader(self) -> DataLoader: """Return DataLoader for val data.""" - return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + return DataLoader( + self.data_val, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) def test_dataloader(self) -> DataLoader: """Return DataLoader for val data.""" - return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) - + return DataLoader( + self.data_test, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py index 7322d7f..a004b8d 100644 --- a/text_recognizer/datasets/base_dataset.py +++ b/text_recognizer/datasets/base_dataset.py @@ -17,12 +17,14 @@ class BaseDataset(Dataset): target_transform (Callable): Fucntion that takes a target and applies target transforms. """ - def __init__(self, - data: Union[Sequence, Tensor], - targets: Union[Sequence, Tensor], - transform: Callable = None, - target_transform: Callable = None, - ) -> None: + + def __init__( + self, + data: Union[Sequence, Tensor], + targets: Union[Sequence, Tensor], + transform: Callable = None, + target_transform: Callable = None, + ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length.") self.data = data @@ -30,11 +32,10 @@ class BaseDataset(Dataset): self.transform = transform self.target_transform = target_transform - def __len__(self) -> int: """Return the length of the dataset.""" return len(self.data) - + def __getitem__(self, index: int) -> Tuple[Any, Any]: """Return a datum and its target, after processing by transforms. @@ -56,7 +57,9 @@ class BaseDataset(Dataset): return datum, target -def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> Tensor: +def convert_strings_to_labels( + strings: Sequence[str], mapping: Dict[str, int], length: int +) -> Tensor: """ Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <S> and </S> tokens, and padded wiht the <P> token. diff --git a/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py deleted file mode 100644 index e794605..0000000 --- a/text_recognizer/datasets/dataset.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Abstract dataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -from torch import Tensor -from torch.utils import data -from torchvision.transforms import ToTensor - -import text_recognizer.datasets.transforms as transforms -from text_recognizer.datasets.util import EmnistMapper - - -class Dataset(data.Dataset): - """Abstract class for with common methods for all datasets.""" - - def __init__( - self, - train: bool, - subsample_fraction: float = None, - transform: Optional[List[Dict]] = None, - target_transform: Optional[List[Dict]] = None, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Initialization of Dataset class. - - Args: - train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. - target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. - init_token (Optional[str]): String representing the start of sequence token. Defaults to None. - pad_token (Optional[str]): String representing the pad token. Defaults to None. - eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. - lower (bool): Only use lower case letters. Defaults to False. - - Raises: - ValueError: If subsample_fraction is not None and outside the range (0, 1). - - """ - self.train = train - self.split = "train" if self.train else "test" - - if subsample_fraction is not None: - if not 0.0 < subsample_fraction < 1.0: - raise ValueError("The subsample fraction must be in (0, 1).") - self.subsample_fraction = subsample_fraction - - self._mapper = EmnistMapper( - init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower - ) - self._input_shape = self._mapper.input_shape - self._output_shape = self._mapper._num_classes - self.num_classes = self.mapper.num_classes - - # Set transforms. - self.transform = self._configure_transform(transform) - self.target_transform = self._configure_target_transform(target_transform) - - self._data = None - self._targets = None - - def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: - transform_list = [] - if transform is not None: - for t in transform: - t_type = t["type"] - t_args = t["args"] or {} - transform_list.append(getattr(transforms, t_type)(**t_args)) - else: - transform_list.append(ToTensor()) - return transforms.Compose(transform_list) - - def _configure_target_transform( - self, target_transform: List[Dict] - ) -> transforms.Compose: - target_transform_list = [torch.tensor] - if target_transform is not None: - for t in target_transform: - t_type = t["type"] - t_args = t["args"] or {} - target_transform_list.append(getattr(transforms, t_type)(**t_args)) - return transforms.Compose(target_transform_list) - - @property - def data(self) -> Tensor: - """The input data.""" - return self._data - - @property - def targets(self) -> Tensor: - """The target data.""" - return self._targets - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return self._output_shape - - @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Return EMNIST mapping from index to character.""" - return self._mapper.mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the inverse mapping from character to index.""" - return self.mapper.inverse_mapping - - def _subsample(self) -> None: - """Only this fraction of the data will be loaded.""" - if self.subsample_fraction is None: - return - num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self._data = self.data[:num_subsample] - self._targets = self.targets[:num_subsample] - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - raise NotImplementedError - - def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: - """Fetches samples from the dataset. - - Args: - index (Union[int, torch.Tensor]): The indices of the samples to fetch. - - Raises: - NotImplementedError: If the method is not implemented in child class. - - """ - raise NotImplementedError - - def __repr__(self) -> str: - """Returns information about the dataset.""" - raise NotImplementedError diff --git a/text_recognizer/datasets/download_utils.py b/text_recognizer/datasets/download_utils.py index 7a2cab8..e3dc68c 100644 --- a/text_recognizer/datasets/download_utils.py +++ b/text_recognizer/datasets/download_utils.py @@ -63,11 +63,11 @@ def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]: if filename.exists(): return logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") - _download_url(metadata["url"], filename) + _download_url(metadata["url"], filename) logger.info("Computing the SHA-256...") sha256 = _compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError( - "Downloaded data file SHA-256 does not match that listed in metadata document." - ) + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) return filename diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py index e99dbfd..7c208c4 100644 --- a/text_recognizer/datasets/emnist.py +++ b/text_recognizer/datasets/emnist.py @@ -15,20 +15,23 @@ from torch.utils.data import random_split from torchvision import transforms from text_recognizer.datasets.base_dataset import BaseDataset -from text_recognizer.datasets.base_data_module import BaseDataModule, load_print_info +from text_recognizer.datasets.base_data_module import ( + BaseDataModule, + load_and_print_info, +) from text_recognizer.datasets.download_utils import download_dataset SEED = 4711 NUM_SPECIAL_TOKENS = 4 -SAMPLE_TO_BALANCE = True +SAMPLE_TO_BALANCE = True RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" -PROCESSED_DATA_DIRNAME = BaseDataset.data_dirname() / "processed" / "emnist" +PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnsit_essentials.json" +ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" class EMNIST(BaseDataModule): @@ -41,7 +44,9 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - def __init__(self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8) -> None: + def __init__( + self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 + ) -> None: super().__init__(batch_size, num_workers) if not ESSENTIALS_FILENAME.exists(): _download_and_process_emnist() @@ -64,20 +69,21 @@ class EMNIST(BaseDataModule): def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - data = f["x_train"][:] - targets = f["y_train"][:] - - dataset_train = BaseDataset(data, targets, transform=self.transform) + self.x_train = f["x_train"][:] + self.y_train = f["y_train"][:] + + dataset_train = BaseDataset(self.x_train, self.y_train, transform=self.transform) train_size = int(self.train_fraction * len(dataset_train)) val_size = len(dataset_train) - train_size - self.data_train, self.data_val = random_split(dataset_train, [train_size, val_size], generator=torch.Generator()) + self.data_train, self.data_val = random_split( + dataset_train, [train_size, val_size], generator=torch.Generator() + ) if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - data = f["x_test"][:] - targets = f["y_test"][:] - self.data_test = BaseDataset(data, targets, transform=self.transform) - + self.x_test = f["x_test"][:] + self.y_test = f["y_test"][:] + self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) def __repr__(self) -> str: basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" @@ -111,9 +117,15 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: logger.info("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") - x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + x_train = ( + data["dataset"]["train"][0, 0]["images"][0, 0] + .reshape(-1, 28, 28) + .swapaxes(1, 2) + ) y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS - x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + x_test = ( + data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + ) y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS if SAMPLE_TO_BALANCE: @@ -121,7 +133,6 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) - logger.info("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: @@ -154,7 +165,7 @@ def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.nda all_sampled_indices.append(sampled_indices) indices = np.concatenate(all_sampled_indices) x_sampled = x[indices] - y_sampled= y[indices] + y_sampled = y[indices] return x_sampled, y_sampled @@ -162,24 +173,24 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: """Augment the mapping with extra symbols.""" # Extra characters from the IAM dataset. iam_characters = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] # Also add special tokens for: # - CTC blank token at index 0 @@ -190,5 +201,6 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters] -if __name__ == "__main__": - load_print_info(EMNIST) +def download_emnist() -> None: + """Download dataset from internet, if it does not exists, and displays info.""" + load_and_print_info(EMNIST) diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json new file mode 100644 index 0000000..100b36a --- /dev/null +++ b/text_recognizer/datasets/emnist_essentials.json @@ -0,0 +1 @@ +{"characters": ["<b>", "<s>", "</s>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
\ No newline at end of file diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py new file mode 100644 index 0000000..ae23feb --- /dev/null +++ b/text_recognizer/datasets/emnist_lines.py @@ -0,0 +1,184 @@ +"""Dataset of generated text from EMNIST characters.""" +from collections import defaultdict +from pathlib import Path +from typing import Dict, Sequence + +import h5py +from loguru import logger +import numpy as np +import torch +from torchvision import transforms + +from text_recognizer.datasets.base_dataset import BaseDataset +from text_recognizer.datasets.base_data_module import BaseDataModule +from text_recognizer.datasets.emnist import EMNIST +from text_recognizer.datasets.sentence_generator import SentenceGenerator + + +DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" +ESSENTIALS_FILENAME = ( + Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json" +) + +SEED = 4711 +IMAGE_HEIGHT = 56 +IMAGE_WIDTH = 1024 +IMAGE_X_PADDING = 28 +MAX_OUTPUT_LENGTH = 89 # Same as IAMLines + + +class EMNISTLines(BaseDataModule): + """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" + + def __init__( + self, + augment: bool = True, + batch_size: int = 128, + num_workers: int = 0, + max_length: int = 32, + min_overlap: float = 0.0, + max_overlap: float = 0.33, + num_train: int = 10_000, + num_val: int = 2_000, + num_test: int = 2_000, + ) -> None: + super().__init__(batch_size, num_workers) + + self.augment = augment + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + + self.emnist = EMNIST() + self.mapping = self.emnist.mapping + max_width = int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING + + if max_width <= IMAGE_WIDTH: + raise ValueError("max_width greater than IMAGE_WIDTH") + + self.dims = ( + self.emnist.dims[0], + self.emnist.dims[1], + self.emnist.dims[2] * self.max_length, + ) + + if self.max_length <= MAX_OUTPUT_LENGTH: + raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") + + self.output_dims = (MAX_OUTPUT_LENGTH, 1) + self.data_train = None + self.data_val = None + self.data_test = None + + @property + def data_filename(self) -> Path: + """Return name of dataset.""" + return ( + DATA_DIRNAME + / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" + ) + + def prepare_data(self) -> None: + if self.data_filename.exists(): + return + np.random.seed(SEED) + self._generate_data("train") + self._generate_data("val") + self._generate_data("test") + + def setup(self, stage: str = None) -> None: + logger.info("EMNISTLinesDataset loading data from HDF5...") + if stage == "fit" or stage is None: + with h5py.File(self.data_filename, "r") as f: + x_train = f["x_train"][:] + y_train = torch.LongTensor(f["y_train"][:]) + x_val = f["x_val"][:] + y_val = torch.LongTensor(f["y_val"][:]) + + self.data_train = BaseDataset(x_train, y_train, transform=_get_transform(augment=self.augment)) + self.data_val = BaseDataset(x_val, y_val, transform=_get_transform(augment=self.augment)) + + if stage == "test" or stage is None: + with h5py.File(self.data_filename, "r") as f: + x_test = f["x_test"][:] + y_test = torch.LongTensor(f["y_test"][:]) + + self.data_train = BaseDataset(x_test, y_test, transform=_get_transform(augment=False)) + + def __repr__(self) -> str: + """Return str about dataset.""" + basic = ( + "EMNISTLines2 Dataset\n" # pylint: disable=no-member + f"Min overlap: {self.min_overlap}\n" + f"Max overlap: {self.max_overlap}\n" + f"Num classes: {len(self.mapping)}\n" + f"Dims: {self.dims}\n" + f"Output dims: {self.output_dims}\n" + ) + + if not any([self.data_train, self.data_val, self.data_test]): + return basic + + x, y = next(iter(self.train_dataloader())) + data = ( + f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" + f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" + ) + return basic + data + + def _generate_data(self, split: str) -> None: + logger.info(f"EMNISTLines generating data for {split}...") + sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract by 2 because start/end token + + emnist = self.emnist + emnist.prepare_data() + emnist.setup() + + if split == "train": + samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping) + num = self.num_train + elif split == "val": + samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping) + num = self.num_val + elif split == "test": + samples_by_char = _get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) + num = self.num_test + + DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(self.data_filename, "w") as f: + x, y = _create_dataset_of_images( + num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.dims + ) + y = _convert_strings_to_labels( + y, + emnist.inverse_mapping, + length=MAX_OUTPUT_LENGTH + ) + f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") + f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") + +def _get_samples_by_char(samples: np.ndarray, labels: np.ndarray, mapping: Dict) -> defaultdict: + samples_by_char = defaultdict(list) + for sample, label in zip(samples, labels): + samples_by_char[mapping[label]].append(sample) + return samples_by_char + + +def _construct_image_from_string(): + pass + + +def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): + pass + + +def _create_dataset_of_images(num_samples: int, samples_by_char: defaultdict, sentence_generator: SentenceGenerator, min_overlap: float, max_overlap: float, dims: Tuple) -> Tuple[torch.Tensor, torch.Tensor]: + images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) + labels = [] + for n in range(num_samples): + label = sentence_generator.generate() + crop = _construct_image_from_string() |