summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/datasets')
-rw-r--r--text_recognizer/datasets/__init__.py38
-rw-r--r--text_recognizer/datasets/base_data_module.py36
-rw-r--r--text_recognizer/datasets/base_dataset.py21
-rw-r--r--text_recognizer/datasets/dataset.py152
-rw-r--r--text_recognizer/datasets/download_utils.py6
-rw-r--r--text_recognizer/datasets/emnist.py88
-rw-r--r--text_recognizer/datasets/emnist_essentials.json1
-rw-r--r--text_recognizer/datasets/emnist_lines.py184
8 files changed, 278 insertions, 248 deletions
diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py
index a6c1c59..2727b20 100644
--- a/text_recognizer/datasets/__init__.py
+++ b/text_recognizer/datasets/__init__.py
@@ -1,39 +1 @@
"""Dataset modules."""
-from .emnist_dataset import EmnistDataset
-from .emnist_lines_dataset import (
- construct_image_from_string,
- EmnistLinesDataset,
- get_samples_by_character,
-)
-from .iam_dataset import IamDataset
-from .iam_lines_dataset import IamLinesDataset
-from .iam_paragraphs_dataset import IamParagraphsDataset
-from .iam_preprocessor import load_metadata, Preprocessor
-from .transforms import AddTokens, Transpose
-from .util import (
- _download_raw_dataset,
- compute_sha256,
- DATA_DIRNAME,
- download_url,
- EmnistMapper,
- ESSENTIALS_FILENAME,
-)
-
-__all__ = [
- "_download_raw_dataset",
- "AddTokens",
- "compute_sha256",
- "construct_image_from_string",
- "DATA_DIRNAME",
- "download_url",
- "EmnistDataset",
- "EmnistMapper",
- "EmnistLinesDataset",
- "get_samples_by_character",
- "load_metadata",
- "IamDataset",
- "IamLinesDataset",
- "IamParagraphsDataset",
- "Preprocessor",
- "Transpose",
-]
diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py
index 09a0a43..830b39b 100644
--- a/text_recognizer/datasets/base_data_module.py
+++ b/text_recognizer/datasets/base_data_module.py
@@ -16,7 +16,7 @@ def load_and_print_info(data_module_class: type) -> None:
class BaseDataModule(pl.LightningDataModule):
"""Base PyTorch Lightning DataModule."""
-
+
def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
super().__init__()
self.batch_size = batch_size
@@ -34,13 +34,17 @@ class BaseDataModule(pl.LightningDataModule):
def config(self) -> Dict:
"""Return important settings of the dataset."""
- return {"input_dim": self.dims, "output_dims": self.output_dims, "mapping": self.mapping}
+ return {
+ "input_dim": self.dims,
+ "output_dims": self.output_dims,
+ "mapping": self.mapping,
+ }
def prepare_data(self) -> None:
"""Prepare data for training."""
pass
- def setup(self, stage: Any = None) -> None:
+ def setup(self, stage: str = None) -> None:
"""Split into train, val, test, and set dims.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and
@@ -54,16 +58,32 @@ class BaseDataModule(pl.LightningDataModule):
self.data_val = None
self.data_test = None
-
def train_dataloader(self) -> DataLoader:
"""Retun DataLoader for train data."""
- return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
+ return DataLoader(
+ self.data_train,
+ shuffle=True,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
def val_dataloader(self) -> DataLoader:
"""Return DataLoader for val data."""
- return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
+ return DataLoader(
+ self.data_val,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
def test_dataloader(self) -> DataLoader:
"""Return DataLoader for val data."""
- return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
-
+ return DataLoader(
+ self.data_test,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py
index 7322d7f..a004b8d 100644
--- a/text_recognizer/datasets/base_dataset.py
+++ b/text_recognizer/datasets/base_dataset.py
@@ -17,12 +17,14 @@ class BaseDataset(Dataset):
target_transform (Callable): Fucntion that takes a target and applies
target transforms.
"""
- def __init__(self,
- data: Union[Sequence, Tensor],
- targets: Union[Sequence, Tensor],
- transform: Callable = None,
- target_transform: Callable = None,
- ) -> None:
+
+ def __init__(
+ self,
+ data: Union[Sequence, Tensor],
+ targets: Union[Sequence, Tensor],
+ transform: Callable = None,
+ target_transform: Callable = None,
+ ) -> None:
if len(data) != len(targets):
raise ValueError("Data and targets must be of equal length.")
self.data = data
@@ -30,11 +32,10 @@ class BaseDataset(Dataset):
self.transform = transform
self.target_transform = target_transform
-
def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.data)
-
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Return a datum and its target, after processing by transforms.
@@ -56,7 +57,9 @@ class BaseDataset(Dataset):
return datum, target
-def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> Tensor:
+def convert_strings_to_labels(
+ strings: Sequence[str], mapping: Dict[str, int], length: int
+) -> Tensor:
"""
Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <S> and </S> tokens,
and padded wiht the <P> token.
diff --git a/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py
deleted file mode 100644
index e794605..0000000
--- a/text_recognizer/datasets/dataset.py
+++ /dev/null
@@ -1,152 +0,0 @@
-"""Abstract dataset class."""
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import torch
-from torch import Tensor
-from torch.utils import data
-from torchvision.transforms import ToTensor
-
-import text_recognizer.datasets.transforms as transforms
-from text_recognizer.datasets.util import EmnistMapper
-
-
-class Dataset(data.Dataset):
- """Abstract class for with common methods for all datasets."""
-
- def __init__(
- self,
- train: bool,
- subsample_fraction: float = None,
- transform: Optional[List[Dict]] = None,
- target_transform: Optional[List[Dict]] = None,
- init_token: Optional[str] = None,
- pad_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- """Initialization of Dataset class.
-
- Args:
- train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
- subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
- transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None.
- target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None.
- init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
- pad_token (Optional[str]): String representing the pad token. Defaults to None.
- eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
- lower (bool): Only use lower case letters. Defaults to False.
-
- Raises:
- ValueError: If subsample_fraction is not None and outside the range (0, 1).
-
- """
- self.train = train
- self.split = "train" if self.train else "test"
-
- if subsample_fraction is not None:
- if not 0.0 < subsample_fraction < 1.0:
- raise ValueError("The subsample fraction must be in (0, 1).")
- self.subsample_fraction = subsample_fraction
-
- self._mapper = EmnistMapper(
- init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower
- )
- self._input_shape = self._mapper.input_shape
- self._output_shape = self._mapper._num_classes
- self.num_classes = self.mapper.num_classes
-
- # Set transforms.
- self.transform = self._configure_transform(transform)
- self.target_transform = self._configure_target_transform(target_transform)
-
- self._data = None
- self._targets = None
-
- def _configure_transform(self, transform: List[Dict]) -> transforms.Compose:
- transform_list = []
- if transform is not None:
- for t in transform:
- t_type = t["type"]
- t_args = t["args"] or {}
- transform_list.append(getattr(transforms, t_type)(**t_args))
- else:
- transform_list.append(ToTensor())
- return transforms.Compose(transform_list)
-
- def _configure_target_transform(
- self, target_transform: List[Dict]
- ) -> transforms.Compose:
- target_transform_list = [torch.tensor]
- if target_transform is not None:
- for t in target_transform:
- t_type = t["type"]
- t_args = t["args"] or {}
- target_transform_list.append(getattr(transforms, t_type)(**t_args))
- return transforms.Compose(target_transform_list)
-
- @property
- def data(self) -> Tensor:
- """The input data."""
- return self._data
-
- @property
- def targets(self) -> Tensor:
- """The target data."""
- return self._targets
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- @property
- def output_shape(self) -> Tuple:
- """Output shape of the data."""
- return self._output_shape
-
- @property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def mapping(self) -> Dict:
- """Return EMNIST mapping from index to character."""
- return self._mapper.mapping
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the inverse mapping from character to index."""
- return self.mapper.inverse_mapping
-
- def _subsample(self) -> None:
- """Only this fraction of the data will be loaded."""
- if self.subsample_fraction is None:
- return
- num_subsample = int(self.data.shape[0] * self.subsample_fraction)
- self._data = self.data[:num_subsample]
- self._targets = self.targets[:num_subsample]
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
-
- def load_or_generate_data(self) -> None:
- """Load or generate dataset data."""
- raise NotImplementedError
-
- def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
- """Fetches samples from the dataset.
-
- Args:
- index (Union[int, torch.Tensor]): The indices of the samples to fetch.
-
- Raises:
- NotImplementedError: If the method is not implemented in child class.
-
- """
- raise NotImplementedError
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- raise NotImplementedError
diff --git a/text_recognizer/datasets/download_utils.py b/text_recognizer/datasets/download_utils.py
index 7a2cab8..e3dc68c 100644
--- a/text_recognizer/datasets/download_utils.py
+++ b/text_recognizer/datasets/download_utils.py
@@ -63,11 +63,11 @@ def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
if filename.exists():
return
logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
- _download_url(metadata["url"], filename)
+ _download_url(metadata["url"], filename)
logger.info("Computing the SHA-256...")
sha256 = _compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError(
- "Downloaded data file SHA-256 does not match that listed in metadata document."
- )
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )
return filename
diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py
index e99dbfd..7c208c4 100644
--- a/text_recognizer/datasets/emnist.py
+++ b/text_recognizer/datasets/emnist.py
@@ -15,20 +15,23 @@ from torch.utils.data import random_split
from torchvision import transforms
from text_recognizer.datasets.base_dataset import BaseDataset
-from text_recognizer.datasets.base_data_module import BaseDataModule, load_print_info
+from text_recognizer.datasets.base_data_module import (
+ BaseDataModule,
+ load_and_print_info,
+)
from text_recognizer.datasets.download_utils import download_dataset
SEED = 4711
NUM_SPECIAL_TOKENS = 4
-SAMPLE_TO_BALANCE = True
+SAMPLE_TO_BALANCE = True
RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist"
-PROCESSED_DATA_DIRNAME = BaseDataset.data_dirname() / "processed" / "emnist"
+PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
-ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnsit_essentials.json"
+ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
class EMNIST(BaseDataModule):
@@ -41,7 +44,9 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- def __init__(self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8) -> None:
+ def __init__(
+ self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
+ ) -> None:
super().__init__(batch_size, num_workers)
if not ESSENTIALS_FILENAME.exists():
_download_and_process_emnist()
@@ -64,20 +69,21 @@ class EMNIST(BaseDataModule):
def setup(self, stage: str = None) -> None:
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- data = f["x_train"][:]
- targets = f["y_train"][:]
-
- dataset_train = BaseDataset(data, targets, transform=self.transform)
+ self.x_train = f["x_train"][:]
+ self.y_train = f["y_train"][:]
+
+ dataset_train = BaseDataset(self.x_train, self.y_train, transform=self.transform)
train_size = int(self.train_fraction * len(dataset_train))
val_size = len(dataset_train) - train_size
- self.data_train, self.data_val = random_split(dataset_train, [train_size, val_size], generator=torch.Generator())
+ self.data_train, self.data_val = random_split(
+ dataset_train, [train_size, val_size], generator=torch.Generator()
+ )
if stage == "test" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- data = f["x_test"][:]
- targets = f["y_test"][:]
- self.data_test = BaseDataset(data, targets, transform=self.transform)
-
+ self.x_test = f["x_test"][:]
+ self.y_test = f["y_test"][:]
+ self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform)
def __repr__(self) -> str:
basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n"
@@ -111,9 +117,15 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
logger.info("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
- x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ x_train = (
+ data["dataset"]["train"][0, 0]["images"][0, 0]
+ .reshape(-1, 28, 28)
+ .swapaxes(1, 2)
+ )
y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
- x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ x_test = (
+ data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ )
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
if SAMPLE_TO_BALANCE:
@@ -121,7 +133,6 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
-
logger.info("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
@@ -154,7 +165,7 @@ def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.nda
all_sampled_indices.append(sampled_indices)
indices = np.concatenate(all_sampled_indices)
x_sampled = x[indices]
- y_sampled= y[indices]
+ y_sampled = y[indices]
return x_sampled, y_sampled
@@ -162,24 +173,24 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
"""Augment the mapping with extra symbols."""
# Extra characters from the IAM dataset.
iam_characters = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
# Also add special tokens for:
# - CTC blank token at index 0
@@ -190,5 +201,6 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters]
-if __name__ == "__main__":
- load_print_info(EMNIST)
+def download_emnist() -> None:
+ """Download dataset from internet, if it does not exists, and displays info."""
+ load_and_print_info(EMNIST)
diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json
new file mode 100644
index 0000000..100b36a
--- /dev/null
+++ b/text_recognizer/datasets/emnist_essentials.json
@@ -0,0 +1 @@
+{"characters": ["<b>", "<s>", "</s>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} \ No newline at end of file
diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py
new file mode 100644
index 0000000..ae23feb
--- /dev/null
+++ b/text_recognizer/datasets/emnist_lines.py
@@ -0,0 +1,184 @@
+"""Dataset of generated text from EMNIST characters."""
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, Sequence
+
+import h5py
+from loguru import logger
+import numpy as np
+import torch
+from torchvision import transforms
+
+from text_recognizer.datasets.base_dataset import BaseDataset
+from text_recognizer.datasets.base_data_module import BaseDataModule
+from text_recognizer.datasets.emnist import EMNIST
+from text_recognizer.datasets.sentence_generator import SentenceGenerator
+
+
+DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines"
+ESSENTIALS_FILENAME = (
+ Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json"
+)
+
+SEED = 4711
+IMAGE_HEIGHT = 56
+IMAGE_WIDTH = 1024
+IMAGE_X_PADDING = 28
+MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
+
+
+class EMNISTLines(BaseDataModule):
+ """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,"""
+
+ def __init__(
+ self,
+ augment: bool = True,
+ batch_size: int = 128,
+ num_workers: int = 0,
+ max_length: int = 32,
+ min_overlap: float = 0.0,
+ max_overlap: float = 0.33,
+ num_train: int = 10_000,
+ num_val: int = 2_000,
+ num_test: int = 2_000,
+ ) -> None:
+ super().__init__(batch_size, num_workers)
+
+ self.augment = augment
+ self.max_length = max_length
+ self.min_overlap = min_overlap
+ self.max_overlap = max_overlap
+ self.num_train = num_train
+ self.num_val = num_val
+ self.num_test = num_test
+
+ self.emnist = EMNIST()
+ self.mapping = self.emnist.mapping
+ max_width = int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING
+
+ if max_width <= IMAGE_WIDTH:
+ raise ValueError("max_width greater than IMAGE_WIDTH")
+
+ self.dims = (
+ self.emnist.dims[0],
+ self.emnist.dims[1],
+ self.emnist.dims[2] * self.max_length,
+ )
+
+ if self.max_length <= MAX_OUTPUT_LENGTH:
+ raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
+
+ self.output_dims = (MAX_OUTPUT_LENGTH, 1)
+ self.data_train = None
+ self.data_val = None
+ self.data_test = None
+
+ @property
+ def data_filename(self) -> Path:
+ """Return name of dataset."""
+ return (
+ DATA_DIRNAME
+ / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5"
+ )
+
+ def prepare_data(self) -> None:
+ if self.data_filename.exists():
+ return
+ np.random.seed(SEED)
+ self._generate_data("train")
+ self._generate_data("val")
+ self._generate_data("test")
+
+ def setup(self, stage: str = None) -> None:
+ logger.info("EMNISTLinesDataset loading data from HDF5...")
+ if stage == "fit" or stage is None:
+ with h5py.File(self.data_filename, "r") as f:
+ x_train = f["x_train"][:]
+ y_train = torch.LongTensor(f["y_train"][:])
+ x_val = f["x_val"][:]
+ y_val = torch.LongTensor(f["y_val"][:])
+
+ self.data_train = BaseDataset(x_train, y_train, transform=_get_transform(augment=self.augment))
+ self.data_val = BaseDataset(x_val, y_val, transform=_get_transform(augment=self.augment))
+
+ if stage == "test" or stage is None:
+ with h5py.File(self.data_filename, "r") as f:
+ x_test = f["x_test"][:]
+ y_test = torch.LongTensor(f["y_test"][:])
+
+ self.data_train = BaseDataset(x_test, y_test, transform=_get_transform(augment=False))
+
+ def __repr__(self) -> str:
+ """Return str about dataset."""
+ basic = (
+ "EMNISTLines2 Dataset\n" # pylint: disable=no-member
+ f"Min overlap: {self.min_overlap}\n"
+ f"Max overlap: {self.max_overlap}\n"
+ f"Num classes: {len(self.mapping)}\n"
+ f"Dims: {self.dims}\n"
+ f"Output dims: {self.output_dims}\n"
+ )
+
+ if not any([self.data_train, self.data_val, self.data_test]):
+ return basic
+
+ x, y = next(iter(self.train_dataloader()))
+ data = (
+ f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
+ f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
+ )
+ return basic + data
+
+ def _generate_data(self, split: str) -> None:
+ logger.info(f"EMNISTLines generating data for {split}...")
+ sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract by 2 because start/end token
+
+ emnist = self.emnist
+ emnist.prepare_data()
+ emnist.setup()
+
+ if split == "train":
+ samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping)
+ num = self.num_train
+ elif split == "val":
+ samples_by_char = _get_samples_by_char(emnist.x_train, emnist.y_train, emnist.mapping)
+ num = self.num_val
+ elif split == "test":
+ samples_by_char = _get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping)
+ num = self.num_test
+
+ DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ with h5py.File(self.data_filename, "w") as f:
+ x, y = _create_dataset_of_images(
+ num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.dims
+ )
+ y = _convert_strings_to_labels(
+ y,
+ emnist.inverse_mapping,
+ length=MAX_OUTPUT_LENGTH
+ )
+ f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
+ f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
+
+def _get_samples_by_char(samples: np.ndarray, labels: np.ndarray, mapping: Dict) -> defaultdict:
+ samples_by_char = defaultdict(list)
+ for sample, label in zip(samples, labels):
+ samples_by_char[mapping[label]].append(sample)
+ return samples_by_char
+
+
+def _construct_image_from_string():
+ pass
+
+
+def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict):
+ pass
+
+
+def _create_dataset_of_images(num_samples: int, samples_by_char: defaultdict, sentence_generator: SentenceGenerator, min_overlap: float, max_overlap: float, dims: Tuple) -> Tuple[torch.Tensor, torch.Tensor]:
+ images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2]))
+ labels = []
+ for n in range(num_samples):
+ label = sentence_generator.generate()
+ crop = _construct_image_from_string()