summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets/sentence_generator.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-23 21:55:42 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-23 21:55:42 +0100
commitae589fb3ffdbf6c4bb1ae35345f7a3665deeebc5 (patch)
tree1702f74c069679ebdd74a03892275c6eb3a80ffd /text_recognizer/datasets/sentence_generator.py
parente3741de333a3a43a7968241b6eccaaac66dd7b20 (diff)
refactored emnist lines dataset
Diffstat (limited to 'text_recognizer/datasets/sentence_generator.py')
-rw-r--r--text_recognizer/datasets/sentence_generator.py30
1 files changed, 17 insertions, 13 deletions
diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py
index dd76652..53b781c 100644
--- a/text_recognizer/datasets/sentence_generator.py
+++ b/text_recognizer/datasets/sentence_generator.py
@@ -11,7 +11,7 @@ import numpy as np
from text_recognizer.datasets.util import DATA_DIRNAME
-NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk"
+NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk"
class SentenceGenerator:
@@ -47,18 +47,22 @@ class SentenceGenerator:
raise ValueError(
"Must provide max_length to this method or when making this object."
)
-
- index = np.random.randint(0, len(self.word_start_indices) - 1)
- start_index = self.word_start_indices[index]
- end_index_candidates = []
- for index in range(index + 1, len(self.word_start_indices)):
- if self.word_start_indices[index] - start_index > max_length:
- break
- end_index_candidates.append(self.word_start_indices[index])
- end_index = np.random.choice(end_index_candidates)
- sampled_text = self.corpus[start_index:end_index].strip()
- padding = "_" * (max_length - len(sampled_text))
- return sampled_text + padding
+
+ for _ in range(10):
+ try:
+ index = np.random.randint(0, len(self.word_start_indices) - 1)
+ start_index = self.word_start_indices[index]
+ end_index_candidates = []
+ for index in range(index + 1, len(self.word_start_indices)):
+ if self.word_start_indices[index] - start_index > max_length:
+ break
+ end_index_candidates.append(self.word_start_indices[index])
+ end_index = np.random.choice(end_index_candidates)
+ sampled_text = self.corpus[start_index:end_index].strip()
+ return sampled_text
+ except Exception:
+ pass
+ raise RuntimeError("Was not able to generate a valid string")
def brown_corpus() -> str: