diff options
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 55 |
1 files changed, 16 insertions, 39 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 3ff8a54..1a64931 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,7 +1,7 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import Callable, List, Tuple +from typing import DefaultDict, List, Tuple import attr import h5py @@ -9,8 +9,7 @@ from loguru import logger as log import numpy as np import torch from torch import Tensor -from torchvision import transforms -from torchvision.transforms.functional import InterpolationMode +import torchvision.transforms as T from text_recognizer.data.base_data_module import ( BaseDataModule, @@ -18,12 +17,13 @@ from text_recognizer.data.base_data_module import ( ) from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST -from text_recognizer.data.sentence_generator import SentenceGenerator +from text_recognizer.data.utils.sentence_generator import SentenceGenerator +from text_recognizer.data.transforms.load_transform import load_transform_from_file DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" ESSENTIALS_FILENAME = ( - Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json" + Path(__file__).parents[0].resolve() / "mappings" / "emnist_lines_essentials.json" ) SEED = 4711 @@ -37,7 +37,6 @@ MAX_OUTPUT_LENGTH = 89 # Same as IAMLines class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST.""" - augment: bool = attr.ib(default=True) max_length: int = attr.ib(default=128) min_overlap: float = attr.ib(default=0.0) max_overlap: float = attr.ib(default=0.33) @@ -98,21 +97,15 @@ class EMNISTLines(BaseDataModule): x_val = f["x_val"][:] y_val = torch.LongTensor(f["y_val"][:]) - self.data_train = BaseDataset( - x_train, y_train, transform=_get_transform(augment=self.augment) - ) - self.data_val = BaseDataset( - x_val, y_val, transform=_get_transform(augment=self.augment) - ) + self.data_train = BaseDataset(x_train, y_train, transform=self.transform) + self.data_val = BaseDataset(x_val, y_val, transform=self.transform) if stage == "test" or stage is None: with h5py.File(self.data_filename, "r") as f: x_test = f["x_test"][:] y_test = torch.LongTensor(f["y_test"][:]) - self.data_test = BaseDataset( - x_test, y_test, transform=_get_transform(augment=False) - ) + self.data_test = BaseDataset(x_test, y_test, transform=self.test_transform) def __repr__(self) -> str: """Return str about dataset.""" @@ -129,6 +122,7 @@ class EMNISTLines(BaseDataModule): return basic x, y = next(iter(self.train_dataloader())) + x = x[0] if isinstance(x, list) else x data = ( "Train/val/test sizes: " f"{len(self.data_train)}, " @@ -184,7 +178,7 @@ class EMNISTLines(BaseDataModule): def _get_samples_by_char( samples: np.ndarray, labels: np.ndarray, mapping: List -) -> defaultdict: +) -> DefaultDict: samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): samples_by_char[mapping[label]].append(sample) @@ -192,7 +186,7 @@ def _get_samples_by_char( def _select_letter_samples_for_string( - string: str, samples_by_char: defaultdict + string: str, samples_by_char: DefaultDict ) -> List[Tensor]: null_image = torch.zeros((28, 28), dtype=torch.uint8) sample_image_by_char = {} @@ -207,7 +201,7 @@ def _select_letter_samples_for_string( def _construct_image_from_string( string: str, - samples_by_char: defaultdict, + samples_by_char: DefaultDict, min_overlap: float, max_overlap: float, width: int, @@ -226,7 +220,7 @@ def _construct_image_from_string( def _create_dataset_of_images( num_samples: int, - samples_by_char: defaultdict, + samples_by_char: DefaultDict, sentence_generator: SentenceGenerator, min_overlap: float, max_overlap: float, @@ -246,25 +240,8 @@ def _create_dataset_of_images( return images, labels -def _get_transform(augment: bool = False) -> Callable: - if not augment: - return transforms.Compose([transforms.ToTensor()]) - return transforms.Compose( - [ - transforms.ToTensor(), - transforms.ColorJitter(brightness=(0.5, 1.0)), - transforms.RandomAffine( - degrees=3, - translate=(0.0, 0.05), - scale=(0.4, 1.1), - shear=(-40, 50), - interpolation=InterpolationMode.BILINEAR, - fill=0, - ), - ] - ) - - def generate_emnist_lines() -> None: """Generates a synthetic handwritten dataset and displays info.""" - load_and_print_info(EMNISTLines) + transform = load_transform_from_file("transform/emnist_lines.yaml") + test_transform = load_transform_from_file("test_transform/default.yaml") + load_and_print_info(EMNISTLines(transform=transform, test_transform=test_transform)) |