summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_lines.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_lines.py')
-rw-r--r--text_recognizer/data/iam_lines.py26
1 files changed, 16 insertions, 10 deletions
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index aba38f9..d456e64 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -10,21 +10,21 @@ from typing import List, Sequence, Tuple
import attr
from loguru import logger as log
-from PIL import Image, ImageFile, ImageOps
import numpy as np
+from PIL import Image, ImageFile, ImageOps
from torch import Tensor
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
+from text_recognizer.data import image_utils
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import (
BaseDataset,
convert_strings_to_labels,
split_dataset,
)
-from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
-from text_recognizer.data import image_utils
ImageFile.LOAD_TRUNCATED_IMAGES = True
@@ -82,7 +82,7 @@ class IAMLines(BaseDataModule):
x_train, labels_train = load_line_crops_and_labels(
"train", PROCESSED_DATA_DIRNAME
)
- if self.output_dims[0] < max([len(l) for l in labels_train]) + 2:
+ if self.output_dims[0] < max([len(labels) for labels in labels_train]) + 2:
raise ValueError("Target length longer than max output length.")
y_train = convert_strings_to_labels(
@@ -101,7 +101,7 @@ class IAMLines(BaseDataModule):
"test", PROCESSED_DATA_DIRNAME
)
- if self.output_dims[0] < max([len(l) for l in labels_test]) + 2:
+ if self.output_dims[0] < max([len(labels) for labels in labels_test]) + 2:
raise ValueError("Taget length longer than max output length.")
y_test = convert_strings_to_labels(
@@ -139,10 +139,14 @@ class IAMLines(BaseDataModule):
f"{len(self.data_train)}, "
f"{len(self.data_val)}, "
f"{len(self.data_test)}\n"
- f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
- f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
- f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
- f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
+ "Train Batch x stats: "
+ f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ "Train Batch y stats: "
+ f"{(y.shape, y.dtype, y.min(), y.max())}\n"
+ "Test Batch x stats: "
+ f"{(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
+ "Test Batch y stats: "
+ f"{(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
)
return basic + data
@@ -170,6 +174,7 @@ def line_crops_and_labels(iam: IAM, split: str) -> Tuple[List, List]:
def save_images_and_labels(
crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path
) -> None:
+ """Saves generated images and labels to disk."""
(data_dirname / split).mkdir(parents=True, exist_ok=True)
with (data_dirname / split / "_labels.json").open(mode="w") as f:
@@ -200,7 +205,7 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li
def get_transform(image_width: int, augment: bool = False) -> T.Compose:
- """Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise."""
+ """Augment with brigthness, rotation, slant, translation, scale, and noise."""
def embed_crop(
crop: Image, augment: bool = augment, image_width: int = image_width
@@ -245,4 +250,5 @@ def get_transform(image_width: int, augment: bool = False) -> T.Compose:
def generate_iam_lines() -> None:
+ """Displays Iam Lines dataset statistics."""
load_and_print_info(IAMLines)