diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-24 22:15:54 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-24 22:15:54 +0100 |
commit | 8248f173132dfb7e47ec62b08e9235990c8626e3 (patch) | |
tree | 2f3ff85602cbc08b7168bf4f0d3924d32a689852 /text_recognizer/datasets/emnist_lines.py | |
parent | 74c907a17379688967dc4b3f41a44ba83034f5e0 (diff) |
renamed datasets to data, added iam refactor
Diffstat (limited to 'text_recognizer/datasets/emnist_lines.py')
-rw-r--r-- | text_recognizer/datasets/emnist_lines.py | 280 |
1 files changed, 0 insertions, 280 deletions
diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py deleted file mode 100644 index 9ebad22..0000000 --- a/text_recognizer/datasets/emnist_lines.py +++ /dev/null @@ -1,280 +0,0 @@ -"""Dataset of generated text from EMNIST characters.""" -from collections import defaultdict -from pathlib import Path -from typing import Callable, Dict, Tuple, Sequence - -import h5py -from loguru import logger -import numpy as np -from PIL import Image -import torch -from torchvision import transforms -from torchvision.transforms.functional import InterpolationMode - -from text_recognizer.datasets.base_dataset import BaseDataset, convert_strings_to_labels -from text_recognizer.datasets.base_data_module import ( - BaseDataModule, - load_and_print_info, -) -from text_recognizer.datasets.emnist import EMNIST -from text_recognizer.datasets.sentence_generator import SentenceGenerator - - -DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" -ESSENTIALS_FILENAME = ( - Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json" -) - -SEED = 4711 -IMAGE_HEIGHT = 56 -IMAGE_WIDTH = 1024 -IMAGE_X_PADDING = 28 -MAX_OUTPUT_LENGTH = 89 # Same as IAMLines - - -class EMNISTLines(BaseDataModule): - """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" - - def __init__( - self, - augment: bool = True, - batch_size: int = 128, - num_workers: int = 0, - max_length: int = 32, - min_overlap: float = 0.0, - max_overlap: float = 0.33, - num_train: int = 10_000, - num_val: int = 2_000, - num_test: int = 2_000, - ) -> None: - super().__init__(batch_size, num_workers) - - self.augment = augment - self.max_length = max_length - self.min_overlap = min_overlap - self.max_overlap = max_overlap - self.num_train = num_train - self.num_val = num_val - self.num_test = num_test - - self.emnist = EMNIST() - self.mapping = self.emnist.mapping - max_width = ( - int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) - + IMAGE_X_PADDING - ) - - if max_width >= IMAGE_WIDTH: - raise ValueError( - f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" - ) - - self.dims = ( - self.emnist.dims[0], - IMAGE_HEIGHT, - IMAGE_WIDTH - ) - - if self.max_length >= MAX_OUTPUT_LENGTH: - raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") - - self.output_dims = (MAX_OUTPUT_LENGTH, 1) - self.data_train = None - self.data_val = None - self.data_test = None - - @property - def data_filename(self) -> Path: - """Return name of dataset.""" - return ( - DATA_DIRNAME / (f"ml_{self.max_length}_" - f"o{self.min_overlap:f}_{self.max_overlap:f}_" - f"ntr{self.num_train}_" - f"ntv{self.num_val}_" - f"nte{self.num_test}.h5") - ) - - def prepare_data(self) -> None: - if self.data_filename.exists(): - return - np.random.seed(SEED) - self._generate_data("train") - self._generate_data("val") - self._generate_data("test") - - def setup(self, stage: str = None) -> None: - logger.info("EMNISTLinesDataset loading data from HDF5...") - if stage == "fit" or stage is None: - print(self.data_filename) - with h5py.File(self.data_filename, "r") as f: - x_train = f["x_train"][:] - y_train = torch.LongTensor(f["y_train"][:]) - x_val = f["x_val"][:] - y_val = torch.LongTensor(f["y_val"][:]) - - self.data_train = BaseDataset( - x_train, y_train, transform=_get_transform(augment=self.augment) - ) - self.data_val = BaseDataset( - x_val, y_val, transform=_get_transform(augment=self.augment) - ) - - if stage == "test" or stage is None: - with h5py.File(self.data_filename, "r") as f: - x_test = f["x_test"][:] - y_test = torch.LongTensor(f["y_test"][:]) - - self.data_test = BaseDataset( - x_test, y_test, transform=_get_transform(augment=False) - ) - - def __repr__(self) -> str: - """Return str about dataset.""" - basic = ( - "EMNISTLines2 Dataset\n" # pylint: disable=no-member - f"Min overlap: {self.min_overlap}\n" - f"Max overlap: {self.max_overlap}\n" - f"Num classes: {len(self.mapping)}\n" - f"Dims: {self.dims}\n" - f"Output dims: {self.output_dims}\n" - ) - - if not any([self.data_train, self.data_val, self.data_test]): - return basic - - x, y = next(iter(self.train_dataloader())) - data = ( - f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" - f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" - f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" - ) - return basic + data - - def _generate_data(self, split: str) -> None: - logger.info(f"EMNISTLines generating data for {split}...") - sentence_generator = SentenceGenerator( - self.max_length - 2 - ) # Subtract by 2 because start/end token - - emnist = self.emnist - emnist.prepare_data() - emnist.setup() - - if split == "train": - samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping - ) - num = self.num_train - elif split == "val": - samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping - ) - num = self.num_val - else: - samples_by_char = _get_samples_by_char( - emnist.x_test, emnist.y_test, emnist.mapping - ) - num = self.num_test - - DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - with h5py.File(self.data_filename, "a") as f: - x, y = _create_dataset_of_images( - num, - samples_by_char, - sentence_generator, - self.min_overlap, - self.max_overlap, - self.dims, - ) - y = convert_strings_to_labels( - y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH - ) - f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") - f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") - - -def _get_samples_by_char( - samples: np.ndarray, labels: np.ndarray, mapping: Dict -) -> defaultdict: - samples_by_char = defaultdict(list) - for sample, label in zip(samples, labels): - samples_by_char[mapping[label]].append(sample) - return samples_by_char - - -def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): - null_image = torch.zeros((28, 28), dtype=torch.uint8) - sample_image_by_char = {} - for char in string: - if char in sample_image_by_char: - continue - samples = samples_by_char[char] - sample = samples[np.random.choice(len(samples))] if samples else null_image - sample_image_by_char[char] = sample.reshape(28, 28) - return [sample_image_by_char[char] for char in string] - - -def _construct_image_from_string( - string: str, - samples_by_char: defaultdict, - min_overlap: float, - max_overlap: float, - width: int, -) -> torch.Tensor: - overlap = np.random.uniform(min_overlap, max_overlap) - sampled_images = _select_letter_samples_for_string(string, samples_by_char) - N = len(sampled_images) - H, W = sampled_images[0].shape - next_overlap_width = W - int(overlap * W) - concatenated_image = torch.zeros((H, width), dtype=torch.uint8) - x = IMAGE_X_PADDING - for image in sampled_images: - concatenated_image[:, x : (x + W)] += image - x += next_overlap_width - return torch.minimum(torch.Tensor([255]), concatenated_image) - - -def _create_dataset_of_images( - num_samples: int, - samples_by_char: defaultdict, - sentence_generator: SentenceGenerator, - min_overlap: float, - max_overlap: float, - dims: Tuple, -) -> Tuple[torch.Tensor, torch.Tensor]: - images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) - labels = [] - for n in range(num_samples): - label = sentence_generator.generate() - crop = _construct_image_from_string( - label, samples_by_char, min_overlap, max_overlap, dims[-1] - ) - height = crop.shape[0] - y = (IMAGE_HEIGHT - height) // 2 - images[n, y : (y + height), :] = crop - labels.append(label) - return images, labels - - -def _get_transform(augment: bool = False) -> Callable: - if not augment: - return transforms.Compose([transforms.ToTensor()]) - return transforms.Compose( - [ - transforms.ToTensor(), - transforms.ColorJitter(brightness=(0.5, 1.0)), - transforms.RandomAffine( - degrees=3, - translate=(0.0, 0.05), - scale=(0.4, 1.1), - shear=(-40, 50), - interpolation=InterpolationMode.BILINEAR, - fill=0, - ), - ] - ) - - -def generate_emnist_lines() -> None: - """Generates a synthetic handwritten dataset and displays info,""" - load_and_print_info(EMNISTLines) |