diff options
Diffstat (limited to 'text_recognizer/datasets/emnist_lines.py')
| -rw-r--r-- | text_recognizer/datasets/emnist_lines.py | 172 | 
1 files changed, 134 insertions, 38 deletions
diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py index ae23feb..9ebad22 100644 --- a/text_recognizer/datasets/emnist_lines.py +++ b/text_recognizer/datasets/emnist_lines.py @@ -1,16 +1,21 @@  """Dataset of generated text from EMNIST characters."""  from collections import defaultdict  from pathlib import Path -from typing import Dict, Sequence +from typing import Callable, Dict, Tuple, Sequence  import h5py  from loguru import logger  import numpy as np +from PIL import Image  import torch  from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode -from text_recognizer.datasets.base_dataset import BaseDataset -from text_recognizer.datasets.base_data_module import BaseDataModule +from text_recognizer.datasets.base_dataset import BaseDataset, convert_strings_to_labels +from text_recognizer.datasets.base_data_module import ( +    BaseDataModule, +    load_and_print_info, +)  from text_recognizer.datasets.emnist import EMNIST  from text_recognizer.datasets.sentence_generator import SentenceGenerator @@ -54,18 +59,23 @@ class EMNISTLines(BaseDataModule):          self.emnist = EMNIST()          self.mapping = self.emnist.mapping -        max_width = int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING -         -        if max_width <= IMAGE_WIDTH: -            raise ValueError("max_width greater than IMAGE_WIDTH") +        max_width = ( +            int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) +            + IMAGE_X_PADDING +        ) + +        if max_width >= IMAGE_WIDTH: +            raise ValueError( +                    f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" +                    )          self.dims = (              self.emnist.dims[0], -            self.emnist.dims[1], -            self.emnist.dims[2] * self.max_length, +            IMAGE_HEIGHT, +            IMAGE_WIDTH          ) -        if self.max_length <= MAX_OUTPUT_LENGTH: +        if self.max_length >= MAX_OUTPUT_LENGTH:              raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")          self.output_dims = (MAX_OUTPUT_LENGTH, 1) @@ -77,8 +87,11 @@ class EMNISTLines(BaseDataModule):      def data_filename(self) -> Path:          """Return name of dataset."""          return ( -            DATA_DIRNAME -            / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" +            DATA_DIRNAME / (f"ml_{self.max_length}_" +            f"o{self.min_overlap:f}_{self.max_overlap:f}_" +            f"ntr{self.num_train}_" +            f"ntv{self.num_val}_" +            f"nte{self.num_test}.h5")          )      def prepare_data(self) -> None: @@ -92,21 +105,28 @@ class EMNISTLines(BaseDataModule):      def setup(self, stage: str = None) -> None:          logger.info("EMNISTLinesDataset loading data from HDF5...")          if stage == "fit" or stage is None: +            print(self.data_filename)              with h5py.File(self.data_filename, "r") as f:                  x_train = f["x_train"][:]                  y_train = torch.LongTensor(f["y_train"][:])                  x_val = f["x_val"][:]                  y_val = torch.LongTensor(f["y_val"][:]) -            self.data_train = BaseDataset(x_train, y_train, transform=_get_transform(augment=self.augment)) -            self.data_val  = BaseDataset(x_val, y_val, transform=_get_transform(augment=self.augment)) +            self.data_train = BaseDataset( +                x_train, y_train, transform=_get_transform(augment=self.augment) +            ) +            self.data_val = BaseDataset( +                x_val, y_val, transform=_get_transform(augment=self.augment) +            )          if stage == "test" or stage is None:              with h5py.File(self.data_filename, "r") as f:                  x_test = f["x_test"][:]                  y_test = torch.LongTensor(f["y_test"][:]) -            self.data_train = BaseDataset(x_test, y_test, transform=_get_transform(augment=False)) +            self.data_test = BaseDataset( +                x_test, y_test, transform=_get_transform(augment=False) +            )      def __repr__(self) -> str:          """Return str about dataset.""" @@ -132,53 +152,129 @@ class EMNISTLines(BaseDataModule):      def _generate_data(self, split: str) -> None:          logger.info(f"EMNISTLines generating data for {split}...") -        sentence_generator = SentenceGenerator(self.max_length - 2)  # Subtract by 2 because start/end token +        sentence_generator = SentenceGenerator( +            self.max_length - 2 +        )  # Subtract by 2 because start/end token          emnist = self.emnist          emnist.prepare_data()          emnist.setup()          if split == "train": -            samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping) +            samples_by_char = _get_samples_by_char( +                emnist.x_train, emnist.y_train, emnist.mapping +            )              num = self.num_train          elif split == "val": -            samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping) +            samples_by_char = _get_samples_by_char( +                emnist.x_train, emnist.y_train, emnist.mapping +            )              num = self.num_val -        elif split == "test": -            samples_by_char = _get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) +        else: +            samples_by_char = _get_samples_by_char( +                emnist.x_test, emnist.y_test, emnist.mapping +            )              num = self.num_test          DATA_DIRNAME.mkdir(parents=True, exist_ok=True) -        with h5py.File(self.data_filename, "w") as f: +        with h5py.File(self.data_filename, "a") as f:              x, y = _create_dataset_of_images( -                    num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.dims -                    ) -            y = _convert_strings_to_labels( -                    y, -                    emnist.inverse_mapping, -                    length=MAX_OUTPUT_LENGTH -                    ) +                num, +                samples_by_char, +                sentence_generator, +                self.min_overlap, +                self.max_overlap, +                self.dims, +            ) +            y = convert_strings_to_labels( +                y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH +            )              f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")              f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") -def _get_samples_by_char(samples: np.ndarray, labels: np.ndarray, mapping: Dict) -> defaultdict: + +def _get_samples_by_char( +    samples: np.ndarray, labels: np.ndarray, mapping: Dict +) -> defaultdict:      samples_by_char = defaultdict(list)      for sample, label in zip(samples, labels):          samples_by_char[mapping[label]].append(sample)      return samples_by_char -def _construct_image_from_string(): -    pass - -  def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): -    pass - - -def _create_dataset_of_images(num_samples: int, samples_by_char: defaultdict, sentence_generator: SentenceGenerator, min_overlap: float, max_overlap: float, dims: Tuple) -> Tuple[torch.Tensor, torch.Tensor]: +    null_image = torch.zeros((28, 28), dtype=torch.uint8) +    sample_image_by_char = {} +    for char in string: +        if char in sample_image_by_char: +            continue +        samples = samples_by_char[char] +        sample = samples[np.random.choice(len(samples))] if samples else null_image +        sample_image_by_char[char] = sample.reshape(28, 28) +    return [sample_image_by_char[char] for char in string] + + +def _construct_image_from_string( +    string: str, +    samples_by_char: defaultdict, +    min_overlap: float, +    max_overlap: float, +    width: int, +) -> torch.Tensor: +    overlap = np.random.uniform(min_overlap, max_overlap) +    sampled_images = _select_letter_samples_for_string(string, samples_by_char) +    N = len(sampled_images) +    H, W = sampled_images[0].shape +    next_overlap_width = W - int(overlap * W) +    concatenated_image = torch.zeros((H, width), dtype=torch.uint8) +    x = IMAGE_X_PADDING +    for image in sampled_images: +        concatenated_image[:, x : (x + W)] += image +        x += next_overlap_width +    return torch.minimum(torch.Tensor([255]), concatenated_image) + + +def _create_dataset_of_images( +    num_samples: int, +    samples_by_char: defaultdict, +    sentence_generator: SentenceGenerator, +    min_overlap: float, +    max_overlap: float, +    dims: Tuple, +) -> Tuple[torch.Tensor, torch.Tensor]:      images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2]))      labels = []      for n in range(num_samples):          label = sentence_generator.generate() -        crop = _construct_image_from_string() +        crop = _construct_image_from_string( +            label, samples_by_char, min_overlap, max_overlap, dims[-1] +        ) +        height = crop.shape[0] +        y = (IMAGE_HEIGHT - height) // 2 +        images[n, y : (y + height), :] = crop +        labels.append(label) +    return images, labels + + +def _get_transform(augment: bool = False) -> Callable: +    if not augment: +        return transforms.Compose([transforms.ToTensor()]) +    return transforms.Compose( +        [ +            transforms.ToTensor(), +            transforms.ColorJitter(brightness=(0.5, 1.0)), +            transforms.RandomAffine( +                degrees=3, +                translate=(0.0, 0.05), +                scale=(0.4, 1.1), +                shear=(-40, 50), +                interpolation=InterpolationMode.BILINEAR, +                fill=0, +            ), +        ] +    ) + + +def generate_emnist_lines() -> None: +    """Generates a synthetic handwritten dataset and displays info,""" +    load_and_print_info(EMNISTLines)  |