diff options
Diffstat (limited to 'text_recognizer/datasets/emnist_lines.py')
-rw-r--r-- | text_recognizer/datasets/emnist_lines.py | 172 |
1 files changed, 134 insertions, 38 deletions
diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py index ae23feb..9ebad22 100644 --- a/text_recognizer/datasets/emnist_lines.py +++ b/text_recognizer/datasets/emnist_lines.py @@ -1,16 +1,21 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import Dict, Sequence +from typing import Callable, Dict, Tuple, Sequence import h5py from loguru import logger import numpy as np +from PIL import Image import torch from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode -from text_recognizer.datasets.base_dataset import BaseDataset -from text_recognizer.datasets.base_data_module import BaseDataModule +from text_recognizer.datasets.base_dataset import BaseDataset, convert_strings_to_labels +from text_recognizer.datasets.base_data_module import ( + BaseDataModule, + load_and_print_info, +) from text_recognizer.datasets.emnist import EMNIST from text_recognizer.datasets.sentence_generator import SentenceGenerator @@ -54,18 +59,23 @@ class EMNISTLines(BaseDataModule): self.emnist = EMNIST() self.mapping = self.emnist.mapping - max_width = int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING - - if max_width <= IMAGE_WIDTH: - raise ValueError("max_width greater than IMAGE_WIDTH") + max_width = ( + int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + + IMAGE_X_PADDING + ) + + if max_width >= IMAGE_WIDTH: + raise ValueError( + f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" + ) self.dims = ( self.emnist.dims[0], - self.emnist.dims[1], - self.emnist.dims[2] * self.max_length, + IMAGE_HEIGHT, + IMAGE_WIDTH ) - if self.max_length <= MAX_OUTPUT_LENGTH: + if self.max_length >= MAX_OUTPUT_LENGTH: raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") self.output_dims = (MAX_OUTPUT_LENGTH, 1) @@ -77,8 +87,11 @@ class EMNISTLines(BaseDataModule): def data_filename(self) -> Path: """Return name of dataset.""" return ( - DATA_DIRNAME - / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" + DATA_DIRNAME / (f"ml_{self.max_length}_" + f"o{self.min_overlap:f}_{self.max_overlap:f}_" + f"ntr{self.num_train}_" + f"ntv{self.num_val}_" + f"nte{self.num_test}.h5") ) def prepare_data(self) -> None: @@ -92,21 +105,28 @@ class EMNISTLines(BaseDataModule): def setup(self, stage: str = None) -> None: logger.info("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: + print(self.data_filename) with h5py.File(self.data_filename, "r") as f: x_train = f["x_train"][:] y_train = torch.LongTensor(f["y_train"][:]) x_val = f["x_val"][:] y_val = torch.LongTensor(f["y_val"][:]) - self.data_train = BaseDataset(x_train, y_train, transform=_get_transform(augment=self.augment)) - self.data_val = BaseDataset(x_val, y_val, transform=_get_transform(augment=self.augment)) + self.data_train = BaseDataset( + x_train, y_train, transform=_get_transform(augment=self.augment) + ) + self.data_val = BaseDataset( + x_val, y_val, transform=_get_transform(augment=self.augment) + ) if stage == "test" or stage is None: with h5py.File(self.data_filename, "r") as f: x_test = f["x_test"][:] y_test = torch.LongTensor(f["y_test"][:]) - self.data_train = BaseDataset(x_test, y_test, transform=_get_transform(augment=False)) + self.data_test = BaseDataset( + x_test, y_test, transform=_get_transform(augment=False) + ) def __repr__(self) -> str: """Return str about dataset.""" @@ -132,53 +152,129 @@ class EMNISTLines(BaseDataModule): def _generate_data(self, split: str) -> None: logger.info(f"EMNISTLines generating data for {split}...") - sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract by 2 because start/end token + sentence_generator = SentenceGenerator( + self.max_length - 2 + ) # Subtract by 2 because start/end token emnist = self.emnist emnist.prepare_data() emnist.setup() if split == "train": - samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping) + samples_by_char = _get_samples_by_char( + emnist.x_train, emnist.y_train, emnist.mapping + ) num = self.num_train elif split == "val": - samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping) + samples_by_char = _get_samples_by_char( + emnist.x_train, emnist.y_train, emnist.mapping + ) num = self.num_val - elif split == "test": - samples_by_char = _get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) + else: + samples_by_char = _get_samples_by_char( + emnist.x_test, emnist.y_test, emnist.mapping + ) num = self.num_test DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - with h5py.File(self.data_filename, "w") as f: + with h5py.File(self.data_filename, "a") as f: x, y = _create_dataset_of_images( - num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.dims - ) - y = _convert_strings_to_labels( - y, - emnist.inverse_mapping, - length=MAX_OUTPUT_LENGTH - ) + num, + samples_by_char, + sentence_generator, + self.min_overlap, + self.max_overlap, + self.dims, + ) + y = convert_strings_to_labels( + y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH + ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") -def _get_samples_by_char(samples: np.ndarray, labels: np.ndarray, mapping: Dict) -> defaultdict: + +def _get_samples_by_char( + samples: np.ndarray, labels: np.ndarray, mapping: Dict +) -> defaultdict: samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): samples_by_char[mapping[label]].append(sample) return samples_by_char -def _construct_image_from_string(): - pass - - def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): - pass - - -def _create_dataset_of_images(num_samples: int, samples_by_char: defaultdict, sentence_generator: SentenceGenerator, min_overlap: float, max_overlap: float, dims: Tuple) -> Tuple[torch.Tensor, torch.Tensor]: + null_image = torch.zeros((28, 28), dtype=torch.uint8) + sample_image_by_char = {} + for char in string: + if char in sample_image_by_char: + continue + samples = samples_by_char[char] + sample = samples[np.random.choice(len(samples))] if samples else null_image + sample_image_by_char[char] = sample.reshape(28, 28) + return [sample_image_by_char[char] for char in string] + + +def _construct_image_from_string( + string: str, + samples_by_char: defaultdict, + min_overlap: float, + max_overlap: float, + width: int, +) -> torch.Tensor: + overlap = np.random.uniform(min_overlap, max_overlap) + sampled_images = _select_letter_samples_for_string(string, samples_by_char) + N = len(sampled_images) + H, W = sampled_images[0].shape + next_overlap_width = W - int(overlap * W) + concatenated_image = torch.zeros((H, width), dtype=torch.uint8) + x = IMAGE_X_PADDING + for image in sampled_images: + concatenated_image[:, x : (x + W)] += image + x += next_overlap_width + return torch.minimum(torch.Tensor([255]), concatenated_image) + + +def _create_dataset_of_images( + num_samples: int, + samples_by_char: defaultdict, + sentence_generator: SentenceGenerator, + min_overlap: float, + max_overlap: float, + dims: Tuple, +) -> Tuple[torch.Tensor, torch.Tensor]: images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) labels = [] for n in range(num_samples): label = sentence_generator.generate() - crop = _construct_image_from_string() + crop = _construct_image_from_string( + label, samples_by_char, min_overlap, max_overlap, dims[-1] + ) + height = crop.shape[0] + y = (IMAGE_HEIGHT - height) // 2 + images[n, y : (y + height), :] = crop + labels.append(label) + return images, labels + + +def _get_transform(augment: bool = False) -> Callable: + if not augment: + return transforms.Compose([transforms.ToTensor()]) + return transforms.Compose( + [ + transforms.ToTensor(), + transforms.ColorJitter(brightness=(0.5, 1.0)), + transforms.RandomAffine( + degrees=3, + translate=(0.0, 0.05), + scale=(0.4, 1.1), + shear=(-40, 50), + interpolation=InterpolationMode.BILINEAR, + fill=0, + ), + ] + ) + + +def generate_emnist_lines() -> None: + """Generates a synthetic handwritten dataset and displays info,""" + load_and_print_info(EMNISTLines) |