From 1b3b8073a19f939d18a0bb85247eb0d99284f7cc Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 20 Sep 2020 11:47:24 +0200 Subject: Bash scripts and some bug fixes. --- src/text_recognizer/datasets/emnist_lines_dataset.py | 9 +++------ src/text_recognizer/datasets/iam_paragraphs_dataset.py | 7 ++++++- src/text_recognizer/datasets/util.py | 2 +- 3 files changed, 10 insertions(+), 8 deletions(-) (limited to 'src/text_recognizer/datasets') diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 6268a01..beb5343 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -149,6 +149,7 @@ class EmnistLinesDataset(Dataset): # Load emnist dataset. emnist = EmnistDataset(train=self.train, sample_to_balance=True) + emnist.load_or_generate_data() samples_by_character = get_samples_by_character( emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, @@ -306,17 +307,13 @@ def create_datasets( num_test: int = 1000, ) -> None: """Creates a training an validation dataset of Emnist lines.""" - emnist_train = EmnistDataset(train=True, sample_to_balance=True) - emnist_test = EmnistDataset(train=False, sample_to_balance=True) - datasets = [emnist_train, emnist_test] num_samples = [num_train, num_test] - for num, train, dataset in zip(num_samples, [True, False], datasets): + for num, train in zip(num_samples, [True, False]): emnist_lines = EmnistLinesDataset( train=train, - emnist=dataset, max_length=max_length, min_overlap=min_overlap, max_overlap=max_overlap, num_samples=num, ) - emnist_lines._load_or_generate_data() + emnist_lines.load_or_generate_data() diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py index 4b34bd1..c1e8fe2 100644 --- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py +++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -266,11 +266,16 @@ def _load_iam_paragraphs() -> None: @click.option( "--subsample_fraction", type=float, - default=0.0, + default=None, help="The subsampling factor of the dataset.", ) def main(subsample_fraction: float) -> None: """Load dataset and print info.""" + logger.info("Creating train set...") + dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + logger.info("Creating test set...") dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) dataset.load_or_generate_data() print(dataset) diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 73968a1..125f05a 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -26,7 +26,7 @@ def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: mapping = [(i, str(label)) for i, label in enumerate(labels)] essentials = { "mapping": mapping, - "input_shape": tuple(emnsit_dataset[0][0].shape[:]), + "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), } logger.info("Saving emnist essentials...") with open(ESSENTIALS_FILENAME, "w") as f: -- cgit v1.2.3-70-g09d2