summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-03 21:59:07 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-03 21:59:07 +0200
commit07f2cc3665a1a60efe8ed8073cad6ac4f344b2c2 (patch)
treed24ae8e3b9b39bfcfb3b850b30cb966eb3b064a7 /text_recognizer/data
parent3196144ec99e803cef218295ddea592748931c57 (diff)
Add IAM paragraphs dataset
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/emnist.py10
-rw-r--r--text_recognizer/data/iam_lines.py4
-rw-r--r--text_recognizer/data/iam_paragraphs.py310
-rw-r--r--text_recognizer/data/iam_preprocessor.py6
-rw-r--r--text_recognizer/data/make_wordpieces.py2
-rw-r--r--text_recognizer/data/transforms.py11
6 files changed, 330 insertions, 13 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 3e10b5f..eda490a 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -1,6 +1,6 @@
"""EMNIST dataset: downloads it from FSDL aws url if not present."""
from pathlib import Path
-from typing import Dict, List, Sequence, Tuple
+from typing import Dict, List, Optional, Sequence, Tuple
import json
import os
import shutil
@@ -52,7 +52,7 @@ class EMNIST(BaseDataModule):
self.data_val = None
self.data_test = None
self.transform = transforms.Compose([transforms.ToTensor()])
- self.dims = (1, * self.input_shape)
+ self.dims = (1, *self.input_shape)
self.output_dims = (1,)
def prepare_data(self) -> None:
@@ -95,13 +95,17 @@ class EMNIST(BaseDataModule):
return basic + data
-def emnist_mapping() -> Tuple[List, Dict[str, int], List[int]]:
+def emnist_mapping(
+ extra_symbols: Optional[List[str]],
+) -> Tuple[List, Dict[str, int], List[int]]:
"""Return the EMNIST mapping."""
if not ESSENTIALS_FILENAME.exists():
download_and_process_emnist()
with ESSENTIALS_FILENAME.open() as f:
essentials = json.load(f)
mapping = list(essentials["characters"])
+ if extra_symbols is not None:
+ mapping += extra_symbols
inverse_mapping = {v: k for k, v in enumerate(mapping)}
input_shape = essentials["input_shape"]
return mapping, inverse_mapping, input_shape
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 391075a..78bc8e1 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -211,7 +211,9 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li
def get_transform(image_width: int, augment: bool = False) -> transforms.Compose:
"""Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise."""
- def embed_crop(crop: Image, augment: bool = augment, image_width: int = image_width) -> Image:
+ def embed_crop(
+ crop: Image, augment: bool = augment, image_width: int = image_width
+ ) -> Image:
# Crop is PIL.Image of dtype="L" (so value range is [0, 255])
image = Image.new("L", (image_width, IMAGE_HEIGHT))
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
new file mode 100644
index 0000000..402a8d4
--- /dev/null
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -0,0 +1,310 @@
+"""IAM Paragraphs Dataset class."""
+import json
+from pathlib import Path
+from typing import Dict, List, Optional, Sequence, Tuple
+
+from loguru import logger
+import numpy as np
+from PIL import Image, ImageFile, ImageOps
+import torch
+import torchvision.transforms as transforms
+from torchvision.transforms.functional import InterpolationMode
+from tqdm import tqdm
+
+from text_recognizer.data.base_dataset import (
+ BaseDataset,
+ convert_strings_to_labels,
+ split_dataset,
+)
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
+from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.iam import IAM
+
+
+PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs"
+
+NEW_LINE_TOKEN = "\n"
+
+SEED = 4711
+IMAGE_SCALE_FACTOR = 2
+IMAGE_HEIGHT = 1152 // IMAGE_SCALE_FACTOR
+IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR
+MAX_LABEL_LENGTH = 682
+
+
+class IAMParagraphs(BaseDataModule):
+ """IAM handwriting database paragraphs."""
+
+ def __init__(
+ self,
+ batch_size: int = 128,
+ num_workers: int = 0,
+ train_fraction: float = 0.8,
+ augment: bool = True,
+ ) -> None:
+ super().__init__(batch_size, num_workers)
+ # TODO: pass in transform and target transform
+ # TODO: pass in mapping
+ self.augment = augment
+ self.mapping, self.inverse_mapping, _ = emnist_mapping(
+ extra_symbols=[NEW_LINE_TOKEN]
+ )
+ self.train_fraction = train_fraction
+
+ self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
+ self.output_dims = (MAX_LABEL_LENGTH, 1)
+ self.data_train: BaseDataset = None
+ self.data_val: BaseDataset = None
+ self.data_test: BaseDataset = None
+
+ def prepare_data(self) -> None:
+ """Create data for training/testing."""
+ if PROCESSED_DATA_DIRNAME.exists():
+ return
+
+ logger.info(
+ "Cropping IAM paragraph regions and saving them along with labels..."
+ )
+
+ iam = IAM()
+ iam.prepare_data()
+
+ properties = {}
+ for split in ["train", "test"]:
+ crops, labels = _get_paragraph_crops_and_labels(iam=iam, split=split)
+ _save_crops_and_labels(crops=crops, labels=labels, split=split)
+
+ properties.update(
+ {
+ id_: {
+ "crop_shape": crops[id_].size[::-1],
+ "label_length": len(label),
+ "num_lines": _num_lines(label),
+ }
+ for id_, label in labels.items()
+ }
+ )
+
+ with (PROCESSED_DATA_DIRNAME / "_properties.json").open("w") as f:
+ json.dump(properties, f, indent=4)
+
+ def setup(self, stage: str = None) -> None:
+ """Loads the data for training/testing."""
+
+ def _load_dataset(split: str, augment: bool) -> BaseDataset:
+ crops, labels = _load_processed_crops_and_labels(split)
+ data = [_resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
+ targets = convert_strings_to_labels(
+ strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]
+ )
+ return BaseDataset(
+ data,
+ targets,
+ transform=_get_transform(image_shape=self.dims[1:], augment=augment),
+ )
+
+ logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
+ _validate_data_dims(input_dims=self.dims, output_dims=self.output_dims)
+
+ if stage == "fit" or stage is None:
+ data_train = _load_dataset(split="train", augment=self.augment)
+ self.data_train, self.data_val = split_dataset(
+ dataset=data_train, fraction=self.train_fraction, seed=SEED
+ )
+
+ if stage == "test" or stage is None:
+ self.data_test = _load_dataset(split="test", augment=False)
+
+ def __repr__(self) -> str:
+ """Return information about the dataset."""
+ basic = (
+ "IAM Paragraphs Dataset\n"
+ f"Num classes: {len(self.mapping)}\n"
+ f"Input dims: {self.dims}\n"
+ f"Output dims: {self.output_dims}\n"
+ )
+
+ if not any([self.data_train, self.data_val, self.data_test]):
+ return basic
+
+ x, y = next(iter(self.train_dataloader()))
+ xt, yt = next(iter(self.test_dataloader()))
+ data = (
+ "Train/val/test sizes: "
+ f"{len(self.data_train)}, "
+ f"{len(self.data_val)}, "
+ f"{len(self.data_test)}\n"
+ f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
+ f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
+ f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
+ )
+ return basic + data
+
+
+def _get_dataset_properties() -> Dict:
+ """Return properties describing the overall dataset."""
+ with (PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f:
+ properties = json.load(f)
+
+ def _get_property_values(key: str) -> List:
+ return [value[key] for value in properties.values()]
+
+ crop_shapes = np.array(_get_property_values("crop_shape"))
+ aspect_ratio = crop_shapes[:, 1] / crop_shapes[:, 0]
+ return {
+ "label_length": {
+ "min": min(_get_property_values("label_length")),
+ "max": max(_get_property_values("label_length")),
+ },
+ "num_lines": {
+ "min": min(_get_property_values("num_lines")),
+ "max": max(_get_property_values("num_lines")),
+ },
+ "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
+ "aspect_ratio": {
+ "min": aspect_ratio.min(axis=0),
+ "max": aspect_ratio.max(axis=0),
+ },
+ }
+
+
+def _validate_data_dims(
+ input_dims: Optional[Tuple[int, ...]], output_dims: Optional[Tuple[int, ...]]
+) -> None:
+ """Validates input and output dimensions against the properties of the dataset."""
+ properties = _get_dataset_properties()
+
+ max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR
+ if (
+ input_dims is not None
+ and input_dims[1] < max_image_shape[0]
+ and input_dims[2] < max_image_shape[1]
+ ):
+ raise ValueError(f"{input_dims} less than {max_image_shape}")
+
+ if (
+ output_dims is not None
+ and output_dims[0] < properties["label_length"]["max"] + 2
+ ):
+ raise ValueError(
+ f"{output_dims} less than {properties['label_length']['max'] + 2}"
+ )
+
+
+def _resize_image(image: Image.Image, scale_factor: int) -> Image.Image:
+ """Resize image by scale factor."""
+ if scale_factor == 1:
+ return image
+ return image.resize(
+ (image.width // scale_factor, image.height // scale_factor),
+ resample=Image.BILINEAR,
+ )
+
+
+def _get_paragraph_crops_and_labels(
+ iam: IAM, split: str
+) -> Tuple[Dict[str, Image.Image], Dict[str, str]]:
+ """Load IAM paragraph crops and labels for a given set."""
+ crops = {}
+ labels = {}
+ for form_filename in tqdm(
+ iam.form_filenames, desc=f"Processing {split} paragraphs"
+ ):
+ id_ = form_filename.stem
+ if not iam.split_by_id[id_] == split:
+ continue
+ image = Image.open(form_filename)
+ image = ImageOps.grayscale(image)
+ image = ImageOps.invert(image)
+
+ line_regions = iam.line_regions_by_id[id_]
+ parameter_box = [
+ min([region["x1"] for region in line_regions]),
+ min([region["y1"] for region in line_regions]),
+ max([region["x2"] for region in line_regions]),
+ max([region["y2"] for region in line_regions]),
+ ]
+ lines = iam.line_strings_by_id[id_]
+
+ crops[id_] = image.crop(parameter_box)
+ labels[id_] = NEW_LINE_TOKEN.join(lines)
+
+ if len(crops) != len(labels):
+ raise ValueError(f"Crops ({len(crops)}) does not match labels ({len(labels)})")
+
+ return crops, labels
+
+
+def _save_crops_and_labels(
+ crops: Dict[str, Image.Image], labels: Dict[str, str], split: str
+) -> None:
+ """Save crops, labels, and shapes of crops of a split."""
+ (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True)
+
+ with _labels_filename(split).open("w") as f:
+ json.dump(labels, f, indent=4)
+
+ for id_, crop in crops.items():
+ crop.save(_crop_filename(id_, split))
+
+
+def _load_processed_crops_and_labels(
+ split: str,
+) -> Tuple[Sequence[Image.Image], Sequence[str]]:
+ """Load processed crops and labels for given split."""
+ with _labels_filename(split).open("r") as f:
+ labels = json.load(f)
+
+ sorted_ids = sorted(labels.keys())
+ ordered_crops = [
+ Image.open(_crop_filename(id_, split)).convert("L") for id_ in sorted_ids
+ ]
+ ordered_labels = [labels[id_] for id_ in sorted_ids]
+
+ if len(ordered_crops) != len(ordered_labels):
+ raise ValueError(
+ f"Crops ({len(ordered_crops)}) does not match labels ({len(ordered_labels)})"
+ )
+ return ordered_crops, ordered_labels
+
+
+def _get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose:
+ """Get transformations for images."""
+ if augment:
+ transforms_list = [
+ transforms.RandomCrop(
+ size=image_shape,
+ padding=None,
+ pad_if_needed=True,
+ fill=0,
+ padding_mode="constant",
+ ),
+ transforms.ColorJitter(brightness=(0.8, 1.6)),
+ transforms.RandomAffine(
+ degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
+ ),
+ ]
+ else:
+ transforms_list = [transforms.CenterCrop(image_shape)]
+ transforms_list.append(transforms.ToTensor())
+ return transforms.Compose(transforms_list)
+
+
+def _labels_filename(split: str) -> Path:
+ """Return filename of processed labels."""
+ return PROCESSED_DATA_DIRNAME / split / "_labels.json"
+
+
+def _crop_filename(id: str, split: str) -> Path:
+ """Return filename of processed crop."""
+ return PROCESSED_DATA_DIRNAME / split / f"{id}.png"
+
+
+def _num_lines(label: str) -> int:
+ """Return the number of lines of text in label."""
+ return label.count("\n") + 1
+
+
+def create_iam_paragraphs() -> None:
+ load_and_print_info(IAMParagraphs)
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index a47aeed..3844419 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -166,7 +166,11 @@ def cli(
"""CLI for extracting text data from the iam dataset."""
if data_dir is None:
data_dir = (
- Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb"
+ Path(__file__).resolve().parents[2]
+ / "data"
+ / "downloaded"
+ / "iam"
+ / "iamdb"
)
logger.debug(f"Using data dir: {data_dir}")
if not data_dir.exists():
diff --git a/text_recognizer/data/make_wordpieces.py b/text_recognizer/data/make_wordpieces.py
index e062c4c..ef9eb1b 100644
--- a/text_recognizer/data/make_wordpieces.py
+++ b/text_recognizer/data/make_wordpieces.py
@@ -99,7 +99,7 @@ def cli(
"""CLI for training the sentence piece model."""
if data_dir is None:
data_dir = (
- Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
)
logger.debug(f"Using data dir: {data_dir}")
if not data_dir.exists():
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 2291eec..616e236 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -9,7 +9,8 @@ from torch import Tensor
from text_recognizer.datasets.iam_preprocessor import Preprocessor
from text_recognizer.data.emnist import emnist_mapping
-
+
+
class ToLower:
"""Converts target to lower case."""
@@ -23,15 +24,11 @@ class ToCharcters:
"""Converts integers to characters."""
def __init__(self) -> None:
- self.mapping, _, _ = emnist_mapping()
+ self.mapping, _, _ = emnist_mapping()
def __call__(self, y: Tensor) -> str:
"""Converts a Tensor to a str."""
- return (
- "".join([self.mapping(int(i)) for i in y])
- .strip("<p>")
- .replace(" ", "▁")
- )
+ return "".join([self.mapping(int(i)) for i in y]).strip("<p>").replace(" ", "▁")
class WordPieces: