summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:03:38 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:03:38 +0200
commit45c584098ce2e01741300cf92c21ece47f54fbe3 (patch)
tree2add7ac020c93eee401e3d43695ceecfa1218f20 /text_recognizer
parent9b2ecf296b196432a45eca14300e00b78972e44f (diff)
Linting of eminst and iam lines
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/emnist.py3
-rw-r--r--text_recognizer/data/emnist_mapping.py7
-rw-r--r--text_recognizer/data/iam_lines.py26
3 files changed, 25 insertions, 11 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index c6be123..9ec6efe 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -3,7 +3,7 @@ import json
import os
from pathlib import Path
import shutil
-from typing import Callable, Dict, List, Optional, Set, Sequence, Tuple
+from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
import zipfile
import attr
@@ -50,6 +50,7 @@ class EMNIST(BaseDataModule):
transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
self.dims = (1, *self.mapping.input_size)
def prepare_data(self) -> None:
diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py
index 4406db7..12ac809 100644
--- a/text_recognizer/data/emnist_mapping.py
+++ b/text_recognizer/data/emnist_mapping.py
@@ -9,6 +9,8 @@ from text_recognizer.data.emnist import emnist_mapping
class EmnistMapping(AbstractMapping):
+ """Mapping for EMNIST labels."""
+
def __init__(
self, extra_symbols: Optional[Set[str]] = None, lower: bool = True
) -> None:
@@ -32,22 +34,27 @@ class EmnistMapping(AbstractMapping):
self.mapping = [c for c in self.mapping if not c.isupper()]
def get_token(self, index: Union[int, Tensor]) -> str:
+ """Returns token for index value."""
if (index := int(index)) <= len(self.mapping):
return self.mapping[index]
raise KeyError(f"Index ({index}) not in mapping.")
def get_index(self, token: str) -> Tensor:
+ """Returns index value of token."""
if token in self.inverse_mapping:
return torch.LongTensor([self.inverse_mapping[token]])
raise KeyError(f"Token ({token}) not found in inverse mapping.")
def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ """Returns the text from a list of indices."""
if isinstance(indices, Tensor):
indices = indices.tolist()
return "".join([self.mapping[index] for index in indices])
def get_indices(self, text: str) -> Tensor:
+ """Returns tensor of indices for a string."""
return Tensor([self.inverse_mapping[token] for token in text])
def __getitem__(self, x: Union[int, Tensor]) -> str:
+ """Returns text for a list of indices."""
return self.get_token(x)
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index aba38f9..d456e64 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -10,21 +10,21 @@ from typing import List, Sequence, Tuple
import attr
from loguru import logger as log
-from PIL import Image, ImageFile, ImageOps
import numpy as np
+from PIL import Image, ImageFile, ImageOps
from torch import Tensor
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
+from text_recognizer.data import image_utils
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import (
BaseDataset,
convert_strings_to_labels,
split_dataset,
)
-from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
-from text_recognizer.data import image_utils
ImageFile.LOAD_TRUNCATED_IMAGES = True
@@ -82,7 +82,7 @@ class IAMLines(BaseDataModule):
x_train, labels_train = load_line_crops_and_labels(
"train", PROCESSED_DATA_DIRNAME
)
- if self.output_dims[0] < max([len(l) for l in labels_train]) + 2:
+ if self.output_dims[0] < max([len(labels) for labels in labels_train]) + 2:
raise ValueError("Target length longer than max output length.")
y_train = convert_strings_to_labels(
@@ -101,7 +101,7 @@ class IAMLines(BaseDataModule):
"test", PROCESSED_DATA_DIRNAME
)
- if self.output_dims[0] < max([len(l) for l in labels_test]) + 2:
+ if self.output_dims[0] < max([len(labels) for labels in labels_test]) + 2:
raise ValueError("Taget length longer than max output length.")
y_test = convert_strings_to_labels(
@@ -139,10 +139,14 @@ class IAMLines(BaseDataModule):
f"{len(self.data_train)}, "
f"{len(self.data_val)}, "
f"{len(self.data_test)}\n"
- f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
- f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
- f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
- f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
+ "Train Batch x stats: "
+ f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ "Train Batch y stats: "
+ f"{(y.shape, y.dtype, y.min(), y.max())}\n"
+ "Test Batch x stats: "
+ f"{(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
+ "Test Batch y stats: "
+ f"{(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
)
return basic + data
@@ -170,6 +174,7 @@ def line_crops_and_labels(iam: IAM, split: str) -> Tuple[List, List]:
def save_images_and_labels(
crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path
) -> None:
+ """Saves generated images and labels to disk."""
(data_dirname / split).mkdir(parents=True, exist_ok=True)
with (data_dirname / split / "_labels.json").open(mode="w") as f:
@@ -200,7 +205,7 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li
def get_transform(image_width: int, augment: bool = False) -> T.Compose:
- """Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise."""
+ """Augment with brigthness, rotation, slant, translation, scale, and noise."""
def embed_crop(
crop: Image, augment: bool = augment, image_width: int = image_width
@@ -245,4 +250,5 @@ def get_transform(image_width: int, augment: bool = False) -> T.Compose:
def generate_iam_lines() -> None:
+ """Displays Iam Lines dataset statistics."""
load_and_print_info(IAMLines)