From ae589fb3ffdbf6c4bb1ae35345f7a3665deeebc5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 23 Mar 2021 21:55:42 +0100 Subject: refactored emnist lines dataset --- text_recognizer/datasets/sentence_generator.py | 30 +++++++++++++++----------- 1 file changed, 17 insertions(+), 13 deletions(-) (limited to 'text_recognizer/datasets/sentence_generator.py') diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py index dd76652..53b781c 100644 --- a/text_recognizer/datasets/sentence_generator.py +++ b/text_recognizer/datasets/sentence_generator.py @@ -11,7 +11,7 @@ import numpy as np from text_recognizer.datasets.util import DATA_DIRNAME -NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk" +NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk" class SentenceGenerator: @@ -47,18 +47,22 @@ class SentenceGenerator: raise ValueError( "Must provide max_length to this method or when making this object." ) - - index = np.random.randint(0, len(self.word_start_indices) - 1) - start_index = self.word_start_indices[index] - end_index_candidates = [] - for index in range(index + 1, len(self.word_start_indices)): - if self.word_start_indices[index] - start_index > max_length: - break - end_index_candidates.append(self.word_start_indices[index]) - end_index = np.random.choice(end_index_candidates) - sampled_text = self.corpus[start_index:end_index].strip() - padding = "_" * (max_length - len(sampled_text)) - return sampled_text + padding + + for _ in range(10): + try: + index = np.random.randint(0, len(self.word_start_indices) - 1) + start_index = self.word_start_indices[index] + end_index_candidates = [] + for index in range(index + 1, len(self.word_start_indices)): + if self.word_start_indices[index] - start_index > max_length: + break + end_index_candidates.append(self.word_start_indices[index]) + end_index = np.random.choice(end_index_candidates) + sampled_text = self.corpus[start_index:end_index].strip() + return sampled_text + except Exception: + pass + raise RuntimeError("Was not able to generate a valid string") def brown_corpus() -> str: -- cgit v1.2.3-70-g09d2