diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-23 21:55:42 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-23 21:55:42 +0100 |
commit | ae589fb3ffdbf6c4bb1ae35345f7a3665deeebc5 (patch) | |
tree | 1702f74c069679ebdd74a03892275c6eb3a80ffd /text_recognizer/datasets/sentence_generator.py | |
parent | e3741de333a3a43a7968241b6eccaaac66dd7b20 (diff) |
refactored emnist lines dataset
Diffstat (limited to 'text_recognizer/datasets/sentence_generator.py')
-rw-r--r-- | text_recognizer/datasets/sentence_generator.py | 30 |
1 files changed, 17 insertions, 13 deletions
diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py index dd76652..53b781c 100644 --- a/text_recognizer/datasets/sentence_generator.py +++ b/text_recognizer/datasets/sentence_generator.py @@ -11,7 +11,7 @@ import numpy as np from text_recognizer.datasets.util import DATA_DIRNAME -NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk" +NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk" class SentenceGenerator: @@ -47,18 +47,22 @@ class SentenceGenerator: raise ValueError( "Must provide max_length to this method or when making this object." ) - - index = np.random.randint(0, len(self.word_start_indices) - 1) - start_index = self.word_start_indices[index] - end_index_candidates = [] - for index in range(index + 1, len(self.word_start_indices)): - if self.word_start_indices[index] - start_index > max_length: - break - end_index_candidates.append(self.word_start_indices[index]) - end_index = np.random.choice(end_index_candidates) - sampled_text = self.corpus[start_index:end_index].strip() - padding = "_" * (max_length - len(sampled_text)) - return sampled_text + padding + + for _ in range(10): + try: + index = np.random.randint(0, len(self.word_start_indices) - 1) + start_index = self.word_start_indices[index] + end_index_candidates = [] + for index in range(index + 1, len(self.word_start_indices)): + if self.word_start_indices[index] - start_index > max_length: + break + end_index_candidates.append(self.word_start_indices[index]) + end_index = np.random.choice(end_index_candidates) + sampled_text = self.corpus[start_index:end_index].strip() + return sampled_text + except Exception: + pass + raise RuntimeError("Was not able to generate a valid string") def brown_corpus() -> str: |