diff options
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 24 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 279 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 326 | ||||
-rw-r--r-- | src/text_recognizer/datasets/sentence_generator.py | 81 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 11 |
5 files changed, 645 insertions, 76 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index 795be90..bfa6a02 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,4 +1,24 @@ """Dataset modules.""" -from .emnist_dataset import EmnistDataLoader +from .emnist_dataset import ( + DATA_DIRNAME, + EmnistDataLoaders, + EmnistDataset, +) +from .emnist_lines_dataset import ( + construct_image_from_string, + EmnistLinesDataset, + get_samples_by_character, +) +from .sentence_generator import SentenceGenerator +from .util import Transpose -__all__ = ["EmnistDataLoader"] +__all__ = [ + "construct_image_from_string", + "DATA_DIRNAME", + "EmnistDataset", + "EmnistDataLoaders", + "EmnistLinesDataset", + "get_samples_by_character", + "SentenceGenerator", + "Transpose", +] diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index b92b57d..525df95 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -1,29 +1,23 @@ -"""Fetches a PyTorch DataLoader with the EMNIST dataset.""" +"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9).""" import json from pathlib import Path -from typing import Callable, Dict, List, Optional, Type +from typing import Callable, Dict, List, Optional, Tuple, Type, Union from loguru import logger import numpy as np from PIL import Image -from torch.utils.data import DataLoader +import torch +from torch.utils.data import DataLoader, Dataset from torchvision.datasets import EMNIST -from torchvision.transforms import Compose, ToTensor +from torchvision.transforms import Compose, Normalize, ToTensor +from text_recognizer.datasets.util import Transpose DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" -class Transpose: - """Transposes the EMNIST image to the correct orientation.""" - - def __call__(self, image: Image) -> np.ndarray: - """Swaps axis.""" - return np.array(image).swapaxes(0, 1) - - def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: """Extract and saves EMNIST essentials.""" labels = emnsit_dataset.classes @@ -45,14 +39,187 @@ def download_emnist() -> None: save_emnist_essentials(dataset) -def load_emnist_mapping() -> Dict[int, str]: +def _load_emnist_essentials() -> Dict: """Load the EMNIST mapping.""" with open(str(ESSENTIALS_FILENAME)) as f: essentials = json.load(f) - return dict(essentials["mapping"]) + return essentials + + +def _augment_emnist_mapping(mapping: Dict) -> Dict: + """Augment the mapping with extra symbols.""" + # Extra symbols in IAM dataset + extra_symbols = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # padding symbol + extra_symbols.append("_") + + max_key = max(mapping.keys()) + extra_mapping = {} + for i, symbol in enumerate(extra_symbols): + extra_mapping[max_key + 1 + i] = symbol + + return {**mapping, **extra_mapping} + + +class EmnistDataset(Dataset): + """This is a class for resampling and subsampling the PyTorch EMNIST dataset.""" + + def __init__( + self, + train: bool = False, + sample_to_balance: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + seed: int = 4711, + ) -> None: + """Loads the dataset and the mappings. + + Args: + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to + False. + sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. + subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + transform (Optional[Callable]): Transform(s) for input data. Defaults to None. + target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + seed (int): Seed number. Defaults to 4711. + + Raises: + ValueError: If subsample_fraction is not None and outside the range (0, 1). + + """ + + self.train = train + self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: + if not 0.0 < subsample_fraction < 1.0: + raise ValueError("The subsample fraction must be in (0, 1).") + self.subsample_fraction = subsample_fraction + self.transform = transform + if self.transform is None: + self.transform = Compose([Transpose(), ToTensor()]) + + self.target_transform = target_transform + self.seed = seed + + # Load dataset infromation. + essentials = _load_emnist_essentials() + self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) + self.inverse_mapping = {v: k for k, v in self.mapping.items()} + self.num_classes = len(self.mapping) + self.input_shape = essentials["input_shape"] + + # Placeholders + self.data = None + self.targets = None + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def __getitem__( + self, index: Union[int, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Fetches samples from the dataset. + + Args: + index (Union[int, torch.Tensor]): The indices of the samples to fetch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Data target tuple. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return ( + "EMNIST Dataset\n" + f"Num classes: {self.num_classes}\n" + f"Mapping: {self.mapping}\n" + f"Input shape: {self.input_shape}\n" + ) + + def _sample_to_balance(self) -> None: + """Because the dataset is not balanced, we take at most the mean number of instances per class.""" + np.random.seed(self.seed) + x = self.data + y = self.targets + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + inds = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled = y[indices] + self.data = x_sampled + self.targets = y_sampled + + def _subsample(self) -> None: + """Subsamples the dataset to the specified fraction.""" + x = self.data + y = self.targets + num_samples = int(x.shape[0] * self.subsample_fraction) + x_sampled = x[:num_samples] + y_sampled = y[:num_samples] + self.data = x_sampled + self.targets = y_sampled + + def load_emnist_dataset(self) -> None: + """Fetch the EMNIST dataset.""" + dataset = EMNIST( + root=DATA_DIRNAME, + split="byclass", + train=self.train, + download=False, + transform=None, + target_transform=None, + ) + + self.data = dataset.data + self.targets = dataset.targets + + if self.sample_to_balance: + self._sample_to_balance() + + if self.subsample_fraction is not None: + self._subsample() -class EmnistDataLoader: +class EmnistDataLoaders: """Class for Emnist DataLoaders.""" def __init__( @@ -68,7 +235,7 @@ class EmnistDataLoader: cuda: bool = True, seed: int = 4711, ) -> None: - """Fetches DataLoaders. + """Fetches DataLoaders for given split(s). Args: splits (List[str]): One or both of the dataset splits "train" and "val". @@ -88,13 +255,17 @@ class EmnistDataLoader: them. Defaults to True. seed (int): Seed for sampling. + Raises: + ValueError: If subsample_fraction is not None and outside the range (0, 1). + """ self.splits = splits self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: - assert ( - 0.0 < subsample_fraction < 1.0 - ), " The subsample fraction must be in (0, 1)." + if not 0.0 < subsample_fraction < 1.0: + raise ValueError("The subsample fraction must be in (0, 1).") + self.subsample_fraction = subsample_fraction self.transform = transform self.target_transform = target_transform @@ -105,6 +276,10 @@ class EmnistDataLoader: self.seed = seed self._data_loaders = self._fetch_emnist_data_loaders() + def __repr__(self) -> str: + """Returns information about the dataset.""" + return self._data_loaders[self.splits[0]].dataset.__repr__() + @property def __name__(self) -> str: """Returns the name of the dataset.""" @@ -128,59 +303,6 @@ class EmnistDataLoader: except KeyError: raise ValueError(f"Split {split} does not exist.") - def _sample_to_balance(self, dataset: type = EMNIST) -> EMNIST: - """Because the dataset is not balanced, we take at most the mean number of instances per class.""" - np.random.seed(self.seed) - x = dataset.data - y = dataset.targets - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_indices = [] - for label in np.unique(y.flatten()): - inds = np.where(y == label)[0] - sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) - all_sampled_indices.append(sampled_indices) - indices = np.concatenate(all_sampled_indices) - x_sampled = x[indices] - y_sampled = y[indices] - dataset.data = x_sampled - dataset.targets = y_sampled - - return dataset - - def _subsample(self, dataset: type = EMNIST) -> EMNIST: - """Subsamples the dataset to the specified fraction.""" - x = dataset.data - y = dataset.targets - num_samples = int(x.shape[0] * self.subsample_fraction) - x_sampled = x[:num_samples] - y_sampled = y[:num_samples] - dataset.data = x_sampled - dataset.targets = y_sampled - - return dataset - - def _fetch_emnist_dataset(self, train: bool) -> EMNIST: - """Fetch the EMNIST dataset.""" - if self.transform is None: - transform = Compose([Transpose(), ToTensor()]) - - dataset = EMNIST( - root=DATA_DIRNAME, - split="byclass", - train=train, - download=False, - transform=transform, - target_transform=self.target_transform, - ) - - if self.sample_to_balance: - dataset = self._sample_to_balance(dataset) - - if self.subsample_fraction is not None: - dataset = self._subsample(dataset) - - return dataset - def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]: """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" data_loaders = {} @@ -193,10 +315,19 @@ class EmnistDataLoader: else: train = False - dataset = self._fetch_emnist_dataset(train) + emnist_dataset = EmnistDataset( + train=train, + sample_to_balance=self.sample_to_balance, + subsample_fraction=self.subsample_fraction, + transform=self.transform, + target_transform=self.target_transform, + seed=self.seed, + ) + + emnist_dataset.load_emnist_dataset() data_loader = DataLoader( - dataset=dataset, + dataset=emnist_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py new file mode 100644 index 0000000..d49319f --- /dev/null +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -0,0 +1,326 @@ +"""Emnist Lines dataset: synthetic handwritten lines dataset made from Emnist characters.""" + +from collections import defaultdict +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple, Union + +import h5py +from loguru import logger +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision.transforms import Compose, Normalize, ToTensor + +from text_recognizer.datasets import DATA_DIRNAME, EmnistDataset, SentenceGenerator +from text_recognizer.datasets.util import Transpose + +DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" +ESSENTIALS_FILENAME = ( + Path(__file__).resolve().parents[0] / "emnist_lines_essentials.json" +) + + +class EmnistLinesDataset(Dataset): + """Synthetic dataset of lines from the Brown corpus with Emnist characters.""" + + def __init__( + self, + emnist: EmnistDataset, + train: bool = False, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + max_length: int = 34, + min_overlap: float = 0, + max_overlap: float = 0.33, + num_samples: int = 10000, + seed: int = 4711, + ) -> None: + """Short summary. + + Args: + emnist (EmnistDataset): A EmnistDataset object. + train (bool): Flag for the filename. Defaults to False. + transform (Optional[Callable]): The transform of the data. Defaults to None. + target_transform (Optional[Callable]): The transform of the target. Defaults to None. + max_length (int): The maximum number of characters. Defaults to 34. + min_overlap (float): The minimum overlap between concatenated images. Defaults to 0. + max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. + num_samples (int): Number of samples to generate. Defaults to 10000. + seed (int): Seed number. Defaults to 4711. + + """ + self.train = train + self.emnist = emnist + + self.transform = transform + if self.transform is None: + self.transform = Compose([ToTensor()]) + + self.target_transform = target_transform + if self.target_transform is None: + self.target_transform = torch.tensor + + self.mapping = self.emnist.mapping + self.num_classes = self.emnist.num_classes + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_samples = num_samples + self.input_shape = ( + self.emnist.input_shape[0], + self.emnist.input_shape[1] * self.max_length, + ) + self.output_shape = (self.max_length, self.num_classes) + self.seed = seed + + # Placeholders for the generated dataset. + self.data = None + self.target = None + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def __getitem__( + self, index: Union[int, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, torch.Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + # data = np.array([self.data[index]]) + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return ( + "EMNIST Lines Dataset\n" # pylint: disable=no-member + f"Max length: {self.max_length}\n" + f"Min overlap: {self.min_overlap}\n" + f"Max overlap: {self.max_overlap}\n" + f"Num classes: {self.num_classes}\n" + f"Input shape: {self.input_shape}\n" + f"Data: {self.data.shape}\n" + f"Tagets: {self.targets.shape}\n" + ) + + @property + def data_filename(self) -> Path: + """Path to the h5 file.""" + filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" + if self.train: + filename = "train_" + filename + else: + filename = "val_" + filename + return DATA_DIRNAME / filename + + def _load_or_generate_data(self) -> None: + """Loads the dataset, if it does not exist a new dataset is generated before loading it.""" + np.random.seed(self.seed) + + if not self.data_filename.exists(): + self._generate_data() + self._load_data() + + def _load_data(self) -> None: + """Loads the dataset from the h5 file.""" + logger.debug("EmnistLinesDataset loading data from HDF5...") + with h5py.File(self.data_filename, "r") as f: + self.data = f["data"][:] + self.targets = f["targets"][:] + + def _generate_data(self) -> str: + """Generates a dataset with the Brown corpus and Emnist characters.""" + logger.debug("Generating data...") + + sentence_generator = SentenceGenerator(self.max_length) + + # Load emnist dataset. + self.emnist.load_emnist_dataset() + samples_by_character = get_samples_by_character( + self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping, + ) + + DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(self.data_filename, "a") as f: + data, targets = create_dataset_of_images( + self.num_samples, + samples_by_character, + sentence_generator, + self.min_overlap, + self.max_overlap, + ) + + targets = convert_strings_to_categorical_labels( + targets, self.emnist.inverse_mapping + ) + + f.create_dataset("data", data=data, dtype="u1", compression="lzf") + f.create_dataset("targets", data=targets, dtype="u1", compression="lzf") + + +def get_samples_by_character( + samples: np.ndarray, labels: np.ndarray, mapping: Dict +) -> defaultdict: + """Creates a dictionary with character as key and value as the list of images of that character. + + Args: + samples (np.ndarray): Dataset of images of characters. + labels (np.ndarray): The labels for each image. + mapping (Dict): The Emnist mapping dictionary. + + Returns: + defaultdict: A dictionary with characters as keys and list of images as values. + + """ + samples_by_character = defaultdict(list) + for sample, label in zip(samples, labels.flatten()): + samples_by_character[mapping[label]].append(sample) + return samples_by_character + + +def select_letter_samples_for_string( + string: str, samples_by_character: Dict +) -> List[np.ndarray]: + """Randomly selects Emnist characters to use for the senetence. + + Args: + string (str): The word or sentence. + samples_by_character (Dict): The dictionary of emnist images of each character. + + Returns: + List[np.ndarray]: A list of emnist images of the string. + + """ + zero_image = np.zeros((28, 28), np.uint8) + sample_image_by_character = {} + for character in string: + if character in sample_image_by_character: + continue + samples = samples_by_character[character] + sample = samples[np.random.choice(len(samples))] if samples else zero_image + sample_image_by_character[character] = sample.reshape(28, 28).swapaxes(0, 1) + return [sample_image_by_character[character] for character in string] + + +def construct_image_from_string( + string: str, samples_by_character: Dict, min_overlap: float, max_overlap: float +) -> np.ndarray: + """Concatenates images of the characters in the string. + + The concatination is made with randomly selected overlap so that some portion of the character will overlap. + + Args: + string (str): The word or sentence. + samples_by_character (Dict): The dictionary of emnist images of each character. + min_overlap (float): Minimum amount of overlap between Emnist images. + max_overlap (float): Maximum amount of overlap between Emnist images. + + Returns: + np.ndarray: The Emnist image of the string. + + """ + overlap = np.random.uniform(min_overlap, max_overlap) + sampled_images = select_letter_samples_for_string(string, samples_by_character) + length = len(sampled_images) + height, width = sampled_images[0].shape + next_overlap_width = width - int(overlap * width) + concatenated_image = np.zeros((height, width * length), np.uint8) + x = 0 + for image in sampled_images: + concatenated_image[:, x : (x + width)] += image + x += next_overlap_width + return np.minimum(255, concatenated_image) + + +def create_dataset_of_images( + length: int, + samples_by_character: Dict, + sentence_generator: SentenceGenerator, + min_overlap: float, + max_overlap: float, +) -> Tuple[np.ndarray, List[str]]: + """Creates a dataset with images and labels from strings generated from the SentenceGenerator. + + Args: + length (int): The number of characters for each string. + samples_by_character (Dict): The dictionary of emnist images of each character. + sentence_generator (SentenceGenerator): A SentenceGenerator objest. + min_overlap (float): Minimum amount of overlap between Emnist images. + max_overlap (float): Maximum amount of overlap between Emnist images. + + Returns: + Tuple[np.ndarray, List[str]]: A list of Emnist images and a list of the strings (labels). + + Raises: + RuntimeError: If the sentence generator is not able to generate a string. + + """ + sample_label = sentence_generator.generate() + sample_image = construct_image_from_string(sample_label, samples_by_character, 0, 0) + images = np.zeros((length, sample_image.shape[0], sample_image.shape[1]), np.uint8) + labels = [] + for n in range(length): + label = None + # Try several times to generate before actually throwing an error. + for _ in range(10): + try: + label = sentence_generator.generate() + break + except Exception: # pylint: disable=broad-except + pass + if label is None: + raise RuntimeError("Was not able to generate a valid string.") + images[n] = construct_image_from_string( + label, samples_by_character, min_overlap, max_overlap + ) + labels.append(label) + return images, labels + + +def convert_strings_to_categorical_labels( + labels: List[str], mapping: Dict +) -> np.ndarray: + """Translates a string of characters in to a target array of class int.""" + return np.array([[mapping[c] for c in label] for label in labels]) + + +def create_datasets( + max_length: int = 34, + min_overlap: float = 0, + max_overlap: float = 0.33, + num_train: int = 10000, + num_val: int = 1000, +) -> None: + """Creates a training an validation dataset of Emnist lines.""" + emnist_train = EmnistDataset(train=True, sample_to_balance=True) + emnist_val = EmnistDataset(train=False, sample_to_balance=True) + datasets = [emnist_train, emnist_val] + num_samples = [num_train, num_val] + for num, train, dataset in zip(num_samples, [True, False], datasets): + emnist_lines = EmnistLinesDataset( + train=train, + emnist=dataset, + max_length=max_length, + min_overlap=min_overlap, + max_overlap=max_overlap, + num_samples=num, + ) + emnist_lines._load_or_generate_data() diff --git a/src/text_recognizer/datasets/sentence_generator.py b/src/text_recognizer/datasets/sentence_generator.py new file mode 100644 index 0000000..ee86bd4 --- /dev/null +++ b/src/text_recognizer/datasets/sentence_generator.py @@ -0,0 +1,81 @@ +"""Downloading the Brown corpus with NLTK for sentence generating.""" + +import itertools +import re +import string +from typing import Optional + +import nltk +from nltk.corpus.reader.util import ConcatenatedCorpusView +import numpy as np + +from text_recognizer.datasets import DATA_DIRNAME + +NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk" + + +class SentenceGenerator: + """Generates text sentences using the Brown corpus.""" + + def __init__(self, max_length: Optional[int] = None) -> None: + """Loads the corpus and sets word start indices.""" + self.corpus = brown_corpus() + self.word_start_indices = [0] + [ + _.start(0) + 1 for _ in re.finditer(" ", self.corpus) + ] + self.max_length = max_length + + def generate(self, max_length: Optional[int] = None) -> str: + """Generates a word or sentences from the Brown corpus. + + Sample a string from the Brown corpus of length at least one word and at most max_length, padding to + max_length with the '_' characters if sentence is shorter. + + Args: + max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None. + + Returns: + str: A sentence from the Brown corpus. + + Raises: + ValueError: If max_length was not specified at initialization and not given as an argument. + + """ + if max_length is None: + max_length = self.max_length + if max_length is None: + raise ValueError( + "Must provide max_length to this method or when making this object." + ) + + index = np.random.randint(0, len(self.word_start_indices) - 1) + start_index = self.word_start_indices[index] + end_index_candidates = [] + for index in range(index + 1, len(self.word_start_indices)): + if self.word_start_indices[index] - start_index > max_length: + break + end_index_candidates.append(self.word_start_indices[index]) + end_index = np.random.choice(end_index_candidates) + sampled_text = self.corpus[start_index:end_index].strip() + padding = "_" * (max_length - len(sampled_text)) + return sampled_text + padding + + +def brown_corpus() -> str: + """Returns a single string with the Brown corpus with all punctuations stripped.""" + sentences = load_nltk_brown_corpus() + corpus = " ".join(itertools.chain.from_iterable(sentences)) + corpus = corpus.translate({ord(c): None for c in string.punctuation}) + corpus = re.sub(" +", " ", corpus) + return corpus + + +def load_nltk_brown_corpus() -> ConcatenatedCorpusView: + """Load the Brown corpus using the NLTK library.""" + nltk.data.path.append(NLTK_DATA_DIRNAME) + try: + nltk.corpus.brown.sents() + except LookupError: + NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) + return nltk.corpus.brown.sents() diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py new file mode 100644 index 0000000..6668eef --- /dev/null +++ b/src/text_recognizer/datasets/util.py @@ -0,0 +1,11 @@ +"""Util functions for datasets.""" +import numpy as np +from PIL import Image + + +class Transpose: + """Transposes the EMNIST image to the correct orientation.""" + + def __call__(self, image: Image) -> np.ndarray: + """Swaps axis.""" + return np.array(image).swapaxes(0, 1) |