diff options
Diffstat (limited to 'text_recognizer/datasets/sentence_generator.py')
-rw-r--r-- | text_recognizer/datasets/sentence_generator.py | 30 |
1 files changed, 17 insertions, 13 deletions
diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py index dd76652..53b781c 100644 --- a/text_recognizer/datasets/sentence_generator.py +++ b/text_recognizer/datasets/sentence_generator.py @@ -11,7 +11,7 @@ import numpy as np from text_recognizer.datasets.util import DATA_DIRNAME -NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk" +NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk" class SentenceGenerator: @@ -47,18 +47,22 @@ class SentenceGenerator: raise ValueError( "Must provide max_length to this method or when making this object." ) - - index = np.random.randint(0, len(self.word_start_indices) - 1) - start_index = self.word_start_indices[index] - end_index_candidates = [] - for index in range(index + 1, len(self.word_start_indices)): - if self.word_start_indices[index] - start_index > max_length: - break - end_index_candidates.append(self.word_start_indices[index]) - end_index = np.random.choice(end_index_candidates) - sampled_text = self.corpus[start_index:end_index].strip() - padding = "_" * (max_length - len(sampled_text)) - return sampled_text + padding + + for _ in range(10): + try: + index = np.random.randint(0, len(self.word_start_indices) - 1) + start_index = self.word_start_indices[index] + end_index_candidates = [] + for index in range(index + 1, len(self.word_start_indices)): + if self.word_start_indices[index] - start_index > max_length: + break + end_index_candidates.append(self.word_start_indices[index]) + end_index = np.random.choice(end_index_candidates) + sampled_text = self.corpus[start_index:end_index].strip() + return sampled_text + except Exception: + pass + raise RuntimeError("Was not able to generate a valid string") def brown_corpus() -> str: |