diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-24 22:15:54 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-24 22:15:54 +0100 |
commit | 8248f173132dfb7e47ec62b08e9235990c8626e3 (patch) | |
tree | 2f3ff85602cbc08b7168bf4f0d3924d32a689852 /text_recognizer/data/sentence_generator.py | |
parent | 74c907a17379688967dc4b3f41a44ba83034f5e0 (diff) |
renamed datasets to data, added iam refactor
Diffstat (limited to 'text_recognizer/data/sentence_generator.py')
-rw-r--r-- | text_recognizer/data/sentence_generator.py | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/text_recognizer/data/sentence_generator.py b/text_recognizer/data/sentence_generator.py new file mode 100644 index 0000000..53b781c --- /dev/null +++ b/text_recognizer/data/sentence_generator.py @@ -0,0 +1,85 @@ +"""Downloading the Brown corpus with NLTK for sentence generating.""" + +import itertools +import re +import string +from typing import Optional + +import nltk +from nltk.corpus.reader.util import ConcatenatedCorpusView +import numpy as np + +from text_recognizer.datasets.util import DATA_DIRNAME + +NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk" + + +class SentenceGenerator: + """Generates text sentences using the Brown corpus.""" + + def __init__(self, max_length: Optional[int] = None) -> None: + """Loads the corpus and sets word start indices.""" + self.corpus = brown_corpus() + self.word_start_indices = [0] + [ + _.start(0) + 1 for _ in re.finditer(" ", self.corpus) + ] + self.max_length = max_length + + def generate(self, max_length: Optional[int] = None) -> str: + """Generates a word or sentences from the Brown corpus. + + Sample a string from the Brown corpus of length at least one word and at most max_length, padding to + max_length with the '_' characters if sentence is shorter. + + Args: + max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None. + + Returns: + str: A sentence from the Brown corpus. + + Raises: + ValueError: If max_length was not specified at initialization and not given as an argument. + + """ + if max_length is None: + max_length = self.max_length + if max_length is None: + raise ValueError( + "Must provide max_length to this method or when making this object." + ) + + for _ in range(10): + try: + index = np.random.randint(0, len(self.word_start_indices) - 1) + start_index = self.word_start_indices[index] + end_index_candidates = [] + for index in range(index + 1, len(self.word_start_indices)): + if self.word_start_indices[index] - start_index > max_length: + break + end_index_candidates.append(self.word_start_indices[index]) + end_index = np.random.choice(end_index_candidates) + sampled_text = self.corpus[start_index:end_index].strip() + return sampled_text + except Exception: + pass + raise RuntimeError("Was not able to generate a valid string") + + +def brown_corpus() -> str: + """Returns a single string with the Brown corpus with all punctuations stripped.""" + sentences = load_nltk_brown_corpus() + corpus = " ".join(itertools.chain.from_iterable(sentences)) + corpus = corpus.translate({ord(c): None for c in string.punctuation}) + corpus = re.sub(" +", " ", corpus) + return corpus + + +def load_nltk_brown_corpus() -> ConcatenatedCorpusView: + """Load the Brown corpus using the NLTK library.""" + nltk.data.path.append(NLTK_DATA_DIRNAME) + try: + nltk.corpus.brown.sents() + except LookupError: + NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) + return nltk.corpus.brown.sents() |