summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-08-03 23:33:34 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-08-03 23:33:34 +0200
commit07dd14116fe1d8148fb614b160245287533620fc (patch)
tree63395d88b17a14ad453c52889fcf541e6cbbdd3e /src/text_recognizer/datasets
parent704451318eb6b0b600ab314cb5aabfac82416bda (diff)
Working Emnist lines dataset.
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/__init__.py24
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py279
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py326
-rw-r--r--src/text_recognizer/datasets/sentence_generator.py81
-rw-r--r--src/text_recognizer/datasets/util.py11
5 files changed, 645 insertions, 76 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index 795be90..bfa6a02 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,4 +1,24 @@
"""Dataset modules."""
-from .emnist_dataset import EmnistDataLoader
+from .emnist_dataset import (
+ DATA_DIRNAME,
+ EmnistDataLoaders,
+ EmnistDataset,
+)
+from .emnist_lines_dataset import (
+ construct_image_from_string,
+ EmnistLinesDataset,
+ get_samples_by_character,
+)
+from .sentence_generator import SentenceGenerator
+from .util import Transpose
-__all__ = ["EmnistDataLoader"]
+__all__ = [
+ "construct_image_from_string",
+ "DATA_DIRNAME",
+ "EmnistDataset",
+ "EmnistDataLoaders",
+ "EmnistLinesDataset",
+ "get_samples_by_character",
+ "SentenceGenerator",
+ "Transpose",
+]
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index b92b57d..525df95 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -1,29 +1,23 @@
-"""Fetches a PyTorch DataLoader with the EMNIST dataset."""
+"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9)."""
import json
from pathlib import Path
-from typing import Callable, Dict, List, Optional, Type
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from loguru import logger
import numpy as np
from PIL import Image
-from torch.utils.data import DataLoader
+import torch
+from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import EMNIST
-from torchvision.transforms import Compose, ToTensor
+from torchvision.transforms import Compose, Normalize, ToTensor
+from text_recognizer.datasets.util import Transpose
DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
-class Transpose:
- """Transposes the EMNIST image to the correct orientation."""
-
- def __call__(self, image: Image) -> np.ndarray:
- """Swaps axis."""
- return np.array(image).swapaxes(0, 1)
-
-
def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
"""Extract and saves EMNIST essentials."""
labels = emnsit_dataset.classes
@@ -45,14 +39,187 @@ def download_emnist() -> None:
save_emnist_essentials(dataset)
-def load_emnist_mapping() -> Dict[int, str]:
+def _load_emnist_essentials() -> Dict:
"""Load the EMNIST mapping."""
with open(str(ESSENTIALS_FILENAME)) as f:
essentials = json.load(f)
- return dict(essentials["mapping"])
+ return essentials
+
+
+def _augment_emnist_mapping(mapping: Dict) -> Dict:
+ """Augment the mapping with extra symbols."""
+ # Extra symbols in IAM dataset
+ extra_symbols = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # padding symbol
+ extra_symbols.append("_")
+
+ max_key = max(mapping.keys())
+ extra_mapping = {}
+ for i, symbol in enumerate(extra_symbols):
+ extra_mapping[max_key + 1 + i] = symbol
+
+ return {**mapping, **extra_mapping}
+
+
+class EmnistDataset(Dataset):
+ """This is a class for resampling and subsampling the PyTorch EMNIST dataset."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ sample_to_balance: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ seed: int = 4711,
+ ) -> None:
+ """Loads the dataset and the mappings.
+
+ Args:
+ train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to
+ False.
+ sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
+ subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
+ transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
+ target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+ seed (int): Seed number. Defaults to 4711.
+
+ Raises:
+ ValueError: If subsample_fraction is not None and outside the range (0, 1).
+
+ """
+
+ self.train = train
+ self.sample_to_balance = sample_to_balance
+ if subsample_fraction is not None:
+ if not 0.0 < subsample_fraction < 1.0:
+ raise ValueError("The subsample fraction must be in (0, 1).")
+ self.subsample_fraction = subsample_fraction
+ self.transform = transform
+ if self.transform is None:
+ self.transform = Compose([Transpose(), ToTensor()])
+
+ self.target_transform = target_transform
+ self.seed = seed
+
+ # Load dataset infromation.
+ essentials = _load_emnist_essentials()
+ self.mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
+ self.inverse_mapping = {v: k for k, v in self.mapping.items()}
+ self.num_classes = len(self.mapping)
+ self.input_shape = essentials["input_shape"]
+
+ # Placeholders
+ self.data = None
+ self.targets = None
+
+ def __len__(self) -> int:
+ """Returns the length of the dataset."""
+ return len(self.data)
+
+ def __getitem__(
+ self, index: Union[int, torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Fetches samples from the dataset.
+
+ Args:
+ index (Union[int, torch.Tensor]): The indices of the samples to fetch.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Data target tuple.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ return (
+ "EMNIST Dataset\n"
+ f"Num classes: {self.num_classes}\n"
+ f"Mapping: {self.mapping}\n"
+ f"Input shape: {self.input_shape}\n"
+ )
+
+ def _sample_to_balance(self) -> None:
+ """Because the dataset is not balanced, we take at most the mean number of instances per class."""
+ np.random.seed(self.seed)
+ x = self.data
+ y = self.targets
+ num_to_sample = int(np.bincount(y.flatten()).mean())
+ all_sampled_indices = []
+ for label in np.unique(y.flatten()):
+ inds = np.where(y == label)[0]
+ sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
+ all_sampled_indices.append(sampled_indices)
+ indices = np.concatenate(all_sampled_indices)
+ x_sampled = x[indices]
+ y_sampled = y[indices]
+ self.data = x_sampled
+ self.targets = y_sampled
+
+ def _subsample(self) -> None:
+ """Subsamples the dataset to the specified fraction."""
+ x = self.data
+ y = self.targets
+ num_samples = int(x.shape[0] * self.subsample_fraction)
+ x_sampled = x[:num_samples]
+ y_sampled = y[:num_samples]
+ self.data = x_sampled
+ self.targets = y_sampled
+
+ def load_emnist_dataset(self) -> None:
+ """Fetch the EMNIST dataset."""
+ dataset = EMNIST(
+ root=DATA_DIRNAME,
+ split="byclass",
+ train=self.train,
+ download=False,
+ transform=None,
+ target_transform=None,
+ )
+
+ self.data = dataset.data
+ self.targets = dataset.targets
+
+ if self.sample_to_balance:
+ self._sample_to_balance()
+
+ if self.subsample_fraction is not None:
+ self._subsample()
-class EmnistDataLoader:
+class EmnistDataLoaders:
"""Class for Emnist DataLoaders."""
def __init__(
@@ -68,7 +235,7 @@ class EmnistDataLoader:
cuda: bool = True,
seed: int = 4711,
) -> None:
- """Fetches DataLoaders.
+ """Fetches DataLoaders for given split(s).
Args:
splits (List[str]): One or both of the dataset splits "train" and "val".
@@ -88,13 +255,17 @@ class EmnistDataLoader:
them. Defaults to True.
seed (int): Seed for sampling.
+ Raises:
+ ValueError: If subsample_fraction is not None and outside the range (0, 1).
+
"""
self.splits = splits
self.sample_to_balance = sample_to_balance
+
if subsample_fraction is not None:
- assert (
- 0.0 < subsample_fraction < 1.0
- ), " The subsample fraction must be in (0, 1)."
+ if not 0.0 < subsample_fraction < 1.0:
+ raise ValueError("The subsample fraction must be in (0, 1).")
+
self.subsample_fraction = subsample_fraction
self.transform = transform
self.target_transform = target_transform
@@ -105,6 +276,10 @@ class EmnistDataLoader:
self.seed = seed
self._data_loaders = self._fetch_emnist_data_loaders()
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ return self._data_loaders[self.splits[0]].dataset.__repr__()
+
@property
def __name__(self) -> str:
"""Returns the name of the dataset."""
@@ -128,59 +303,6 @@ class EmnistDataLoader:
except KeyError:
raise ValueError(f"Split {split} does not exist.")
- def _sample_to_balance(self, dataset: type = EMNIST) -> EMNIST:
- """Because the dataset is not balanced, we take at most the mean number of instances per class."""
- np.random.seed(self.seed)
- x = dataset.data
- y = dataset.targets
- num_to_sample = int(np.bincount(y.flatten()).mean())
- all_sampled_indices = []
- for label in np.unique(y.flatten()):
- inds = np.where(y == label)[0]
- sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
- all_sampled_indices.append(sampled_indices)
- indices = np.concatenate(all_sampled_indices)
- x_sampled = x[indices]
- y_sampled = y[indices]
- dataset.data = x_sampled
- dataset.targets = y_sampled
-
- return dataset
-
- def _subsample(self, dataset: type = EMNIST) -> EMNIST:
- """Subsamples the dataset to the specified fraction."""
- x = dataset.data
- y = dataset.targets
- num_samples = int(x.shape[0] * self.subsample_fraction)
- x_sampled = x[:num_samples]
- y_sampled = y[:num_samples]
- dataset.data = x_sampled
- dataset.targets = y_sampled
-
- return dataset
-
- def _fetch_emnist_dataset(self, train: bool) -> EMNIST:
- """Fetch the EMNIST dataset."""
- if self.transform is None:
- transform = Compose([Transpose(), ToTensor()])
-
- dataset = EMNIST(
- root=DATA_DIRNAME,
- split="byclass",
- train=train,
- download=False,
- transform=transform,
- target_transform=self.target_transform,
- )
-
- if self.sample_to_balance:
- dataset = self._sample_to_balance(dataset)
-
- if self.subsample_fraction is not None:
- dataset = self._subsample(dataset)
-
- return dataset
-
def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]:
"""Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders."""
data_loaders = {}
@@ -193,10 +315,19 @@ class EmnistDataLoader:
else:
train = False
- dataset = self._fetch_emnist_dataset(train)
+ emnist_dataset = EmnistDataset(
+ train=train,
+ sample_to_balance=self.sample_to_balance,
+ subsample_fraction=self.subsample_fraction,
+ transform=self.transform,
+ target_transform=self.target_transform,
+ seed=self.seed,
+ )
+
+ emnist_dataset.load_emnist_dataset()
data_loader = DataLoader(
- dataset=dataset,
+ dataset=emnist_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
new file mode 100644
index 0000000..d49319f
--- /dev/null
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -0,0 +1,326 @@
+"""Emnist Lines dataset: synthetic handwritten lines dataset made from Emnist characters."""
+
+from collections import defaultdict
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import h5py
+from loguru import logger
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from torchvision.transforms import Compose, Normalize, ToTensor
+
+from text_recognizer.datasets import DATA_DIRNAME, EmnistDataset, SentenceGenerator
+from text_recognizer.datasets.util import Transpose
+
+DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
+ESSENTIALS_FILENAME = (
+ Path(__file__).resolve().parents[0] / "emnist_lines_essentials.json"
+)
+
+
+class EmnistLinesDataset(Dataset):
+ """Synthetic dataset of lines from the Brown corpus with Emnist characters."""
+
+ def __init__(
+ self,
+ emnist: EmnistDataset,
+ train: bool = False,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ max_length: int = 34,
+ min_overlap: float = 0,
+ max_overlap: float = 0.33,
+ num_samples: int = 10000,
+ seed: int = 4711,
+ ) -> None:
+ """Short summary.
+
+ Args:
+ emnist (EmnistDataset): A EmnistDataset object.
+ train (bool): Flag for the filename. Defaults to False.
+ transform (Optional[Callable]): The transform of the data. Defaults to None.
+ target_transform (Optional[Callable]): The transform of the target. Defaults to None.
+ max_length (int): The maximum number of characters. Defaults to 34.
+ min_overlap (float): The minimum overlap between concatenated images. Defaults to 0.
+ max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
+ num_samples (int): Number of samples to generate. Defaults to 10000.
+ seed (int): Seed number. Defaults to 4711.
+
+ """
+ self.train = train
+ self.emnist = emnist
+
+ self.transform = transform
+ if self.transform is None:
+ self.transform = Compose([ToTensor()])
+
+ self.target_transform = target_transform
+ if self.target_transform is None:
+ self.target_transform = torch.tensor
+
+ self.mapping = self.emnist.mapping
+ self.num_classes = self.emnist.num_classes
+ self.max_length = max_length
+ self.min_overlap = min_overlap
+ self.max_overlap = max_overlap
+ self.num_samples = num_samples
+ self.input_shape = (
+ self.emnist.input_shape[0],
+ self.emnist.input_shape[1] * self.max_length,
+ )
+ self.output_shape = (self.max_length, self.num_classes)
+ self.seed = seed
+
+ # Placeholders for the generated dataset.
+ self.data = None
+ self.target = None
+
+ def __len__(self) -> int:
+ """Returns the length of the dataset."""
+ return len(self.data)
+
+ def __getitem__(
+ self, index: Union[int, torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, torch.Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ # data = np.array([self.data[index]])
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ return (
+ "EMNIST Lines Dataset\n" # pylint: disable=no-member
+ f"Max length: {self.max_length}\n"
+ f"Min overlap: {self.min_overlap}\n"
+ f"Max overlap: {self.max_overlap}\n"
+ f"Num classes: {self.num_classes}\n"
+ f"Input shape: {self.input_shape}\n"
+ f"Data: {self.data.shape}\n"
+ f"Tagets: {self.targets.shape}\n"
+ )
+
+ @property
+ def data_filename(self) -> Path:
+ """Path to the h5 file."""
+ filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
+ if self.train:
+ filename = "train_" + filename
+ else:
+ filename = "val_" + filename
+ return DATA_DIRNAME / filename
+
+ def _load_or_generate_data(self) -> None:
+ """Loads the dataset, if it does not exist a new dataset is generated before loading it."""
+ np.random.seed(self.seed)
+
+ if not self.data_filename.exists():
+ self._generate_data()
+ self._load_data()
+
+ def _load_data(self) -> None:
+ """Loads the dataset from the h5 file."""
+ logger.debug("EmnistLinesDataset loading data from HDF5...")
+ with h5py.File(self.data_filename, "r") as f:
+ self.data = f["data"][:]
+ self.targets = f["targets"][:]
+
+ def _generate_data(self) -> str:
+ """Generates a dataset with the Brown corpus and Emnist characters."""
+ logger.debug("Generating data...")
+
+ sentence_generator = SentenceGenerator(self.max_length)
+
+ # Load emnist dataset.
+ self.emnist.load_emnist_dataset()
+ samples_by_character = get_samples_by_character(
+ self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping,
+ )
+
+ DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ with h5py.File(self.data_filename, "a") as f:
+ data, targets = create_dataset_of_images(
+ self.num_samples,
+ samples_by_character,
+ sentence_generator,
+ self.min_overlap,
+ self.max_overlap,
+ )
+
+ targets = convert_strings_to_categorical_labels(
+ targets, self.emnist.inverse_mapping
+ )
+
+ f.create_dataset("data", data=data, dtype="u1", compression="lzf")
+ f.create_dataset("targets", data=targets, dtype="u1", compression="lzf")
+
+
+def get_samples_by_character(
+ samples: np.ndarray, labels: np.ndarray, mapping: Dict
+) -> defaultdict:
+ """Creates a dictionary with character as key and value as the list of images of that character.
+
+ Args:
+ samples (np.ndarray): Dataset of images of characters.
+ labels (np.ndarray): The labels for each image.
+ mapping (Dict): The Emnist mapping dictionary.
+
+ Returns:
+ defaultdict: A dictionary with characters as keys and list of images as values.
+
+ """
+ samples_by_character = defaultdict(list)
+ for sample, label in zip(samples, labels.flatten()):
+ samples_by_character[mapping[label]].append(sample)
+ return samples_by_character
+
+
+def select_letter_samples_for_string(
+ string: str, samples_by_character: Dict
+) -> List[np.ndarray]:
+ """Randomly selects Emnist characters to use for the senetence.
+
+ Args:
+ string (str): The word or sentence.
+ samples_by_character (Dict): The dictionary of emnist images of each character.
+
+ Returns:
+ List[np.ndarray]: A list of emnist images of the string.
+
+ """
+ zero_image = np.zeros((28, 28), np.uint8)
+ sample_image_by_character = {}
+ for character in string:
+ if character in sample_image_by_character:
+ continue
+ samples = samples_by_character[character]
+ sample = samples[np.random.choice(len(samples))] if samples else zero_image
+ sample_image_by_character[character] = sample.reshape(28, 28).swapaxes(0, 1)
+ return [sample_image_by_character[character] for character in string]
+
+
+def construct_image_from_string(
+ string: str, samples_by_character: Dict, min_overlap: float, max_overlap: float
+) -> np.ndarray:
+ """Concatenates images of the characters in the string.
+
+ The concatination is made with randomly selected overlap so that some portion of the character will overlap.
+
+ Args:
+ string (str): The word or sentence.
+ samples_by_character (Dict): The dictionary of emnist images of each character.
+ min_overlap (float): Minimum amount of overlap between Emnist images.
+ max_overlap (float): Maximum amount of overlap between Emnist images.
+
+ Returns:
+ np.ndarray: The Emnist image of the string.
+
+ """
+ overlap = np.random.uniform(min_overlap, max_overlap)
+ sampled_images = select_letter_samples_for_string(string, samples_by_character)
+ length = len(sampled_images)
+ height, width = sampled_images[0].shape
+ next_overlap_width = width - int(overlap * width)
+ concatenated_image = np.zeros((height, width * length), np.uint8)
+ x = 0
+ for image in sampled_images:
+ concatenated_image[:, x : (x + width)] += image
+ x += next_overlap_width
+ return np.minimum(255, concatenated_image)
+
+
+def create_dataset_of_images(
+ length: int,
+ samples_by_character: Dict,
+ sentence_generator: SentenceGenerator,
+ min_overlap: float,
+ max_overlap: float,
+) -> Tuple[np.ndarray, List[str]]:
+ """Creates a dataset with images and labels from strings generated from the SentenceGenerator.
+
+ Args:
+ length (int): The number of characters for each string.
+ samples_by_character (Dict): The dictionary of emnist images of each character.
+ sentence_generator (SentenceGenerator): A SentenceGenerator objest.
+ min_overlap (float): Minimum amount of overlap between Emnist images.
+ max_overlap (float): Maximum amount of overlap between Emnist images.
+
+ Returns:
+ Tuple[np.ndarray, List[str]]: A list of Emnist images and a list of the strings (labels).
+
+ Raises:
+ RuntimeError: If the sentence generator is not able to generate a string.
+
+ """
+ sample_label = sentence_generator.generate()
+ sample_image = construct_image_from_string(sample_label, samples_by_character, 0, 0)
+ images = np.zeros((length, sample_image.shape[0], sample_image.shape[1]), np.uint8)
+ labels = []
+ for n in range(length):
+ label = None
+ # Try several times to generate before actually throwing an error.
+ for _ in range(10):
+ try:
+ label = sentence_generator.generate()
+ break
+ except Exception: # pylint: disable=broad-except
+ pass
+ if label is None:
+ raise RuntimeError("Was not able to generate a valid string.")
+ images[n] = construct_image_from_string(
+ label, samples_by_character, min_overlap, max_overlap
+ )
+ labels.append(label)
+ return images, labels
+
+
+def convert_strings_to_categorical_labels(
+ labels: List[str], mapping: Dict
+) -> np.ndarray:
+ """Translates a string of characters in to a target array of class int."""
+ return np.array([[mapping[c] for c in label] for label in labels])
+
+
+def create_datasets(
+ max_length: int = 34,
+ min_overlap: float = 0,
+ max_overlap: float = 0.33,
+ num_train: int = 10000,
+ num_val: int = 1000,
+) -> None:
+ """Creates a training an validation dataset of Emnist lines."""
+ emnist_train = EmnistDataset(train=True, sample_to_balance=True)
+ emnist_val = EmnistDataset(train=False, sample_to_balance=True)
+ datasets = [emnist_train, emnist_val]
+ num_samples = [num_train, num_val]
+ for num, train, dataset in zip(num_samples, [True, False], datasets):
+ emnist_lines = EmnistLinesDataset(
+ train=train,
+ emnist=dataset,
+ max_length=max_length,
+ min_overlap=min_overlap,
+ max_overlap=max_overlap,
+ num_samples=num,
+ )
+ emnist_lines._load_or_generate_data()
diff --git a/src/text_recognizer/datasets/sentence_generator.py b/src/text_recognizer/datasets/sentence_generator.py
new file mode 100644
index 0000000..ee86bd4
--- /dev/null
+++ b/src/text_recognizer/datasets/sentence_generator.py
@@ -0,0 +1,81 @@
+"""Downloading the Brown corpus with NLTK for sentence generating."""
+
+import itertools
+import re
+import string
+from typing import Optional
+
+import nltk
+from nltk.corpus.reader.util import ConcatenatedCorpusView
+import numpy as np
+
+from text_recognizer.datasets import DATA_DIRNAME
+
+NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk"
+
+
+class SentenceGenerator:
+ """Generates text sentences using the Brown corpus."""
+
+ def __init__(self, max_length: Optional[int] = None) -> None:
+ """Loads the corpus and sets word start indices."""
+ self.corpus = brown_corpus()
+ self.word_start_indices = [0] + [
+ _.start(0) + 1 for _ in re.finditer(" ", self.corpus)
+ ]
+ self.max_length = max_length
+
+ def generate(self, max_length: Optional[int] = None) -> str:
+ """Generates a word or sentences from the Brown corpus.
+
+ Sample a string from the Brown corpus of length at least one word and at most max_length, padding to
+ max_length with the '_' characters if sentence is shorter.
+
+ Args:
+ max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None.
+
+ Returns:
+ str: A sentence from the Brown corpus.
+
+ Raises:
+ ValueError: If max_length was not specified at initialization and not given as an argument.
+
+ """
+ if max_length is None:
+ max_length = self.max_length
+ if max_length is None:
+ raise ValueError(
+ "Must provide max_length to this method or when making this object."
+ )
+
+ index = np.random.randint(0, len(self.word_start_indices) - 1)
+ start_index = self.word_start_indices[index]
+ end_index_candidates = []
+ for index in range(index + 1, len(self.word_start_indices)):
+ if self.word_start_indices[index] - start_index > max_length:
+ break
+ end_index_candidates.append(self.word_start_indices[index])
+ end_index = np.random.choice(end_index_candidates)
+ sampled_text = self.corpus[start_index:end_index].strip()
+ padding = "_" * (max_length - len(sampled_text))
+ return sampled_text + padding
+
+
+def brown_corpus() -> str:
+ """Returns a single string with the Brown corpus with all punctuations stripped."""
+ sentences = load_nltk_brown_corpus()
+ corpus = " ".join(itertools.chain.from_iterable(sentences))
+ corpus = corpus.translate({ord(c): None for c in string.punctuation})
+ corpus = re.sub(" +", " ", corpus)
+ return corpus
+
+
+def load_nltk_brown_corpus() -> ConcatenatedCorpusView:
+ """Load the Brown corpus using the NLTK library."""
+ nltk.data.path.append(NLTK_DATA_DIRNAME)
+ try:
+ nltk.corpus.brown.sents()
+ except LookupError:
+ NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ nltk.download("brown", download_dir=NLTK_DATA_DIRNAME)
+ return nltk.corpus.brown.sents()
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
new file mode 100644
index 0000000..6668eef
--- /dev/null
+++ b/src/text_recognizer/datasets/util.py
@@ -0,0 +1,11 @@
+"""Util functions for datasets."""
+import numpy as np
+from PIL import Image
+
+
+class Transpose:
+ """Transposes the EMNIST image to the correct orientation."""
+
+ def __call__(self, image: Image) -> np.ndarray:
+ """Swaps axis."""
+ return np.array(image).swapaxes(0, 1)