diff options
Diffstat (limited to 'text_recognizer/data/emnist_lines.py')
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 44 |
1 files changed, 17 insertions, 27 deletions
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index c36132e..63c9f22 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,12 +1,11 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import Callable, DefaultDict, List, Optional, Tuple, Type +from typing import Callable, DefaultDict, List, Optional, Tuple import h5py import numpy as np import torch -import torchvision.transforms as T from loguru import logger as log from torch import Tensor @@ -16,17 +15,7 @@ from text_recognizer.data.emnist import EMNIST from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.sentence_generator import SentenceGenerator - -DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" -ESSENTIALS_FILENAME = ( - Path(__file__).parents[0].resolve() / "mappings" / "emnist_lines_essentials.json" -) - -SEED = 4711 -IMAGE_HEIGHT = 56 -IMAGE_WIDTH = 1024 -IMAGE_X_PADDING = 28 -MAX_OUTPUT_LENGTH = 89 # Same as IAMLines +from text_recognizer.metadata import emnist_lines as metadata class EMNISTLines(BaseDataModule): @@ -70,25 +59,25 @@ class EMNISTLines(BaseDataModule): self.emnist = EMNIST() max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) - + IMAGE_X_PADDING + + metadata.IMAGE_X_PADDING ) - if max_width >= IMAGE_WIDTH: + if max_width >= metadata.IMAGE_WIDTH: raise ValueError( - f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" + f"max_width {max_width} greater than IMAGE_WIDTH {metadata.IMAGE_WIDTH}" ) - self.dims = (self.emnist.dims[0], IMAGE_HEIGHT, IMAGE_WIDTH) + self.dims = (self.emnist.dims[0], metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH) - if self.max_length >= MAX_OUTPUT_LENGTH: + if self.max_length >= metadata.MAX_OUTPUT_LENGTH: raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") - self.output_dims = (MAX_OUTPUT_LENGTH, 1) + self.output_dims = (metadata.MAX_OUTPUT_LENGTH, 1) @property def data_filename(self) -> Path: """Return name of dataset.""" - return DATA_DIRNAME / ( + return metadata.DATA_DIRNAME / ( f"ml_{self.max_length}_" f"o{self.min_overlap:f}_{self.max_overlap:f}_" f"ntr{self.num_train}_" @@ -100,7 +89,7 @@ class EMNISTLines(BaseDataModule): """Prepare the dataset.""" if self.data_filename.exists(): return - np.random.seed(SEED) + np.random.seed(metadata.SEED) self._generate_data("train") self._generate_data("val") self._generate_data("test") @@ -146,7 +135,8 @@ class EMNISTLines(BaseDataModule): f"{len(self.data_train)}, " f"{len(self.data_val)}, " f"{len(self.data_test)}\n" - f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + "Batch x stats: " + f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data @@ -177,7 +167,7 @@ class EMNISTLines(BaseDataModule): ) num = self.num_test - DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + metadata.PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(self.data_filename, "a") as f: x, y = _create_dataset_of_images( num, @@ -188,7 +178,7 @@ class EMNISTLines(BaseDataModule): self.dims, ) y = convert_strings_to_labels( - y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH + y, self.mapping.inverse_mapping, length=metadata.MAX_OUTPUT_LENGTH ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") @@ -229,7 +219,7 @@ def _construct_image_from_string( H, W = sampled_images[0].shape next_overlap_width = W - int(overlap * W) concatenated_image = torch.zeros((H, width), dtype=torch.uint8) - x = IMAGE_X_PADDING + x = metadata.IMAGE_X_PADDING for image in sampled_images: concatenated_image[:, x : (x + W)] += image x += next_overlap_width @@ -244,7 +234,7 @@ def _create_dataset_of_images( max_overlap: float, dims: Tuple, ) -> Tuple[Tensor, Tensor]: - images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) + images = torch.zeros((num_samples, metadata.IMAGE_HEIGHT, dims[2])) labels = [] for n in range(num_samples): label = sentence_generator.generate() @@ -252,7 +242,7 @@ def _create_dataset_of_images( label, samples_by_char, min_overlap, max_overlap, dims[-1] ) height = crop.shape[0] - y = (IMAGE_HEIGHT - height) // 2 + y = (metadata.IMAGE_HEIGHT - height) // 2 images[n, y : (y + height), :] = crop labels.append(label) return images, labels |