summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-28 22:02:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-28 22:02:24 +0200
commit46a1472d33d3a4180798492e819f2ec02bc3b1a3 (patch)
tree22322ed0d8f9f803966ea745ec5bb8c759f8db64 /text_recognizer
parent8248f173132dfb7e47ec62b08e9235990c8626e3 (diff)
Add refactor of iam lines
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/__init__.py3
-rw-r--r--text_recognizer/data/base_data_module.py2
-rw-r--r--text_recognizer/data/base_dataset.py13
-rw-r--r--text_recognizer/data/emnist.py37
-rw-r--r--text_recognizer/data/emnist_lines.py32
-rw-r--r--text_recognizer/data/iam.py39
-rw-r--r--text_recognizer/data/iam_dataset.py133
-rw-r--r--text_recognizer/data/iam_lines.py255
-rw-r--r--text_recognizer/data/iam_lines_dataset.py110
-rw-r--r--text_recognizer/data/iam_preprocessor.py1
-rw-r--r--text_recognizer/data/image_utils.py49
-rw-r--r--text_recognizer/data/sentence_generator.py7
-rw-r--r--text_recognizer/data/transforms.py160
-rw-r--r--text_recognizer/util.py52
14 files changed, 391 insertions, 502 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py
index 2727b20..9a42fa9 100644
--- a/text_recognizer/data/__init__.py
+++ b/text_recognizer/data/__init__.py
@@ -1 +1,4 @@
"""Dataset modules."""
+from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset
+from .base_data_module import BaseDataModule, load_and_print_info
+from .download_utils import download_dataset
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index f5e7300..8b5c188 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
def load_and_print_info(data_module_class: type) -> None:
- """Load EMNISTLines and prints info."""
+ """Load dataset and print dataset information."""
dataset = data_module_class()
dataset.prepare_data()
dataset.setup()
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index a9e9c24..d00daaf 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -71,3 +71,16 @@ def convert_strings_to_labels(
for j, token in enumerate(tokens):
labels[i, j] = mapping[token]
return labels
+
+
+def split_dataset(
+ dataset: BaseDataset, fraction: float, seed: int
+) -> Tuple[BaseDataset, BaseDataset]:
+ """Split dataset into two parts with fraction * size and (1 - fraction) * size."""
+ if fraction >= 1.0:
+ raise ValueError("Fraction cannot be larger greater or equal to 1.0.")
+ split_1 = int(fraction * len(dataset))
+ split_2 = len(dataset) - split_1
+ return torch.utils.data.random_split(
+ dataset, [split_1, split_2], generator=torch.Generator().manual_seed(seed)
+ )
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 7f67893..3e10b5f 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -1,6 +1,6 @@
"""EMNIST dataset: downloads it from FSDL aws url if not present."""
from pathlib import Path
-from typing import Sequence, Tuple
+from typing import Dict, List, Sequence, Tuple
import json
import os
import shutil
@@ -10,11 +10,9 @@ import h5py
import numpy as np
from loguru import logger
import toml
-import torch
-from torch.utils.data import random_split
from torchvision import transforms
-from text_recognizer.data.base_dataset import BaseDataset
+from text_recognizer.data.base_dataset import BaseDataset, split_dataset
from text_recognizer.data.base_data_module import (
BaseDataModule,
load_and_print_info,
@@ -48,23 +46,18 @@ class EMNIST(BaseDataModule):
self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
) -> None:
super().__init__(batch_size, num_workers)
- if not ESSENTIALS_FILENAME.exists():
- _download_and_process_emnist()
- with ESSENTIALS_FILENAME.open() as f:
- essentials = json.load(f)
self.train_fraction = train_fraction
- self.mapping = list(essentials["characters"])
- self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
+ self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping()
self.data_train = None
self.data_val = None
self.data_test = None
self.transform = transforms.Compose([transforms.ToTensor()])
- self.dims = (1, *essentials["input_shape"])
+ self.dims = (1, * self.input_shape)
self.output_dims = (1,)
def prepare_data(self) -> None:
if not PROCESSED_DATA_FILENAME.exists():
- _download_and_process_emnist()
+ download_and_process_emnist()
def setup(self, stage: str = None) -> None:
if stage == "fit" or stage is None:
@@ -75,10 +68,8 @@ class EMNIST(BaseDataModule):
dataset_train = BaseDataset(
self.x_train, self.y_train, transform=self.transform
)
- train_size = int(self.train_fraction * len(dataset_train))
- val_size = len(dataset_train) - train_size
- self.data_train, self.data_val = random_split(
- dataset_train, [train_size, val_size], generator=torch.Generator()
+ self.data_train, self.data_val = split_dataset(
+ dataset_train, fraction=self.train_fraction, seed=SEED
)
if stage == "test" or stage is None:
@@ -104,7 +95,19 @@ class EMNIST(BaseDataModule):
return basic + data
-def _download_and_process_emnist() -> None:
+def emnist_mapping() -> Tuple[List, Dict[str, int], List[int]]:
+ """Return the EMNIST mapping."""
+ if not ESSENTIALS_FILENAME.exists():
+ download_and_process_emnist()
+ with ESSENTIALS_FILENAME.open() as f:
+ essentials = json.load(f)
+ mapping = list(essentials["characters"])
+ inverse_mapping = {v: k for k, v in enumerate(mapping)}
+ input_shape = essentials["input_shape"]
+ return mapping, inverse_mapping, input_shape
+
+
+def download_and_process_emnist() -> None:
metadata = toml.load(METADATA_FILENAME)
download_dataset(metadata, DL_DATA_DIRNAME)
_process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 6c14add..72665d0 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,12 +1,11 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import Callable, Dict, Tuple, Sequence
+from typing import Callable, Dict, Tuple
import h5py
from loguru import logger
import numpy as np
-from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
@@ -58,6 +57,7 @@ class EMNISTLines(BaseDataModule):
self.num_test = num_test
self.emnist = EMNIST()
+ # TODO: fix mapping
self.mapping = self.emnist.mapping
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
@@ -66,32 +66,28 @@ class EMNISTLines(BaseDataModule):
if max_width >= IMAGE_WIDTH:
raise ValueError(
- f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}"
- )
+ f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}"
+ )
- self.dims = (
- self.emnist.dims[0],
- IMAGE_HEIGHT,
- IMAGE_WIDTH
- )
+ self.dims = (self.emnist.dims[0], IMAGE_HEIGHT, IMAGE_WIDTH)
if self.max_length >= MAX_OUTPUT_LENGTH:
raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
self.output_dims = (MAX_OUTPUT_LENGTH, 1)
- self.data_train = None
- self.data_val = None
- self.data_test = None
+ self.data_train: BaseDataset = None
+ self.data_val: BaseDataset = None
+ self.data_test: BaseDataset = None
@property
def data_filename(self) -> Path:
"""Return name of dataset."""
- return (
- DATA_DIRNAME / (f"ml_{self.max_length}_"
+ return DATA_DIRNAME / (
+ f"ml_{self.max_length}_"
f"o{self.min_overlap:f}_{self.max_overlap:f}_"
f"ntr{self.num_train}_"
f"ntv{self.num_val}_"
- f"nte{self.num_test}.h5")
+ f"nte{self.num_test}.h5"
)
def prepare_data(self) -> None:
@@ -144,7 +140,10 @@ class EMNISTLines(BaseDataModule):
x, y = next(iter(self.train_dataloader()))
data = (
- f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
+ "Train/val/test sizes: "
+ f"{len(self.data_train)}, "
+ f"{len(self.data_val)}, "
+ f"{len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
)
@@ -223,7 +222,6 @@ def _construct_image_from_string(
) -> torch.Tensor:
overlap = np.random.uniform(min_overlap, max_overlap)
sampled_images = _select_letter_samples_for_string(string, samples_by_char)
- N = len(sampled_images)
H, W = sampled_images[0].shape
next_overlap_width = W - int(overlap * W)
concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index fcfe9a7..01272ba 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -60,23 +60,36 @@ class IAM(BaseDataModule):
@property
def split_by_id(self) -> Dict[str, str]:
- return {filename.stem: "test" if filename.stem in self.metadata["test_ids"] else "trainval" for filename in self.form_filenames}
+ return {
+ filename.stem: "test"
+ if filename.stem in self.metadata["test_ids"]
+ else "train"
+ for filename in self.form_filenames
+ }
@cachedproperty
def line_strings_by_id(self) -> Dict[str, List[str]]:
"""Return a dict from name of IAM form to list of line texts in it."""
- return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames}
+ return {
+ filename.stem: _get_line_strings_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
@cachedproperty
def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]:
"""Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it."""
- return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames}
+ return {
+ filename.stem: _get_line_regions_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
def __repr__(self) -> str:
"""Return info about the dataset."""
- return ("IAM Dataset\n"
- f"Num forms total: {len(self.xml_filenames)}\n"
- f"Num in test set: {len(self.metadata['test_ids'])}\n")
+ return (
+ "IAM Dataset\n"
+ f"Num forms total: {len(self.xml_filenames)}\n"
+ f"Num in test set: {len(self.metadata['test_ids'])}\n"
+ )
def _extract_raw_dataset(filename: Path, dirname: Path) -> None:
@@ -92,7 +105,7 @@ def _get_line_strings_from_xml_file(filename: str) -> List[str]:
"""Get the text content of each line. Note that we replace &quot: with "."""
xml_root_element = ElementTree.parse(filename).getroot() # nosec
xml_line_elements = xml_root_element.findall("handwritten-part/line")
- return [el.attrib["text"].replace("&quot", '"') for el in xml_line_elements]
+ return [el.attrib["text"].replace("&quot;", '"') for el in xml_line_elements]
def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
@@ -107,13 +120,13 @@ def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]:
x1s = [int(el.attrib["x"]) for el in word_elements]
y1s = [int(el.attrib["y"]) for el in word_elements]
x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
- y2s = [int(el.attrib["x"]) + int(el.attrib["height"]) for el in word_elements]
+ y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
return {
- "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "x2": min(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- "y2": min(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- }
+ "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ }
def download_iam() -> None:
diff --git a/text_recognizer/data/iam_dataset.py b/text_recognizer/data/iam_dataset.py
deleted file mode 100644
index a8998b9..0000000
--- a/text_recognizer/data/iam_dataset.py
+++ /dev/null
@@ -1,133 +0,0 @@
-"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities."""
-import os
-from typing import Any, Dict, List
-import zipfile
-
-from boltons.cacheutils import cachedproperty
-import defusedxml.ElementTree as ET
-from loguru import logger
-import toml
-
-from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME
-
-RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
-METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
-EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb"
-RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
-
-DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
-LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates.
-
-
-class IamDataset:
- """IAM dataset.
-
- "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
- which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels."
- From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
-
- The data split we will use is
- IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines.
- The validation set has been merged into the train set.
- The train set has 7,101 lines from 326 writers.
- The test set has 1,861 lines from 128 writers.
- The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
-
- """
-
- def __init__(self) -> None:
- self.metadata = toml.load(METADATA_FILENAME)
-
- def load_or_generate_data(self) -> None:
- """Downloads IAM dataset if xml files does not exist."""
- if not self.xml_filenames:
- self._download_iam()
-
- @property
- def xml_filenames(self) -> List:
- """List of xml filenames."""
- return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
-
- @property
- def form_filenames(self) -> List:
- """List of forms filenames."""
- return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
-
- def _download_iam(self) -> None:
- curdir = os.getcwd()
- os.chdir(RAW_DATA_DIRNAME)
- _download_raw_dataset(self.metadata)
- _extract_raw_dataset(self.metadata)
- os.chdir(curdir)
-
- @property
- def form_filenames_by_id(self) -> Dict:
- """Creates a dictionary with filenames as keys and forms as values."""
- return {filename.stem: filename for filename in self.form_filenames}
-
- @cachedproperty
- def line_strings_by_id(self) -> Dict:
- """Return a dict from name of IAM form to a list of line texts in it."""
- return {
- filename.stem: _get_line_strings_from_xml_file(filename)
- for filename in self.xml_filenames
- }
-
- @cachedproperty
- def line_regions_by_id(self) -> Dict:
- """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it."""
- return {
- filename.stem: _get_line_regions_from_xml_file(filename)
- for filename in self.xml_filenames
- }
-
- def __repr__(self) -> str:
- """Print info about dataset."""
- return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n"
-
-
-def _extract_raw_dataset(metadata: Dict) -> None:
- logger.info("Extracting IAM data.")
- with zipfile.ZipFile(metadata["filename"], "r") as zip_file:
- zip_file.extractall()
-
-
-def _get_line_strings_from_xml_file(filename: str) -> List[str]:
- """Get the text content of each line. Note that we replace &quot; with "."""
- xml_root_element = ET.parse(filename).getroot() # nosec
- xml_line_elements = xml_root_element.findall("handwritten-part/line")
- return [el.attrib["text"].replace("&quot;", '"') for el in xml_line_elements]
-
-
-def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
- """Get the line region dict for each line."""
- xml_root_element = ET.parse(filename).getroot() # nosec
- xml_line_elements = xml_root_element.findall("handwritten-part/line")
- return [_get_line_region_from_xml_element(el) for el in xml_line_elements]
-
-
-def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]:
- """Extracts coordinates for each line of text."""
- # TODO: fix input!
- word_elements = xml_line.findall("word/cmp")
- x1s = [int(el.attrib["x"]) for el in word_elements]
- y1s = [int(el.attrib["y"]) for el in word_elements]
- x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
- y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
- return {
- "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- }
-
-
-def main() -> None:
- """Initializes the dataset and print info about the dataset."""
- dataset = IamDataset()
- dataset.load_or_generate_data()
- print(dataset)
-
-
-if __name__ == "__main__":
- main()
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
new file mode 100644
index 0000000..391075a
--- /dev/null
+++ b/text_recognizer/data/iam_lines.py
@@ -0,0 +1,255 @@
+"""Class for IAM Lines dataset.
+
+If not created, will generate a handwritten lines dataset from the IAM paragraphs
+dataset.
+
+"""
+import json
+from pathlib import Path
+import random
+from typing import List, Sequence, Tuple
+
+from loguru import logger
+from PIL import Image, ImageFile, ImageOps
+import numpy as np
+from torch import Tensor
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+
+from text_recognizer.data.base_dataset import (
+ BaseDataset,
+ convert_strings_to_labels,
+ split_dataset,
+)
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
+from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.iam import IAM
+from text_recognizer.data import image_utils
+
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+SEED = 4711
+PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines"
+IMAGE_HEIGHT = 56
+IMAGE_WIDTH = 1024
+
+
+class IAMLines(BaseDataModule):
+ """IAM handwritten lines dataset."""
+
+ def __init__(
+ self,
+ augment: bool = True,
+ fraction: float = 0.8,
+ batch_size: int = 128,
+ num_workers: int = 0,
+ ) -> None:
+ # TODO: add transforms
+ super().__init__(batch_size, num_workers)
+ self.augment = augment
+ self.fraction = fraction
+ self.mapping, self.inverse_mapping, _ = emnist_mapping()
+ self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
+ self.output_dims = (89, 1)
+ self.data_train: BaseDataset = None
+ self.data_val: BaseDataset = None
+ self.data_test: BaseDataset = None
+
+ def prepare_data(self) -> None:
+ """Creates the IAM lines dataset if not existing."""
+ if PROCESSED_DATA_DIRNAME.exists():
+ return
+
+ logger.info("Cropping IAM lines regions...")
+ iam = IAM()
+ iam.prepare_data()
+ crops_train, labels_train = line_crops_and_labels(iam, "train")
+ crops_test, labels_test = line_crops_and_labels(iam, "test")
+
+ shapes = np.array([crop.size for crop in crops_train + crops_test])
+ aspect_ratios = shapes[:, 0] / shapes[:, 1]
+
+ logger.info("Saving images, labels, and statistics...")
+ save_images_and_labels(
+ crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
+ )
+ save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME)
+
+ with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="w") as f:
+ f.write(str(aspect_ratios.max()))
+
+ def setup(self, stage: str = None) -> None:
+ """Load data for training/testing."""
+ with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="r") as f:
+ max_aspect_ratio = float(f.read())
+ image_width = int(IMAGE_HEIGHT * max_aspect_ratio)
+ if image_width >= IMAGE_WIDTH:
+ raise ValueError("image_width equal or greater than IMAGE_WIDTH")
+
+ if stage == "fit" or stage is None:
+ x_train, labels_train = load_line_crops_and_labels(
+ "train", PROCESSED_DATA_DIRNAME
+ )
+ if self.output_dims[0] < max([len(l) for l in labels_train]) + 2:
+ raise ValueError("Target length longer than max output length.")
+
+ y_train = convert_strings_to_labels(
+ labels_train, self.inverse_mapping, length=self.output_dims[0]
+ )
+ data_train = BaseDataset(
+ x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment)
+ )
+
+ self.data_train, self.data_val = split_dataset(
+ dataset=data_train, fraction=self.fraction, seed=SEED
+ )
+
+ if stage == "test" or stage is None:
+ x_test, labels_test = load_line_crops_and_labels(
+ "test", PROCESSED_DATA_DIRNAME
+ )
+
+ if self.output_dims[0] < max([len(l) for l in labels_test]) + 2:
+ raise ValueError("Taget length longer than max output length.")
+
+ y_test = convert_strings_to_labels(
+ labels_test, self.inverse_mapping, length=self.output_dims[0]
+ )
+ self.data_test = BaseDataset(
+ x_test, y_test, transform=get_transform(IMAGE_WIDTH)
+ )
+
+ if stage is None:
+ self._verify_output_dims(labels_train, labels_test)
+
+ def _verify_output_dims(self, labels_train: Tensor, labels_test: Tensor) -> None:
+ max_label_length = max([len(label) for label in labels_train + labels_test]) + 2
+ output_dims = (max_label_length, 1)
+ if output_dims != self.output_dims:
+ raise ValueError("Output dim does not match expected output dims.")
+
+ def __repr__(self) -> str:
+ """Return information about the dataset."""
+ basic = (
+ "IAM Lines dataset\n"
+ f"Num classes: {len(self.mapping)}\n"
+ f"Input dims: {self.dims}\n"
+ f"Output dims: {self.output_dims}\n"
+ )
+
+ if not any([self.data_train, self.data_val, self.data_test]):
+ return basic
+
+ x, y = next(iter(self.train_dataloader()))
+ xt, yt = next(iter(self.test_dataloader()))
+ data = (
+ "Train/val/test sizes: "
+ f"{len(self.data_train)}, "
+ f"{len(self.data_val)}, "
+ f"{len(self.data_test)}\n"
+ f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
+ f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
+ f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
+ )
+ return basic + data
+
+
+def line_crops_and_labels(iam: IAM, split: str) -> Tuple[List, List]:
+ """Load IAM line labels and regions, and load image crops."""
+ crops = []
+ labels = []
+ for filename in iam.form_filenames:
+ if not iam.split_by_id[filename.stem] == split:
+ continue
+ image = image_utils.read_image_pil(filename)
+ image = ImageOps.grayscale(image)
+ image = ImageOps.invert(image)
+ labels += iam.line_strings_by_id[filename.stem]
+ crops += [
+ image.crop([region[box] for box in ["x1", "y1", "x2", "y2"]])
+ for region in iam.line_regions_by_id[filename.stem]
+ ]
+ if len(crops) != len(labels):
+ raise ValueError("Length of crops does not match length of labels")
+ return crops, labels
+
+
+def save_images_and_labels(
+ crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path
+) -> None:
+ (data_dirname / split).mkdir(parents=True, exist_ok=True)
+
+ with (data_dirname / split / "_labels.json").open(mode="w") as f:
+ json.dump(labels, f)
+
+ for index, crop in enumerate(crops):
+ crop.save(data_dirname / split / f"{index}.png")
+
+
+def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, List]:
+ """Load line crops and labels for given split from processed directoru."""
+ with (data_dirname / split / "_labels.json").open(mode="r") as f:
+ labels = json.load(f)
+
+ crop_filenames = sorted(
+ (data_dirname / split).glob("*.png"),
+ key=lambda filename: int(Path(filename).stem),
+ )
+ crops = [
+ image_utils.read_image_pil(filename, grayscale=True)
+ for filename in crop_filenames
+ ]
+
+ if len(crops) != len(labels):
+ raise ValueError("Length of crops does not match length of labels")
+
+ return crops, labels
+
+
+def get_transform(image_width: int, augment: bool = False) -> transforms.Compose:
+ """Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise."""
+
+ def embed_crop(crop: Image, augment: bool = augment, image_width: int = image_width) -> Image:
+ # Crop is PIL.Image of dtype="L" (so value range is [0, 255])
+ image = Image.new("L", (image_width, IMAGE_HEIGHT))
+
+ # Resize crop.
+ crop_width, crop_height = crop.size
+ new_crop_height = IMAGE_HEIGHT
+ new_crop_width = int(new_crop_height * (crop_width / crop_height))
+
+ if augment:
+ # Add random stretching
+ new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
+ new_crop_width = min(new_crop_width, image_width)
+ crop_resized = crop.resize(
+ (new_crop_width, new_crop_height), resample=Image.BILINEAR
+ )
+
+ # Embed in image
+ x = min(28, image_width - new_crop_width)
+ y = IMAGE_HEIGHT - new_crop_height
+ image.paste(crop_resized, (x, y))
+
+ return image
+
+ transfroms_list = [transforms.Lambda(embed_crop)]
+
+ if augment:
+ transfroms_list += [
+ transforms.ColorJitter(brightness=(0.8, 1.6)),
+ transforms.RandomAffine(
+ degrees=1,
+ shear=(-30, 20),
+ interpolation=InterpolationMode.BILINEAR,
+ fill=0,
+ ),
+ ]
+ transfroms_list.append(transforms.ToTensor())
+ return transforms.Compose(transfroms_list)
+
+
+def generate_iam_lines() -> None:
+ load_and_print_info(IAMLines)
diff --git a/text_recognizer/data/iam_lines_dataset.py b/text_recognizer/data/iam_lines_dataset.py
deleted file mode 100644
index 1cb84bd..0000000
--- a/text_recognizer/data/iam_lines_dataset.py
+++ /dev/null
@@ -1,110 +0,0 @@
-"""IamLinesDataset class."""
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import h5py
-from loguru import logger
-import torch
-from torch import Tensor
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets.dataset import Dataset
-from text_recognizer.datasets.util import (
- compute_sha256,
- DATA_DIRNAME,
- download_url,
- EmnistMapper,
-)
-
-
-PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
-PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
-PROCESSED_DATA_URL = (
- "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
-)
-
-
-class IamLinesDataset(Dataset):
- """IAM lines datasets for handwritten text lines."""
-
- def __init__(
- self,
- train: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- init_token: Optional[str] = None,
- pad_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- self.pad_token = "_" if pad_token is None else pad_token
-
- super().__init__(
- train=train,
- subsample_fraction=subsample_fraction,
- transform=transform,
- target_transform=target_transform,
- init_token=init_token,
- pad_token=pad_token,
- eos_token=eos_token,
- lower=lower,
- )
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self.data.shape[1:] if self.data is not None else None
-
- @property
- def output_shape(self) -> Tuple:
- """Output shape of the data."""
- return (
- self.targets.shape[1:] + (self.num_classes,)
- if self.targets is not None
- else None
- )
-
- def load_or_generate_data(self) -> None:
- """Load or generate dataset data."""
- if not PROCESSED_DATA_FILENAME.exists():
- PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- logger.info("Downloading IAM lines...")
- download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
- with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- self._data = f[f"x_{self.split}"][:]
- self._targets = f[f"y_{self.split}"][:]
- self._subsample()
-
- def __repr__(self) -> str:
- """Print info about the dataset."""
- return (
- "IAM Lines Dataset\n" # pylint: disable=no-member
- f"Number classes: {self.num_classes}\n"
- f"Mapping: {self.mapper.mapping}\n"
- f"Data: {self.data.shape}\n"
- f"Targets: {self.targets.shape}\n"
- )
-
- def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
- """Fetches data, target pair of the dataset for a given and index or indices.
-
- Args:
- index (Union[int, Tensor]): Either a list or int of indices/index.
-
- Returns:
- Tuple[Tensor, Tensor]: Data target pair.
-
- """
- if torch.is_tensor(index):
- index = index.tolist()
-
- data = self.data[index]
- targets = self.targets[index]
-
- if self.transform:
- data = self.transform(data)
-
- if self.target_transform:
- targets = self.target_transform(targets)
-
- return data, targets
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index a93eb00..5d0fad6 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -5,7 +5,6 @@ The code is mostly stolen from:
https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
"""
-
import collections
import itertools
from pathlib import Path
diff --git a/text_recognizer/data/image_utils.py b/text_recognizer/data/image_utils.py
new file mode 100644
index 0000000..c2b8915
--- /dev/null
+++ b/text_recognizer/data/image_utils.py
@@ -0,0 +1,49 @@
+"""Image util functions for loading and saving images."""
+from pathlib import Path
+from typing import Union
+from urllib.request import urlopen
+
+import cv2
+import numpy as np
+from PIL import Image
+
+
+def read_image_pil(image_uri: Union[Path, str], grayscale: bool = False) -> Image:
+ """Return PIL image."""
+ image = Image.open(image_uri)
+ if grayscale:
+ image = image.convert("L")
+ return image
+
+
+def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.array:
+ """Read image_uri."""
+
+ if isinstance(image_uri, str):
+ image_uri = Path(image_uri)
+
+ def read_image_from_filename(image_filename: Path, imread_flag: int) -> np.array:
+ return cv2.imread(str(image_filename), imread_flag)
+
+ def read_image_from_url(image_url: Path, imread_flag: int) -> np.array:
+ url_response = urlopen(str(image_url)) # nosec
+ image_array = np.array(bytearray(url_response.read()), dtype=np.uint8)
+ return cv2.imdecode(image_array, imread_flag)
+
+ imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
+ image = None
+
+ if image_uri.exists():
+ image = read_image_from_filename(image_uri, imread_flag)
+ else:
+ image = read_image_from_url(image_uri, imread_flag)
+
+ if image is None:
+ raise ValueError(f"Could not load image at {image_uri}")
+
+ return image
+
+
+def write_image(image: np.ndarray, filename: Union[Path, str]) -> None:
+ """Write image to file."""
+ cv2.imwrite(str(filename), image)
diff --git a/text_recognizer/data/sentence_generator.py b/text_recognizer/data/sentence_generator.py
index 53b781c..f09703b 100644
--- a/text_recognizer/data/sentence_generator.py
+++ b/text_recognizer/data/sentence_generator.py
@@ -1,5 +1,4 @@
"""Downloading the Brown corpus with NLTK for sentence generating."""
-
import itertools
import re
import string
@@ -9,9 +8,9 @@ import nltk
from nltk.corpus.reader.util import ConcatenatedCorpusView
import numpy as np
-from text_recognizer.datasets.util import DATA_DIRNAME
+from text_recognizer.data.base_data_module import BaseDataModule
-NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk"
+NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk"
class SentenceGenerator:
@@ -47,7 +46,7 @@ class SentenceGenerator:
raise ValueError(
"Must provide max_length to this method or when making this object."
)
-
+
for _ in range(10):
try:
index = np.random.randint(0, len(self.word_start_indices) - 1)
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index b6a48f5..2291eec 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -1,148 +1,15 @@
"""Transforms for PyTorch datasets."""
from abc import abstractmethod
from pathlib import Path
-import random
from typing import Any, Optional, Union
from loguru import logger
-import numpy as np
-from PIL import Image
import torch
from torch import Tensor
-import torch.nn.functional as F
-from torchvision import transforms
-from torchvision.transforms import (
- ColorJitter,
- Compose,
- Normalize,
- RandomAffine,
- RandomHorizontalFlip,
- RandomRotation,
- ToPILImage,
- ToTensor,
-)
from text_recognizer.datasets.iam_preprocessor import Preprocessor
-from text_recognizer.datasets.util import EmnistMapper
-
-
-class RandomResizeCrop:
- """Image transform with random resize and crop applied.
-
- Stolen from
-
- https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
-
- """
-
- def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None:
- self.jitter = jitter
- self.ratio = ratio
-
- def __call__(self, img: np.ndarray) -> np.ndarray:
- """Applies random crop and rotation to an image."""
- w, h = img.size
-
- # pad with white:
- img = transforms.functional.pad(img, self.jitter, fill=255)
-
- # crop at random (x, y):
- x = self.jitter + random.randint(-self.jitter, self.jitter)
- y = self.jitter + random.randint(-self.jitter, self.jitter)
-
- # randomize aspect ratio:
- size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio)
- size = (h, int(size_w))
- img = transforms.functional.resized_crop(img, y, x, h, w, size)
- return img
-
-
-class Transpose:
- """Transposes the EMNIST image to the correct orientation."""
-
- def __call__(self, image: Image) -> np.ndarray:
- """Swaps axis."""
- return np.array(image).swapaxes(0, 1)
-
-
-class Resize:
- """Resizes a tensor to a specified width."""
-
- def __init__(self, width: int = 952) -> None:
- # The default is 952 because of the IAM dataset.
- self.width = width
-
- def __call__(self, image: Tensor) -> Tensor:
- """Resize tensor in the last dimension."""
- return F.interpolate(image, size=self.width, mode="nearest")
-
-
-class AddTokens:
- """Adds start of sequence and end of sequence tokens to target tensor."""
-
- def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- if self.init_token is not None:
- self.emnist_mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- )
- else:
- self.emnist_mapper = EmnistMapper(
- pad_token=self.pad_token, eos_token=self.eos_token,
- )
- self.pad_value = self.emnist_mapper(self.pad_token)
- self.eos_value = self.emnist_mapper(self.eos_token)
-
- def __call__(self, target: Tensor) -> Tensor:
- """Adds a sos token to the begining and a eos token to the end of a target sequence."""
- dtype, device = target.dtype, target.device
-
- # Find the where padding starts.
- pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
-
- target[pad_index] = self.eos_value
-
- if self.init_token is not None:
- self.sos_value = self.emnist_mapper(self.init_token)
- sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
- target = torch.cat([sos, target], dim=0)
-
- return target
-
-
-class ApplyContrast:
- """Sets everything below a threshold to zero, i.e. increase contrast."""
-
- def __init__(self, low: float = 0.0, high: float = 0.25) -> None:
- self.low = low
- self.high = high
-
- def __call__(self, x: Tensor) -> Tensor:
- """Apply mask binary mask to input tensor."""
- mask = x > np.random.RandomState().uniform(low=self.low, high=self.high)
- return x * mask
-
-
-class Unsqueeze:
- """Add a dimension to the tensor."""
-
- def __call__(self, x: Tensor) -> Tensor:
- """Adds dim."""
- return x.unsqueeze(0)
-
-
-class Squeeze:
- """Removes the first dimension of a tensor."""
-
- def __call__(self, x: Tensor) -> Tensor:
- """Removes first dim."""
- return x.squeeze(0)
-
-
+from text_recognizer.data.emnist import emnist_mapping
+
class ToLower:
"""Converts target to lower case."""
@@ -155,29 +22,14 @@ class ToLower:
class ToCharcters:
"""Converts integers to characters."""
- def __init__(
- self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True
- ) -> None:
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- if self.init_token is not None:
- self.emnist_mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- lower=lower,
- )
- else:
- self.emnist_mapper = EmnistMapper(
- pad_token=self.pad_token, eos_token=self.eos_token, lower=lower
- )
+ def __init__(self) -> None:
+ self.mapping, _, _ = emnist_mapping()
def __call__(self, y: Tensor) -> str:
"""Converts a Tensor to a str."""
return (
- "".join([self.emnist_mapper(int(i)) for i in y])
- .strip("_")
+ "".join([self.mapping(int(i)) for i in y])
+ .strip("<p>")
.replace(" ", "▁")
)
diff --git a/text_recognizer/util.py b/text_recognizer/util.py
deleted file mode 100644
index b431e22..0000000
--- a/text_recognizer/util.py
+++ /dev/null
@@ -1,52 +0,0 @@
-"""Utility functions for text_recognizer module."""
-import os
-from pathlib import Path
-from typing import Union
-from urllib.request import urlopen
-
-import cv2
-import numpy as np
-
-
-def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.ndarray:
- """Read image_uri."""
-
- def read_image_from_filename(image_filename: str, imread_flag: int) -> np.ndarray:
- return cv2.imread(str(image_filename), imread_flag)
-
- def read_image_from_url(image_url: str, imread_flag: int) -> np.ndarray:
- if image_url.lower().startswith("http"):
- url_response = urlopen(str(image_url))
- image_array = np.array(bytearray(url_response.read()), dtype=np.uint8)
- return cv2.imdecode(image_array, imread_flag)
- else:
- raise ValueError(
- "Url does not start with http, therefore not safe to open..."
- ) from None
-
- imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
- local_file = os.path.exists(image_uri)
- image = None
-
- if local_file:
- image = read_image_from_filename(image_uri, imread_flag)
- else:
- image = read_image_from_url(image_uri, imread_flag)
-
- if image is None:
- raise ValueError(f"Could not load image at {image_uri}")
-
- return image
-
-
-def rescale_image(image: np.ndarray) -> np.ndarray:
- """Rescale image from [0, 1] to [0, 255]."""
- if image.max() <= 1.0:
- image = 255 * (image - image.min()) / (image.max() - image.min())
- return image
-
-
-def write_image(image: np.ndarray, filename: Union[Path, str]) -> None:
- """Write image to file."""
- image = rescale_image(image)
- cv2.imwrite(str(filename), image)