summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets/sentence_generator.py
blob: dd76652d5438df50e66efdd20320aedd2ac6f256 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Downloading the Brown corpus with NLTK for sentence generating."""

import itertools
import re
import string
from typing import Optional

import nltk
from nltk.corpus.reader.util import ConcatenatedCorpusView
import numpy as np

from text_recognizer.datasets.util import DATA_DIRNAME

NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk"


class SentenceGenerator:
    """Generates text sentences using the Brown corpus."""

    def __init__(self, max_length: Optional[int] = None) -> None:
        """Loads the corpus and sets word start indices."""
        self.corpus = brown_corpus()
        self.word_start_indices = [0] + [
            _.start(0) + 1 for _ in re.finditer(" ", self.corpus)
        ]
        self.max_length = max_length

    def generate(self, max_length: Optional[int] = None) -> str:
        """Generates a word or sentences from the Brown corpus.

        Sample a string from the Brown corpus of length at least one word and at most max_length, padding to
        max_length with the '_' characters if sentence is shorter.

        Args:
            max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None.

        Returns:
            str: A sentence from the Brown corpus.

        Raises:
            ValueError: If max_length was not specified at initialization and not given as an argument.

        """
        if max_length is None:
            max_length = self.max_length
        if max_length is None:
            raise ValueError(
                "Must provide max_length to this method or when making this object."
            )

        index = np.random.randint(0, len(self.word_start_indices) - 1)
        start_index = self.word_start_indices[index]
        end_index_candidates = []
        for index in range(index + 1, len(self.word_start_indices)):
            if self.word_start_indices[index] - start_index > max_length:
                break
            end_index_candidates.append(self.word_start_indices[index])
        end_index = np.random.choice(end_index_candidates)
        sampled_text = self.corpus[start_index:end_index].strip()
        padding = "_" * (max_length - len(sampled_text))
        return sampled_text + padding


def brown_corpus() -> str:
    """Returns a single string with the Brown corpus with all punctuations stripped."""
    sentences = load_nltk_brown_corpus()
    corpus = " ".join(itertools.chain.from_iterable(sentences))
    corpus = corpus.translate({ord(c): None for c in string.punctuation})
    corpus = re.sub(" +", " ", corpus)
    return corpus


def load_nltk_brown_corpus() -> ConcatenatedCorpusView:
    """Load the Brown corpus using the NLTK library."""
    nltk.data.path.append(NLTK_DATA_DIRNAME)
    try:
        nltk.corpus.brown.sents()
    except LookupError:
        NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
        nltk.download("brown", download_dir=NLTK_DATA_DIRNAME)
    return nltk.corpus.brown.sents()