From 46a1472d33d3a4180798492e819f2ec02bc3b1a3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 28 Mar 2021 22:02:24 +0200 Subject: Add refactor of iam lines --- text_recognizer/data/emnist_lines.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) (limited to 'text_recognizer/data/emnist_lines.py') diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 6c14add..72665d0 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,12 +1,11 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import Callable, Dict, Tuple, Sequence +from typing import Callable, Dict, Tuple import h5py from loguru import logger import numpy as np -from PIL import Image import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode @@ -58,6 +57,7 @@ class EMNISTLines(BaseDataModule): self.num_test = num_test self.emnist = EMNIST() + # TODO: fix mapping self.mapping = self.emnist.mapping max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) @@ -66,32 +66,28 @@ class EMNISTLines(BaseDataModule): if max_width >= IMAGE_WIDTH: raise ValueError( - f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" - ) + f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" + ) - self.dims = ( - self.emnist.dims[0], - IMAGE_HEIGHT, - IMAGE_WIDTH - ) + self.dims = (self.emnist.dims[0], IMAGE_HEIGHT, IMAGE_WIDTH) if self.max_length >= MAX_OUTPUT_LENGTH: raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") self.output_dims = (MAX_OUTPUT_LENGTH, 1) - self.data_train = None - self.data_val = None - self.data_test = None + self.data_train: BaseDataset = None + self.data_val: BaseDataset = None + self.data_test: BaseDataset = None @property def data_filename(self) -> Path: """Return name of dataset.""" - return ( - DATA_DIRNAME / (f"ml_{self.max_length}_" + return DATA_DIRNAME / ( + f"ml_{self.max_length}_" f"o{self.min_overlap:f}_{self.max_overlap:f}_" f"ntr{self.num_train}_" f"ntv{self.num_val}_" - f"nte{self.num_test}.h5") + f"nte{self.num_test}.h5" ) def prepare_data(self) -> None: @@ -144,7 +140,10 @@ class EMNISTLines(BaseDataModule): x, y = next(iter(self.train_dataloader())) data = ( - f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" + "Train/val/test sizes: " + f"{len(self.data_train)}, " + f"{len(self.data_val)}, " + f"{len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) @@ -223,7 +222,6 @@ def _construct_image_from_string( ) -> torch.Tensor: overlap = np.random.uniform(min_overlap, max_overlap) sampled_images = _select_letter_samples_for_string(string, samples_by_char) - N = len(sampled_images) H, W = sampled_images[0].shape next_overlap_width = W - int(overlap * W) concatenated_image = torch.zeros((H, width), dtype=torch.uint8) -- cgit v1.2.3-70-g09d2