diff options
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 51 |
1 files changed, 37 insertions, 14 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 6268a01..6871492 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -4,6 +4,7 @@ from collections import defaultdict from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Union +import click import h5py from loguru import logger import numpy as np @@ -37,6 +38,9 @@ class EmnistLinesDataset(Dataset): max_overlap: float = 0.33, num_samples: int = 10000, seed: int = 4711, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, ) -> None: """Set attributes and loads the dataset. @@ -50,13 +54,21 @@ class EmnistLinesDataset(Dataset): max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. num_samples (int): Number of samples to generate. Defaults to 10000. seed (int): Seed number. Defaults to 4711. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. """ + self.pad_token = "_" if pad_token is None else pad_token + super().__init__( train=train, transform=transform, target_transform=target_transform, subsample_fraction=subsample_fraction, + init_token=init_token, + pad_token=self.pad_token, + eos_token=eos_token, ) # Extract dataset information. @@ -118,11 +130,7 @@ class EmnistLinesDataset(Dataset): @property def data_filename(self) -> Path: """Path to the h5 file.""" - filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" - if self.train: - filename = "train_" + filename - else: - filename = "test_" + filename + filename = "train.pt" if self.train else "test.pt" return DATA_DIRNAME / filename def load_or_generate_data(self) -> None: @@ -138,8 +146,8 @@ class EmnistLinesDataset(Dataset): """Loads the dataset from the h5 file.""" logger.debug("EmnistLinesDataset loading data from HDF5...") with h5py.File(self.data_filename, "r") as f: - self._data = f["data"][:] - self._targets = f["targets"][:] + self._data = f["data"][()] + self._targets = f["targets"][()] def _generate_data(self) -> str: """Generates a dataset with the Brown corpus and Emnist characters.""" @@ -148,7 +156,10 @@ class EmnistLinesDataset(Dataset): sentence_generator = SentenceGenerator(self.max_length) # Load emnist dataset. - emnist = EmnistDataset(train=self.train, sample_to_balance=True) + emnist = EmnistDataset( + train=self.train, sample_to_balance=True, pad_token=self.pad_token + ) + emnist.load_or_generate_data() samples_by_character = get_samples_by_character( emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, @@ -298,6 +309,18 @@ def convert_strings_to_categorical_labels( return np.array([[mapping[c] for c in label] for label in labels]) +@click.command() +@click.option( + "--max_length", type=int, default=34, help="Number of characters in a sentence." +) +@click.option( + "--min_overlap", type=float, default=0.0, help="Min overlap between characters." +) +@click.option( + "--max_overlap", type=float, default=0.33, help="Max overlap between characters." +) +@click.option("--num_train", type=int, default=10_000, help="Number of train examples.") +@click.option("--num_test", type=int, default=1_000, help="Number of test examples.") def create_datasets( max_length: int = 34, min_overlap: float = 0, @@ -306,17 +329,17 @@ def create_datasets( num_test: int = 1000, ) -> None: """Creates a training an validation dataset of Emnist lines.""" - emnist_train = EmnistDataset(train=True, sample_to_balance=True) - emnist_test = EmnistDataset(train=False, sample_to_balance=True) - datasets = [emnist_train, emnist_test] num_samples = [num_train, num_test] - for num, train, dataset in zip(num_samples, [True, False], datasets): + for num, train in zip(num_samples, [True, False]): emnist_lines = EmnistLinesDataset( train=train, - emnist=dataset, max_length=max_length, min_overlap=min_overlap, max_overlap=max_overlap, num_samples=num, ) - emnist_lines._load_or_generate_data() + emnist_lines.load_or_generate_data() + + +if __name__ == "__main__": + create_datasets() |