summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/__init__.py12
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py32
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py38
-rw-r--r--src/text_recognizer/datasets/iam_dataset.py134
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py126
-rw-r--r--src/text_recognizer/datasets/iam_paragraphs_dataset.py322
-rw-r--r--src/text_recognizer/datasets/util.py99
7 files changed, 685 insertions, 78 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index 05f74f6..ede4541 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -10,15 +10,23 @@ from .emnist_lines_dataset import (
EmnistLinesDataset,
get_samples_by_character,
)
-from .util import fetch_data_loaders, Transpose
+from .iam_dataset import IamDataset
+from .iam_lines_dataset import IamLinesDataset
+from .iam_paragraphs_dataset import IamParagraphsDataset
+from .util import _download_raw_dataset, compute_sha256, download_url, Transpose
__all__ = [
+ "_download_raw_dataset",
+ "compute_sha256",
"construct_image_from_string",
"DATA_DIRNAME",
+ "download_url",
"EmnistDataset",
"EmnistMapper",
"EmnistLinesDataset",
- "fetch_data_loaders",
"get_samples_by_character",
+ "IamDataset",
+ "IamLinesDataset",
+ "IamParagraphsDataset",
"Transpose",
]
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 49ebad3..0715aae 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -152,8 +152,7 @@ class EmnistDataset(Dataset):
"""Loads the dataset and the mappings.
Args:
- train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to
- False.
+ train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
@@ -181,17 +180,37 @@ class EmnistDataset(Dataset):
self.seed = seed
self._mapper = EmnistMapper()
- self.input_shape = self._mapper.input_shape
+ self._input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
# Load dataset.
- self.data, self.targets = self.load_emnist_dataset()
+ self._data, self._targets = self.load_emnist_dataset()
+
+ @property
+ def data(self) -> Tensor:
+ """The input data."""
+ return self._data
+
+ @property
+ def targets(self) -> Tensor:
+ """The target data."""
+ return self._targets
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
@property
def mapper(self) -> EmnistMapper:
"""Returns the EmnistMapper."""
return self._mapper
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the inverse mapping from character to index."""
+ return self.mapper.inverse_mapping
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -220,11 +239,6 @@ class EmnistDataset(Dataset):
return data, targets
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "EmnistDataset"
-
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index b0617f5..656131a 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -9,8 +9,8 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
-from torch.utils.data import DataLoader, Dataset
-from torchvision.transforms import Compose, Normalize, ToTensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
from text_recognizer.datasets import (
DATA_DIRNAME,
@@ -20,6 +20,7 @@ from text_recognizer.datasets import (
)
from text_recognizer.datasets.sentence_generator import SentenceGenerator
from text_recognizer.datasets.util import Transpose
+from text_recognizer.networks import sliding_window
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
@@ -55,7 +56,7 @@ class EmnistLinesDataset(Dataset):
self.transform = transform
if self.transform is None:
- self.transform = Compose([ToTensor()])
+ self.transform = ToTensor()
self.target_transform = target_transform
if self.target_transform is None:
@@ -63,14 +64,14 @@ class EmnistLinesDataset(Dataset):
# Extract dataset information.
self._mapper = EmnistMapper()
- self.input_shape = self._mapper.input_shape
+ self._input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
self.max_length = max_length
self.min_overlap = min_overlap
self.max_overlap = max_overlap
self.num_samples = num_samples
- self.input_shape = (
+ self._input_shape = (
self.input_shape[0],
self.input_shape[1] * self.max_length,
)
@@ -84,6 +85,11 @@ class EmnistLinesDataset(Dataset):
# Load dataset.
self._load_or_generate_data()
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -112,11 +118,6 @@ class EmnistLinesDataset(Dataset):
return data, targets
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "EmnistLinesDataset"
-
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (
@@ -136,13 +137,18 @@ class EmnistLinesDataset(Dataset):
return self._mapper
@property
+ def mapping(self) -> Dict:
+ """Return EMNIST mapping from index to character."""
+ return self._mapper.mapping
+
+ @property
def data_filename(self) -> Path:
"""Path to the h5 file."""
filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
if self.train:
filename = "train_" + filename
else:
- filename = "val_" + filename
+ filename = "test_" + filename
return DATA_DIRNAME / filename
def _load_or_generate_data(self) -> None:
@@ -184,7 +190,7 @@ class EmnistLinesDataset(Dataset):
)
targets = convert_strings_to_categorical_labels(
- targets, self.emnist.inverse_mapping
+ targets, emnist.inverse_mapping
)
f.create_dataset("data", data=data, dtype="u1", compression="lzf")
@@ -322,13 +328,13 @@ def create_datasets(
min_overlap: float = 0,
max_overlap: float = 0.33,
num_train: int = 10000,
- num_val: int = 1000,
+ num_test: int = 1000,
) -> None:
"""Creates a training an validation dataset of Emnist lines."""
emnist_train = EmnistDataset(train=True, sample_to_balance=True)
- emnist_val = EmnistDataset(train=False, sample_to_balance=True)
- datasets = [emnist_train, emnist_val]
- num_samples = [num_train, num_val]
+ emnist_test = EmnistDataset(train=False, sample_to_balance=True)
+ datasets = [emnist_train, emnist_test]
+ num_samples = [num_train, num_test]
for num, train, dataset in zip(num_samples, [True, False], datasets):
emnist_lines = EmnistLinesDataset(
train=train,
diff --git a/src/text_recognizer/datasets/iam_dataset.py b/src/text_recognizer/datasets/iam_dataset.py
new file mode 100644
index 0000000..5e47350
--- /dev/null
+++ b/src/text_recognizer/datasets/iam_dataset.py
@@ -0,0 +1,134 @@
+"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities."""
+import os
+from typing import Any, Dict, List
+import zipfile
+
+from boltons.cacheutils import cachedproperty
+import defusedxml.ElementTree as ET
+from loguru import logger
+import toml
+from torch.utils.data import Dataset
+
+from text_recognizer.datasets import DATA_DIRNAME
+from text_recognizer.datasets.util import _download_raw_dataset
+
+RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
+METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
+EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb"
+
+DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
+LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates.
+
+
+class IamDataset(Dataset):
+ """IAM dataset.
+
+ "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
+ which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels."
+ From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
+
+ The data split we will use is
+ IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines.
+ The validation set has been merged into the train set.
+ The train set has 7,101 lines from 326 writers.
+ The test set has 1,861 lines from 128 writers.
+ The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
+
+ """
+
+ def __init__(self) -> None:
+ self.metadata = toml.load(METADATA_FILENAME)
+
+ def load_or_generate_data(self) -> None:
+ """Downloads IAM dataset if xml files does not exist."""
+ if not self.xml_filenames:
+ self._download_iam()
+
+ @property
+ def xml_filenames(self) -> List:
+ """List of xml filenames."""
+ return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
+
+ @property
+ def form_filenames(self) -> List:
+ """List of forms filenames."""
+ return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
+
+ def _download_iam(self) -> None:
+ curdir = os.getcwd()
+ os.chdir(RAW_DATA_DIRNAME)
+ _download_raw_dataset(self.metadata)
+ _extract_raw_dataset(self.metadata)
+ os.chdir(curdir)
+
+ @property
+ def form_filenames_by_id(self) -> Dict:
+ """Creates a dictionary with filenames as keys and forms as values."""
+ return {filename.stem: filename for filename in self.form_filenames}
+
+ @cachedproperty
+ def line_strings_by_id(self) -> Dict:
+ """Return a dict from name of IAM form to a list of line texts in it."""
+ return {
+ filename.stem: _get_line_strings_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
+
+ @cachedproperty
+ def line_regions_by_id(self) -> Dict:
+ """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it."""
+ return {
+ filename.stem: _get_line_regions_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
+
+ def __repr__(self) -> str:
+ """Print info about dataset."""
+ return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n"
+
+
+def _extract_raw_dataset(metadata: Dict) -> None:
+ logger.info("Extracting IAM data.")
+ with zipfile.ZipFile(metadata["filename"], "r") as zip_file:
+ zip_file.extractall()
+
+
+def _get_line_strings_from_xml_file(filename: str) -> List[str]:
+ """Get the text content of each line. Note that we replace " with "."""
+ xml_root_element = ET.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [el.attrib["text"].replace(""", '"') for el in xml_line_elements]
+
+
+def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
+ """Get the line region dict for each line."""
+ xml_root_element = ET.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [_get_line_region_from_xml_element(el) for el in xml_line_elements]
+
+
+def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]:
+ """Extracts coordinates for each line of text."""
+ # TODO: fix input!
+ word_elements = xml_line.findall("word/cmp")
+ x1s = [int(el.attrib["x"]) for el in word_elements]
+ y1s = [int(el.attrib["y"]) for el in word_elements]
+ x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
+ y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
+ return {
+ "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ }
+
+
+def main() -> None:
+ """Initializes the dataset and print info about the dataset."""
+ dataset = IamDataset()
+ dataset.load_or_generate_data()
+ print(dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
new file mode 100644
index 0000000..477f500
--- /dev/null
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -0,0 +1,126 @@
+"""IamLinesDataset class."""
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import h5py
+from loguru import logger
+import torch
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper
+from text_recognizer.datasets.util import compute_sha256, download_url
+
+
+PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
+PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
+PROCESSED_DATA_URL = (
+ "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
+)
+
+
+class IamLinesDataset(Dataset):
+ """IAM lines datasets for handwritten text lines."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ self.train = train
+ self.split = "train" if self.train else "test"
+ self._mapper = EmnistMapper()
+ self.num_classes = self.mapper.num_classes
+
+ # Set transforms.
+ self.transform = transform
+ if self.transform is None:
+ self.transform = ToTensor()
+
+ self.target_transform = target_transform
+ if self.target_transform is None:
+ self.target_transform = torch.tensor
+
+ self.subsample_fraction = subsample_fraction
+ self.data = None
+ self.targets = None
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
+ @property
+ def mapping(self) -> Dict:
+ """Return EMNIST mapping from index to character."""
+ return self._mapper.mapping
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self.data.shape[1:]
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return self.targets.shape[1:] + (self.num_classes,)
+
+ def __len__(self) -> int:
+ """Returns the length of the dataset."""
+ return len(self.data)
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ if not PROCESSED_DATA_FILENAME.exists():
+ PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ logger.info("Downloading IAM lines...")
+ download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ self.data = f[f"x_{self.split}"][:]
+ self.targets = f[f"y_{self.split}"][:]
+ self._subsample()
+
+ def _subsample(self) -> None:
+ """Only a fraction of the data will be loaded."""
+ if self.subsample_fraction is None:
+ return
+
+ num_samples = int(self.data.shape[0] * self.subsample_fraction)
+ self.data = self.data[:num_samples]
+ self.targets = self.targets[:num_samples]
+
+ def __repr__(self) -> str:
+ """Print info about the dataset."""
+ return (
+ "IAM Lines Dataset\n" # pylint: disable=no-member
+ f"Number classes: {self.num_classes}\n"
+ f"Mapping: {self.mapper.mapping}\n"
+ f"Data: {self.data.shape}\n"
+ f"Targets: {self.targets.shape}\n"
+ )
+
+ def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
new file mode 100644
index 0000000..d65b346
--- /dev/null
+++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -0,0 +1,322 @@
+"""IamParagraphsDataset class and functions for data processing."""
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import click
+import cv2
+import h5py
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer import util
+from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper
+from text_recognizer.datasets.iam_dataset import IamDataset
+from text_recognizer.datasets.util import compute_sha256, download_url
+
+INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
+DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
+PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs"
+CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops"
+GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt"
+
+PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines.
+TEST_FRACTION = 0.2
+SEED = 4711
+
+
+class IamParagraphsDataset(Dataset):
+ """IAM Paragraphs dataset for paragraphs of handwritten text.
+
+ TODO: __getitem__, __len__, get_data_target_from_id
+
+ """
+
+ def __init__(
+ self,
+ train: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+
+ # Load Iam dataset.
+ self.iam_dataset = IamDataset()
+
+ self.train = train
+ self.split = "train" if self.train else "test"
+ self.num_classes = 3
+ self._input_shape = (256, 256)
+ self._output_shape = self._input_shape + (self.num_classes,)
+ self.subsample_fraction = subsample_fraction
+
+ # Set transforms.
+ self.transform = transform
+ if self.transform is None:
+ self.transform = ToTensor()
+
+ self.target_transform = target_transform
+ if self.target_transform is None:
+ self.target_transform = torch.tensor
+
+ self._data = None
+ self._targets = None
+ self._ids = None
+
+ def __len__(self) -> int:
+ """Returns the length of the dataset."""
+ return len(self.data)
+
+ def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return self._output_shape
+
+ @property
+ def data(self) -> Tensor:
+ """The input data."""
+ return self._data
+
+ @property
+ def targets(self) -> Tensor:
+ """The target data."""
+ return self._targets
+
+ @property
+ def ids(self) -> Tensor:
+ """Ids of the dataset."""
+ return self._ids
+
+ def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]:
+ """Get data target pair from id."""
+ ind = self.ids.index(id_)
+ return self.data[ind], self.targets[ind]
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ num_actual = len(list(CROPS_DIRNAME.glob("*.jpg")))
+ num_targets = len(self.iam_dataset.line_regions_by_id)
+
+ if num_actual < num_targets - 2:
+ self._process_iam_paragraphs()
+
+ self._data, self._targets, self._ids = _load_iam_paragraphs()
+ self._get_random_split()
+ self._subsample()
+
+ def _get_random_split(self) -> None:
+ np.random.seed(SEED)
+ num_train = int((1 - TEST_FRACTION) * self.data.shape[0])
+ indices = np.random.permutation(self.data.shape[0])
+ train_indices, test_indices = indices[:num_train], indices[num_train:]
+ if self.train:
+ self._data = self.data[train_indices]
+ self._targets = self.targets[train_indices]
+ else:
+ self._data = self.data[test_indices]
+ self._targets = self.targets[test_indices]
+
+ def _process_iam_paragraphs(self) -> None:
+ """Crop the part with the text.
+
+ For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are
+ self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel
+ corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line
+ """
+ crop_dims = self._decide_on_crop_dims()
+ CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
+ DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
+ GT_DIRNAME.mkdir(parents=True, exist_ok=True)
+ logger.info(
+ f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}"
+ )
+ for filename in self.iam_dataset.form_filenames:
+ id_ = filename.stem
+ line_region = self.iam_dataset.line_regions_by_id[id_]
+ _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape)
+
+ def _decide_on_crop_dims(self) -> Tuple[int, int]:
+ """Decide on the dimensions to crop out of the form image.
+
+ Since image width is larger than a comfortable crop around the longest paragraph,
+ we will make the crop a square form factor.
+ And since the found dimensions 610x610 are pretty close to 512x512,
+ we might as well resize crops and make it exactly that, which lets us
+ do all kinds of power-of-2 pooling and upsampling should we choose to.
+
+ Returns:
+ Tuple[int, int]: A tuple of crop dimensions.
+
+ Raises:
+ RuntimeError: When max crop height is larger than max crop width.
+
+ """
+
+ sample_form_filename = self.iam_dataset.form_filenames[0]
+ sample_image = util.read_image(sample_form_filename, grayscale=True)
+ max_crop_width = sample_image.shape[1]
+ max_crop_height = _get_max_paragraph_crop_height(
+ self.iam_dataset.line_regions_by_id
+ )
+ if not max_crop_height <= max_crop_width:
+ raise RuntimeError(
+ f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}"
+ )
+
+ crop_dims = (max_crop_width, max_crop_width)
+ logger.info(
+ f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}."
+ )
+ logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
+ return crop_dims
+
+ def _subsample(self) -> None:
+ """Only this fraction of the data will be loaded."""
+ if self.subsample_fraction is None:
+ return
+ num_subsample = int(self.data.shape[0] * self.subsample_fraction)
+ self.data = self.data[:num_subsample]
+ self.targets = self.targets[:num_subsample]
+
+ def __repr__(self) -> str:
+ """Return info about the dataset."""
+ return (
+ "IAM Paragraph Dataset\n" # pylint: disable=no-member
+ f"Num classes: {self.num_classes}\n"
+ f"Data: {self.data.shape}\n"
+ f"Targets: {self.targets.shape}\n"
+ )
+
+
+def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int:
+ heights = []
+ for regions in line_regions_by_id.values():
+ min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER
+ max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER
+ height = max_y2 - min_y1
+ heights.append(height)
+ return max(heights)
+
+
+def _crop_paragraph_image(
+ filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple
+) -> None:
+ image = util.read_image(filename, grayscale=True)
+
+ min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER
+ max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER
+ height = max_y2 - min_y1
+ crop_height = crop_dims[0]
+ buffer = (crop_height - height) // 2
+
+ # Generate image crop.
+ image_crop = 255 * np.ones(crop_dims, dtype=np.uint8)
+ try:
+ image_crop[buffer : buffer + height] = image[min_y1:max_y2]
+ except Exception as e: # pylint: disable=broad-except
+ logger.error(f"Rescued {filename}: {e}")
+ return
+
+ # Generate ground truth.
+ gt_image = np.zeros_like(image_crop, dtype=np.uint8)
+ for index, region in enumerate(line_regions):
+ gt_image[
+ (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer),
+ region["x1"] : region["x2"],
+ ] = (index % 2 + 1)
+
+ # Generate image for debugging.
+ import matplotlib.pyplot as plt
+
+ cmap = plt.get_cmap("Set1")
+ image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop])
+ for index, region in enumerate(line_regions):
+ color = [255 * _ for _ in cmap(index)[:-1]]
+ cv2.rectangle(
+ image_crop_for_debug,
+ (region["x1"], region["y1"] - min_y1 + buffer),
+ (region["x2"], region["y2"] - min_y1 + buffer),
+ color,
+ 3,
+ )
+ image_crop_for_debug = cv2.resize(
+ image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA
+ )
+ util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg")
+
+ image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA)
+ util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg")
+
+ gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST)
+ util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png")
+
+
+def _load_iam_paragraphs() -> None:
+ logger.info("Loading IAM paragraph crops and ground truth from image files...")
+ images = []
+ gt_images = []
+ ids = []
+ for filename in CROPS_DIRNAME.glob("*.jpg"):
+ id_ = filename.stem
+ image = util.read_image(filename, grayscale=True)
+ image = 1.0 - image / 255
+
+ gt_filename = GT_DIRNAME / f"{id_}.png"
+ gt_image = util.read_image(gt_filename, grayscale=True)
+
+ images.append(image)
+ gt_images.append(gt_image)
+ ids.append(id_)
+ images = np.array(images).astype(np.float32)
+ gt_images = np.array(gt_images).astype(np.uint8)
+ ids = np.array(ids)
+ return images, gt_images, ids
+
+
+@click.command()
+@click.option(
+ "--subsample_fraction",
+ type=float,
+ default=0.0,
+ help="The subsampling factor of the dataset.",
+)
+def main(subsample_fraction: float) -> None:
+ """Load dataset and print info."""
+ dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
+ dataset.load_or_generate_data()
+ print(dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 76bd85f..dd16bed 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -1,10 +1,17 @@
"""Util functions for datasets."""
+import hashlib
import importlib
-from typing import Callable, Dict, List, Type
+import os
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Type, Union
+from urllib.request import urlopen, urlretrieve
+import cv2
+from loguru import logger
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
class Transpose:
@@ -15,58 +22,48 @@ class Transpose:
return np.array(image).swapaxes(0, 1)
-def fetch_data_loaders(
- splits: List[str],
- dataset: str,
- dataset_args: Dict,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
-) -> Dict[str, DataLoader]:
- """Fetches DataLoaders for given split(s) as a dictionary.
-
- Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then
- calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value.
-
- Args:
- splits (List[str]): One or both of the dataset splits "train" and "val".
- dataset (str): The name of the dataset.
- dataset_args (Dict): The dataset arguments.
- batch_size (int): How many samples per batch to load. Defaults to 128.
- shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
- num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
- loaded in the main process. Defaults to 0.
- cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
- them. Defaults to True.
-
- Returns:
- Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value.
+def compute_sha256(filename: Union[Path, str]) -> str:
+ """Returns the SHA256 checksum of a file."""
+ with open(filename, "rb") as f:
+ return hashlib.sha256(f.read()).hexdigest()
- """
-
- def check_dataset_args(args: Dict, split: str) -> Dict:
- """Adds train flag to the dataset args."""
- args["train"] = True if split == "train" else False
- return args
-
- # Import dataset module.
- datasets_module = importlib.import_module("text_recognizer.datasets")
- dataset_ = getattr(datasets_module, dataset)
- data_loaders = {}
+class TqdmUpTo(tqdm):
+ """TQDM progress bar when downloading files.
- for split in ["train", "val"]:
- if split in splits:
+ From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
- data_loader = DataLoader(
- dataset=dataset_(**check_dataset_args(dataset_args, split)),
- batch_size=batch_size,
- shuffle=shuffle,
- num_workers=num_workers,
- pin_memory=cuda,
- )
-
- data_loaders[split] = data_loader
+ """
- return data_loaders
+ def update_to(
+ self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
+ ) -> None:
+ """Updates the progress bar.
+
+ Args:
+ blocks (int): Number of blocks transferred so far. Defaults to 1.
+ block_size (int): Size of each block, in tqdm units. Defaults to 1.
+ total_size (Optional[int]): Total size in tqdm units. Defaults to None.
+ """
+ if total_size is not None:
+ self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.update(blocks * block_size - self.n)
+
+
+def download_url(url: str, filename: str) -> None:
+ """Downloads a file from url to filename, with a progress bar."""
+ with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
+ urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
+
+
+def _download_raw_dataset(metadata: Dict) -> None:
+ if os.path.exists(metadata["filename"]):
+ return
+ logger.info(f"Downloading raw dataset from {metadata['url']}...")
+ download_url(metadata["url"], metadata["filename"])
+ logger.info("Computing SHA-256...")
+ sha256 = compute_sha256(metadata["filename"])
+ if sha256 != metadata["sha256"]:
+ raise ValueError(
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )