diff options
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r-- | text_recognizer/data/iam_lines.py | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index aba38f9..d456e64 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -10,21 +10,21 @@ from typing import List, Sequence, Tuple import attr from loguru import logger as log -from PIL import Image, ImageFile, ImageOps import numpy as np +from PIL import Image, ImageFile, ImageOps from torch import Tensor import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode +from text_recognizer.data import image_utils +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import ( BaseDataset, convert_strings_to_labels, split_dataset, ) -from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.emnist_mapping import EmnistMapping from text_recognizer.data.iam import IAM -from text_recognizer.data import image_utils ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -82,7 +82,7 @@ class IAMLines(BaseDataModule): x_train, labels_train = load_line_crops_and_labels( "train", PROCESSED_DATA_DIRNAME ) - if self.output_dims[0] < max([len(l) for l in labels_train]) + 2: + if self.output_dims[0] < max([len(labels) for labels in labels_train]) + 2: raise ValueError("Target length longer than max output length.") y_train = convert_strings_to_labels( @@ -101,7 +101,7 @@ class IAMLines(BaseDataModule): "test", PROCESSED_DATA_DIRNAME ) - if self.output_dims[0] < max([len(l) for l in labels_test]) + 2: + if self.output_dims[0] < max([len(labels) for labels in labels_test]) + 2: raise ValueError("Taget length longer than max output length.") y_test = convert_strings_to_labels( @@ -139,10 +139,14 @@ class IAMLines(BaseDataModule): f"{len(self.data_train)}, " f"{len(self.data_val)}, " f"{len(self.data_test)}\n" - f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" - f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" - f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" - f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" + "Train Batch x stats: " + f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + "Train Batch y stats: " + f"{(y.shape, y.dtype, y.min(), y.max())}\n" + "Test Batch x stats: " + f"{(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" + "Test Batch y stats: " + f"{(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data @@ -170,6 +174,7 @@ def line_crops_and_labels(iam: IAM, split: str) -> Tuple[List, List]: def save_images_and_labels( crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path ) -> None: + """Saves generated images and labels to disk.""" (data_dirname / split).mkdir(parents=True, exist_ok=True) with (data_dirname / split / "_labels.json").open(mode="w") as f: @@ -200,7 +205,7 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li def get_transform(image_width: int, augment: bool = False) -> T.Compose: - """Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise.""" + """Augment with brigthness, rotation, slant, translation, scale, and noise.""" def embed_crop( crop: Image, augment: bool = augment, image_width: int = image_width @@ -245,4 +250,5 @@ def get_transform(image_width: int, augment: bool = False) -> T.Compose: def generate_iam_lines() -> None: + """Displays Iam Lines dataset statistics.""" load_and_print_info(IAMLines) |