From 45c584098ce2e01741300cf92c21ece47f54fbe3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 19 Sep 2021 21:03:38 +0200 Subject: Linting of eminst and iam lines --- text_recognizer/data/emnist.py | 3 ++- text_recognizer/data/emnist_mapping.py | 7 +++++++ text_recognizer/data/iam_lines.py | 26 ++++++++++++++++---------- 3 files changed, 25 insertions(+), 11 deletions(-) (limited to 'text_recognizer/data') diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index c6be123..9ec6efe 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -3,7 +3,7 @@ import json import os from pathlib import Path import shutil -from typing import Callable, Dict, List, Optional, Set, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple import zipfile import attr @@ -50,6 +50,7 @@ class EMNIST(BaseDataModule): transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) def __attrs_post_init__(self) -> None: + """Post init configuration.""" self.dims = (1, *self.mapping.input_size) def prepare_data(self) -> None: diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py index 4406db7..12ac809 100644 --- a/text_recognizer/data/emnist_mapping.py +++ b/text_recognizer/data/emnist_mapping.py @@ -9,6 +9,8 @@ from text_recognizer.data.emnist import emnist_mapping class EmnistMapping(AbstractMapping): + """Mapping for EMNIST labels.""" + def __init__( self, extra_symbols: Optional[Set[str]] = None, lower: bool = True ) -> None: @@ -32,22 +34,27 @@ class EmnistMapping(AbstractMapping): self.mapping = [c for c in self.mapping if not c.isupper()] def get_token(self, index: Union[int, Tensor]) -> str: + """Returns token for index value.""" if (index := int(index)) <= len(self.mapping): return self.mapping[index] raise KeyError(f"Index ({index}) not in mapping.") def get_index(self, token: str) -> Tensor: + """Returns index value of token.""" if token in self.inverse_mapping: return torch.LongTensor([self.inverse_mapping[token]]) raise KeyError(f"Token ({token}) not found in inverse mapping.") def get_text(self, indices: Union[List[int], Tensor]) -> str: + """Returns the text from a list of indices.""" if isinstance(indices, Tensor): indices = indices.tolist() return "".join([self.mapping[index] for index in indices]) def get_indices(self, text: str) -> Tensor: + """Returns tensor of indices for a string.""" return Tensor([self.inverse_mapping[token] for token in text]) def __getitem__(self, x: Union[int, Tensor]) -> str: + """Returns text for a list of indices.""" return self.get_token(x) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index aba38f9..d456e64 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -10,21 +10,21 @@ from typing import List, Sequence, Tuple import attr from loguru import logger as log -from PIL import Image, ImageFile, ImageOps import numpy as np +from PIL import Image, ImageFile, ImageOps from torch import Tensor import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode +from text_recognizer.data import image_utils +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import ( BaseDataset, convert_strings_to_labels, split_dataset, ) -from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.emnist_mapping import EmnistMapping from text_recognizer.data.iam import IAM -from text_recognizer.data import image_utils ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -82,7 +82,7 @@ class IAMLines(BaseDataModule): x_train, labels_train = load_line_crops_and_labels( "train", PROCESSED_DATA_DIRNAME ) - if self.output_dims[0] < max([len(l) for l in labels_train]) + 2: + if self.output_dims[0] < max([len(labels) for labels in labels_train]) + 2: raise ValueError("Target length longer than max output length.") y_train = convert_strings_to_labels( @@ -101,7 +101,7 @@ class IAMLines(BaseDataModule): "test", PROCESSED_DATA_DIRNAME ) - if self.output_dims[0] < max([len(l) for l in labels_test]) + 2: + if self.output_dims[0] < max([len(labels) for labels in labels_test]) + 2: raise ValueError("Taget length longer than max output length.") y_test = convert_strings_to_labels( @@ -139,10 +139,14 @@ class IAMLines(BaseDataModule): f"{len(self.data_train)}, " f"{len(self.data_val)}, " f"{len(self.data_test)}\n" - f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" - f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" - f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" - f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" + "Train Batch x stats: " + f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + "Train Batch y stats: " + f"{(y.shape, y.dtype, y.min(), y.max())}\n" + "Test Batch x stats: " + f"{(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" + "Test Batch y stats: " + f"{(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data @@ -170,6 +174,7 @@ def line_crops_and_labels(iam: IAM, split: str) -> Tuple[List, List]: def save_images_and_labels( crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path ) -> None: + """Saves generated images and labels to disk.""" (data_dirname / split).mkdir(parents=True, exist_ok=True) with (data_dirname / split / "_labels.json").open(mode="w") as f: @@ -200,7 +205,7 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li def get_transform(image_width: int, augment: bool = False) -> T.Compose: - """Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise.""" + """Augment with brigthness, rotation, slant, translation, scale, and noise.""" def embed_crop( crop: Image, augment: bool = augment, image_width: int = image_width @@ -245,4 +250,5 @@ def get_transform(image_width: int, augment: bool = False) -> T.Compose: def generate_iam_lines() -> None: + """Displays Iam Lines dataset statistics.""" load_and_print_info(IAMLines) -- cgit v1.2.3-70-g09d2