summaryrefslogtreecommitdiff
path: root/text_recognizer/data/utils
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:04:50 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:04:50 +0200
commit8291a87c64f9a5f18caec82201bea15579b49730 (patch)
tree1c8bb3e07a3bd06086e182dd320f8408829ba81c /text_recognizer/data/utils
parent30e3ae483c846418b04ed48f014a4af2cf9a0771 (diff)
Move data utils to submodules
Diffstat (limited to 'text_recognizer/data/utils')
-rw-r--r--text_recognizer/data/utils/build_transitions.py261
-rw-r--r--text_recognizer/data/utils/download_utils.py73
-rw-r--r--text_recognizer/data/utils/iam_preprocessor.py209
-rw-r--r--text_recognizer/data/utils/image_utils.py49
-rw-r--r--text_recognizer/data/utils/make_wordpieces.py112
-rw-r--r--text_recognizer/data/utils/sentence_generator.py89
6 files changed, 793 insertions, 0 deletions
diff --git a/text_recognizer/data/utils/build_transitions.py b/text_recognizer/data/utils/build_transitions.py
new file mode 100644
index 0000000..0f987ca
--- /dev/null
+++ b/text_recognizer/data/utils/build_transitions.py
@@ -0,0 +1,261 @@
+"""Builds transition graph.
+
+Most code stolen from here:
+ https://github.com/facebookresearch/gtn_applications/blob/master/scripts/build_transitions.py
+"""
+
+import collections
+import itertools
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import click
+import gtn
+from loguru import logger
+
+
+START_IDX = -1
+END_IDX = -2
+WORDSEP = "▁"
+
+
+def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph:
+ """Returns a gtn Graph based on the ngrams."""
+ graph = gtn.Graph(False)
+ ngram = len(ngrams)
+ state_to_node = {}
+
+ def get_node(state: Optional[List]) -> Any:
+ node = state_to_node.get(state, None)
+
+ if node is not None:
+ return node
+
+ start = state == tuple([START_IDX]) if ngram > 1 else True
+ end = state == tuple([END_IDX]) if ngram > 1 else True
+ node = graph.add_node(start, end)
+ state_to_node[state] = node
+
+ if not disable_backoff and not end:
+ # Add back off when adding node.
+ for n in range(1, len(state) + 1):
+ backoff_node = state_to_node.get(state[n:], None)
+
+ # Epsilon transition to the back-off state.
+ if backoff_node is not None:
+ graph.add_arc(node, backoff_node, gtn.epsilon)
+ break
+ return node
+
+ for grams in ngrams:
+ for gram in grams:
+ istate, ostate = gram[:-1], gram[len(gram) - ngram + 1 :]
+ inode = get_node(istate)
+
+ if END_IDX not in gram[1:] and gram[1:] not in state_to_node:
+ raise ValueError(
+ "Ill formed counts: if (x, y_1, ..., y_{n-1}) is above"
+ "the n-gram threshold, then (y_1, ..., y_{n-1}) must be"
+ "above the (n-1)-gram threshold"
+ )
+
+ if END_IDX in ostate:
+ # Merge all state having </s> into one as final graph generated
+ # will be similar.
+ ostate = tuple([END_IDX])
+
+ onode = get_node(ostate)
+ # p(gram[-1] | gram[:-1])
+ graph.add_arc(
+ inode, onode, gtn.epsilon if gram[-1] == END_IDX else gram[-1]
+ )
+ return graph
+
+
+def count_ngrams(lines: List, ngram: List, tokens_to_index: Dict) -> List:
+ """Counts the number of ngrams."""
+ counts = [collections.Counter() for _ in range(ngram)]
+ for line in lines:
+ # Prepend implicit start token.
+ token_line = [START_IDX]
+ for t in line:
+ token_line.append(tokens_to_index[t])
+ token_line.append(END_IDX)
+ for n, counter in enumerate(counts):
+ start_offset = n == 0
+ end_offset = ngram == 1
+ for e in range(n + start_offset, len(token_line) - end_offset):
+ counter[tuple(token_line[e - n : e + 1])] += 1
+
+ return counts
+
+
+def prune_ngrams(ngrams: List, prune: List) -> List:
+ """Prunes ngrams."""
+ pruned_ngrams = []
+ for n, grams in enumerate(ngrams):
+ grams = grams.most_common()
+ pruned_grams = [gram for gram, c in grams if c > prune[n]]
+ pruned_ngrams.append(pruned_grams)
+ return pruned_ngrams
+
+
+def add_blank_grams(pruned_ngrams: List, num_tokens: int, blank: str) -> List:
+ """Adds blank token to grams."""
+ all_grams = [gram for grams in pruned_ngrams for gram in grams]
+ maxorder = len(pruned_ngrams)
+ blank_grams = {}
+ if blank == "forced":
+ pruned_ngrams = [pruned_ngrams[0] if i == 0 else [] for i in range(maxorder)]
+ pruned_ngrams[0].append(tuple([num_tokens]))
+ blank_grams[tuple([num_tokens])] = True
+
+ for gram in all_grams:
+ # Iterate over all possibilities by using a vector of 0s, 1s to
+ # denote whether a blank is being used at each position.
+ if blank == "optional":
+ # Given a gram ab.. if order n, we have n + 1 positions
+ # available whether to use blank or not.
+ onehot_vectors = itertools.product([0, 1], repeat=len(gram) + 1)
+ elif blank == "forced":
+ # Must include a blank token in between.
+ onehot_vectors = [[1] * (len(gram) + 1)]
+ else:
+ raise ValueError(
+ "Invalid value specificed for blank. Must be in |optional|forced|none|"
+ )
+
+ for j in onehot_vectors:
+ new_array = []
+ for idx, oz in enumerate(j[:-1]):
+ if oz == 1 and gram[idx] != START_IDX:
+ new_array.append(num_tokens)
+ new_array.append(gram[idx])
+ if j[-1] == 1 and gram[-1] != END_IDX:
+ new_array.append(num_tokens)
+ for n in range(maxorder):
+ for e in range(n, len(new_array)):
+ cur_gram = tuple(new_array[e - n : e + 1])
+ if num_tokens in cur_gram and cur_gram not in blank_grams:
+ pruned_ngrams[n].append(cur_gram)
+ blank_grams[cur_gram] = True
+
+ return pruned_ngrams
+
+
+def add_self_loops(pruned_ngrams: List) -> List:
+ """Adds self loops to the ngrams."""
+ maxorder = len(pruned_ngrams)
+
+ # Use dict for fast search.
+ all_grams = set([gram for grams in pruned_ngrams for gram in grams])
+ for o in range(1, maxorder):
+ for gram in pruned_ngrams[o - 1]:
+ # Repeat one of the tokens.
+ for pos in range(len(gram)):
+ if gram[pos] == START_IDX or gram[pos] == END_IDX:
+ continue
+ new_gram = gram[:pos] + (gram[pos],) + gram[pos:]
+
+ if new_gram not in all_grams:
+ pruned_ngrams[o].append(new_gram)
+ all_grams.add(new_gram)
+ return pruned_ngrams
+
+
+def parse_lines(lines: List, lexicon: Path) -> List:
+ """Parses lines with a lexicon."""
+ with open(lexicon, "r") as f:
+ lex = (line.strip().split() for line in f)
+ lex = {line[0]: line[1:] for line in lex}
+ print(len(lex))
+ return [[t for w in line.split(WORDSEP) for t in lex[w]] for line in lines]
+
+
+@click.command()
+@click.option("--data_dir", type=str, default=None, help="Path to dataset root.")
+@click.option(
+ "--tokens", type=str, help="Path to token list (in order used with training)."
+)
+@click.option("--lexicon", type=str, default=None, help="Path to lexicon")
+@click.option(
+ "--prune",
+ nargs=2,
+ type=int,
+ help="Threshold values for prune unigrams, bigrams, etc.",
+)
+@click.option(
+ "--blank",
+ default=click.Choice(["none", "optional", "forced"]),
+ help="Specifies the usage of blank token"
+ "'none' - do not use blank token "
+ "'optional' - allow an optional blank inbetween tokens"
+ "'forced' - force a blank inbetween tokens (also referred to as garbage token)",
+)
+@click.option("--self_loops", is_flag=True, help="Add self loops for tokens")
+@click.option("--disable_backoff", is_flag=True, help="Disable backoff transitions")
+@click.option("--save_path", default=None, help="Path to save transition graph.")
+def cli(
+ data_dir: str,
+ tokens: str,
+ lexicon: str,
+ prune: List[int],
+ blank: str,
+ self_loops: bool,
+ disable_backoff: bool,
+ save_path: str,
+) -> None:
+ """CLI for creating the transitions."""
+ logger.info(f"Building {len(prune)}-gram transition models.")
+
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+
+ # Build table of counts and the back-off if below threshold.
+ with open(data_dir / "train.txt", "r") as f:
+ lines = [line.strip() for line in f]
+
+ with open(data_dir / tokens, "r") as f:
+ tokens = [line.strip() for line in f]
+
+ if lexicon is not None:
+ lexicon = data_dir / lexicon
+ lines = parse_lines(lines, lexicon)
+
+ tokens_to_idx = {t: e for e, t in enumerate(tokens)}
+
+ ngram = len(prune)
+
+ logger.info("Counting data...")
+ ngrams = count_ngrams(lines, ngram, tokens_to_idx)
+
+ pruned_ngrams = prune_ngrams(ngrams, prune)
+
+ for n in range(ngram):
+ logger.info(f"Kept {len(pruned_ngrams[n])} of {len(ngrams[n])} {n + 1}-grams")
+
+ if blank == "none":
+ pruned_ngrams = add_blank_grams(pruned_ngrams, len(tokens_to_idx), blank)
+
+ if self_loops:
+ pruned_ngrams = add_self_loops(pruned_ngrams)
+
+ logger.info("Building graph from pruned ngrams...")
+ graph = build_graph(pruned_ngrams, disable_backoff)
+ logger.info(f"Graph has {graph.num_arcs()} arcs and {graph.num_nodes()} nodes.")
+
+ save_path = str(data_dir / save_path)
+
+ logger.info(f"Saving graph to {save_path}")
+ gtn.save(save_path, graph)
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/text_recognizer/data/utils/download_utils.py b/text_recognizer/data/utils/download_utils.py
new file mode 100644
index 0000000..a5a5360
--- /dev/null
+++ b/text_recognizer/data/utils/download_utils.py
@@ -0,0 +1,73 @@
+"""Util functions for downloading datasets."""
+import hashlib
+from pathlib import Path
+from typing import Dict, Optional
+from urllib.request import urlretrieve
+
+from loguru import logger as log
+from tqdm import tqdm
+
+
+def _compute_sha256(filename: Path) -> str:
+ """Returns the SHA256 checksum of a file."""
+ with filename.open(mode="rb") as f:
+ return hashlib.sha256(f.read()).hexdigest()
+
+
+class TqdmUpTo(tqdm):
+ """TQDM progress bar when downloading files.
+
+ From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
+
+ """
+
+ def update_to(
+ self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
+ ) -> None:
+ """Updates the progress bar.
+
+ Args:
+ blocks (int): Number of blocks transferred so far. Defaults to 1.
+ block_size (int): Size of each block, in tqdm units. Defaults to 1.
+ total_size (Optional[int]): Total size in tqdm units. Defaults to None.
+ """
+ if total_size is not None:
+ self.total = total_size
+ self.update(blocks * block_size - self.n)
+
+
+def _download_url(url: str, filename: str) -> None:
+ """Downloads a file from url to filename, with a progress bar."""
+ with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
+ urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
+
+
+def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
+ """Downloads dataset using a metadata file.
+
+ Args:
+ metadata (Dict): A metadata file of the dataset.
+ dl_dir (Path): Download directory for the dataset.
+
+ Returns:
+ Optional[Path]: Returns filename if dataset is downloaded, None if it already
+ exists.
+
+ Raises:
+ ValueError: If the SHA-256 value is not the same between the dataset and
+ the metadata file.
+
+ """
+ dl_dir.mkdir(parents=True, exist_ok=True)
+ filename = dl_dir / metadata["filename"]
+ if filename.exists():
+ return
+ log.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
+ _download_url(metadata["url"], filename)
+ log.info("Computing the SHA-256...")
+ sha256 = _compute_sha256(filename)
+ if sha256 != metadata["sha256"]:
+ raise ValueError(
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )
+ return filename
diff --git a/text_recognizer/data/utils/iam_preprocessor.py b/text_recognizer/data/utils/iam_preprocessor.py
new file mode 100644
index 0000000..60ecff1
--- /dev/null
+++ b/text_recognizer/data/utils/iam_preprocessor.py
@@ -0,0 +1,209 @@
+"""Preprocessor for extracting word letters from the IAM dataset.
+
+The code is mostly stolen from:
+ https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
+"""
+import collections
+import itertools
+from pathlib import Path
+import re
+from typing import List, Optional, Set, Union
+
+import click
+from loguru import logger as log
+import torch
+
+
+def load_metadata(
+ data_dir: Path, wordsep: str, use_words: bool = False
+) -> collections.defaultdict:
+ """Loads IAM metadata and returns it as a dictionary."""
+ forms = collections.defaultdict(list)
+ filename = "words.txt" if use_words else "lines.txt"
+
+ with open(data_dir / "ascii" / filename, "r") as f:
+ lines = (line.strip().split() for line in f if line[0] != "#")
+ for line in lines:
+ # Skip word segmentation errors.
+ if use_words and line[1] == "err":
+ continue
+ text = " ".join(line[8:])
+
+ # Remove garbage tokens:
+ text = text.replace("#", "")
+
+ # Swap word sep form | to wordsep
+ text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep)
+ form_key = "-".join(line[0].split("-")[:2])
+ line_key = "-".join(line[0].split("-")[:3])
+ box_idx = 4 - use_words
+ box = tuple(int(val) for val in line[box_idx : box_idx + 4])
+ forms[form_key].append({"key": line_key, "box": box, "text": text})
+ return forms
+
+
+class Preprocessor:
+ """A preprocessor for the IAM dataset."""
+
+ def __init__(
+ self,
+ data_dir: Union[str, Path],
+ num_features: int,
+ tokens_path: Optional[Union[str, Path]] = None,
+ lexicon_path: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ special_tokens: Optional[Set[str]] = None,
+ ) -> None:
+ self.wordsep = "▁"
+ self._use_word = use_words
+ self._prepend_wordsep = prepend_wordsep
+ self.special_tokens = special_tokens if special_tokens is not None else None
+ self.data_dir = Path(data_dir)
+ self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
+
+ # Load the set of graphemes:
+ graphemes = set()
+ for _, form in self.forms.items():
+ for line in form:
+ graphemes.update(line["text"].lower())
+ self.graphemes = sorted(graphemes)
+
+ # Build the token-to-index and index-to-token maps.
+ if tokens_path is not None:
+ with open(tokens_path, "r") as f:
+ self.tokens = [line.strip() for line in f]
+ else:
+ self.tokens = self.graphemes
+
+ if lexicon_path is not None:
+ with open(lexicon_path, "r") as f:
+ lexicon = (line.strip().split() for line in f)
+ lexicon = {line[0]: line[1:] for line in lexicon}
+ self.lexicon = lexicon
+ else:
+ self.lexicon = None
+
+ if self.special_tokens is not None:
+ special_tokens_ = (*self.special_tokens, "#", "*")
+ self.tokens += special_tokens_
+ self.graphemes += special_tokens_
+
+ self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
+ self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
+ self.num_features = num_features
+ self.text = []
+
+ @property
+ def num_tokens(self) -> int:
+ """Returns the number or tokens."""
+ return len(self.tokens)
+
+ @property
+ def use_words(self) -> bool:
+ """If words are used."""
+ return self._use_word
+
+ def extract_train_text(self) -> None:
+ """Extracts training text."""
+ keys = []
+ with open(self.data_dir / "task" / "trainset.txt") as f:
+ keys.extend((line.strip() for line in f))
+
+ for _, examples in self.forms.items():
+ for example in examples:
+ if example["key"] not in keys:
+ continue
+ self.text.append(example["text"].lower())
+
+ def _to_index(self, line: str) -> torch.LongTensor:
+ if self.special_tokens is not None and line in self.special_tokens:
+ return torch.LongTensor([self.tokens_to_index[line]])
+ token_to_index = self.graphemes_to_index
+ if self.lexicon is not None:
+ if len(line) > 0:
+ # If the word is not found in the lexicon, fall back to letters.
+ tokens = [
+ t
+ for w in line.split(self.wordsep)
+ for t in self.lexicon.get(w, self.wordsep + w)
+ ]
+ token_to_index = self.tokens_to_index
+ if self._prepend_wordsep:
+ tokens = itertools.chain([self.wordsep], tokens)
+ return torch.LongTensor([token_to_index[t] for t in tokens])
+
+ def to_index(self, line: str) -> torch.LongTensor:
+ """Converts text to a tensor of indices."""
+ if self.special_tokens is not None:
+ pattern = f"({'|'.join(self.special_tokens)})"
+ lines = list(filter(None, re.split(pattern, line)))
+ return torch.cat([self._to_index(line) for line in lines])
+ return self._to_index(line)
+
+ def to_text(self, indices: List[int]) -> str:
+ """Converts indices to text."""
+ # Roughly the inverse of `to_index`
+ encoding = self.graphemes
+ if self.lexicon is not None:
+ encoding = self.tokens
+ return self._post_process(encoding[i] for i in indices)
+
+ def tokens_to_text(self, indices: List[int]) -> str:
+ """Converts tokens to text."""
+ return self._post_process(self.tokens[i] for i in indices)
+
+ def _post_process(self, indices: List[int]) -> str:
+ """A list join."""
+ return "".join(indices).strip(self.wordsep)
+
+
+@click.command()
+@click.option("--data_dir", type=str, default=None, help="Path to iam dataset")
+@click.option(
+ "--use_words", is_flag=True, help="Load word segmented dataset instead of lines"
+)
+@click.option(
+ "--save_text", type=str, default=None, help="Path to save parsed train text"
+)
+@click.option("--save_tokens", type=str, default=None, help="Path to save tokens")
+def cli(
+ data_dir: Optional[str],
+ use_words: bool,
+ save_text: Optional[str],
+ save_tokens: Optional[str],
+) -> None:
+ """CLI for extracting text data from the iam dataset."""
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[2]
+ / "data"
+ / "downloaded"
+ / "iam"
+ / "iamdb"
+ )
+ log.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+
+ preprocessor = Preprocessor(data_dir, 64, use_words=use_words)
+ preprocessor.extract_train_text()
+
+ processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
+ log.debug(f"Saving processed files at: {processed_dir}")
+
+ if save_text is not None:
+ log.info("Saving training text")
+ with open(processed_dir / save_text, "w") as f:
+ f.write("\n".join(t for t in preprocessor.text))
+
+ if save_tokens is not None:
+ log.info("Saving tokens")
+ with open(processed_dir / save_tokens, "w") as f:
+ f.write("\n".join(preprocessor.tokens))
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/text_recognizer/data/utils/image_utils.py b/text_recognizer/data/utils/image_utils.py
new file mode 100644
index 0000000..c2b8915
--- /dev/null
+++ b/text_recognizer/data/utils/image_utils.py
@@ -0,0 +1,49 @@
+"""Image util functions for loading and saving images."""
+from pathlib import Path
+from typing import Union
+from urllib.request import urlopen
+
+import cv2
+import numpy as np
+from PIL import Image
+
+
+def read_image_pil(image_uri: Union[Path, str], grayscale: bool = False) -> Image:
+ """Return PIL image."""
+ image = Image.open(image_uri)
+ if grayscale:
+ image = image.convert("L")
+ return image
+
+
+def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.array:
+ """Read image_uri."""
+
+ if isinstance(image_uri, str):
+ image_uri = Path(image_uri)
+
+ def read_image_from_filename(image_filename: Path, imread_flag: int) -> np.array:
+ return cv2.imread(str(image_filename), imread_flag)
+
+ def read_image_from_url(image_url: Path, imread_flag: int) -> np.array:
+ url_response = urlopen(str(image_url)) # nosec
+ image_array = np.array(bytearray(url_response.read()), dtype=np.uint8)
+ return cv2.imdecode(image_array, imread_flag)
+
+ imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
+ image = None
+
+ if image_uri.exists():
+ image = read_image_from_filename(image_uri, imread_flag)
+ else:
+ image = read_image_from_url(image_uri, imread_flag)
+
+ if image is None:
+ raise ValueError(f"Could not load image at {image_uri}")
+
+ return image
+
+
+def write_image(image: np.ndarray, filename: Union[Path, str]) -> None:
+ """Write image to file."""
+ cv2.imwrite(str(filename), image)
diff --git a/text_recognizer/data/utils/make_wordpieces.py b/text_recognizer/data/utils/make_wordpieces.py
new file mode 100644
index 0000000..8e53815
--- /dev/null
+++ b/text_recognizer/data/utils/make_wordpieces.py
@@ -0,0 +1,112 @@
+"""Creates word pieces from a text file.
+
+Most code stolen from:
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/scripts/make_wordpieces.py
+
+"""
+import io
+from pathlib import Path
+from typing import List, Optional, Union
+
+import click
+from loguru import logger as log
+import sentencepiece as spm
+
+
+def iamdb_pieces(
+ data_dir: Path, text_file: str, num_pieces: int, output_prefix: str
+) -> None:
+ """Creates word pieces from the iamdb train text."""
+ # Load training text.
+ with open(data_dir / text_file, "r") as f:
+ text = [line.strip() for line in f]
+
+ sp = train_spm_model(
+ iter(text),
+ num_pieces + 1, # To account for <unk>
+ user_symbols=["/"], # added so token is in the output set
+ )
+
+ vocab = sorted(set(w for t in text for w in t.split("▁") if w))
+ if "move" not in vocab:
+ raise RuntimeError("`MOVE` not in vocab")
+
+ save_pieces(sp, num_pieces, data_dir, output_prefix, vocab)
+
+
+def train_spm_model(
+ sentences: iter, vocab_size: int, user_symbols: Union[str, List[str]] = ""
+) -> spm.SentencePieceProcessor:
+ """Trains the sentence piece model."""
+ model = io.BytesIO()
+ spm.SentencePieceTrainer.train(
+ sentence_iterator=sentences,
+ model_writer=model,
+ vocab_size=vocab_size,
+ bos_id=-1,
+ eos_id=-1,
+ character_coverage=1.0,
+ user_defined_symbols=user_symbols,
+ )
+ sp = spm.SentencePieceProcessor(model_proto=model.getvalue())
+ return sp
+
+
+def save_pieces(
+ sp: spm.SentencePieceProcessor,
+ num_pieces: int,
+ data_dir: Path,
+ output_prefix: str,
+ vocab: set,
+) -> None:
+ """Saves word pieces to disk."""
+ log.info(f"Generating word piece list of size {num_pieces}.")
+ pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)]
+ log.info(f"Encoding vocabulary of size {len(vocab)}.")
+ encoded_vocab = [sp.encode_as_pieces(v) for v in vocab]
+
+ # Save pieces to file.
+ with open(data_dir / f"{output_prefix}_tokens_{num_pieces}.txt", "w") as f:
+ f.write("\n".join(pieces))
+
+ # Save lexicon to a file.
+ with open(data_dir / f"{output_prefix}_lex_{num_pieces}.txt", "w") as f:
+ for v, p in zip(vocab, encoded_vocab):
+ f.write(f"{v} {' '.join(p)}\n")
+
+
+@click.command()
+@click.option("--data_dir", type=str, default=None, help="Path to processed iam dir.")
+@click.option(
+ "--text_file", type=str, default=None, help="Name of sentence piece training text."
+)
+@click.option(
+ "--output_prefix",
+ type=str,
+ default="word_pieces",
+ help="Prefix name to store tokens and lexicon.",
+)
+@click.option("--num_pieces", type=int, default=1000, help="Number of word pieces.")
+def cli(
+ data_dir: Optional[str],
+ text_file: Optional[str],
+ output_prefix: Optional[str],
+ num_pieces: Optional[int],
+) -> None:
+ """CLI for training the sentence piece model."""
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ )
+ log.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+
+ iamdb_pieces(data_dir, text_file, num_pieces, output_prefix)
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/text_recognizer/data/utils/sentence_generator.py b/text_recognizer/data/utils/sentence_generator.py
new file mode 100644
index 0000000..8567e6d
--- /dev/null
+++ b/text_recognizer/data/utils/sentence_generator.py
@@ -0,0 +1,89 @@
+"""Downloading the Brown corpus with NLTK for sentence generating."""
+import itertools
+import re
+import string
+from typing import Optional
+
+import nltk
+from nltk.corpus.reader.util import ConcatenatedCorpusView
+import numpy as np
+
+from text_recognizer.data.base_data_module import BaseDataModule
+
+NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk"
+
+
+class SentenceGenerator:
+ """Generates text sentences using the Brown corpus."""
+
+ def __init__(self, max_length: Optional[int] = None) -> None:
+ """Loads the corpus and sets word start indices."""
+ self.corpus = brown_corpus()
+ self.word_start_indices = [0] + [
+ _.start(0) + 1 for _ in re.finditer(" ", self.corpus)
+ ]
+ self.max_length = max_length
+
+ def generate(self, max_length: Optional[int] = None) -> str:
+ r"""Generates a word or sentences from the Brown corpus.
+
+ Sample a string from the Brown corpus of length at least one word and at most
+ max_length, padding to max_length with the '_' characters if sentence is
+ shorter.
+
+ Args:
+ max_length (Optional[int]): The maximum number of characters in the sentence.
+ Defaults to None.
+
+ Returns:
+ str: A sentence from the Brown corpus.
+
+ Raises:
+ ValueError: If max_length was not specified at initialization and not
+ given as an argument.
+
+ RuntimeError: If a valid string was not generated.
+
+ """
+ if max_length is None:
+ max_length = self.max_length
+ if max_length is None:
+ raise ValueError(
+ "Must provide max_length to this method or when making this object."
+ )
+
+ for _ in range(10):
+ try:
+ index = np.random.randint(0, len(self.word_start_indices) - 1)
+ start_index = self.word_start_indices[index]
+ end_index_candidates = []
+ for index in range(index + 1, len(self.word_start_indices)):
+ if self.word_start_indices[index] - start_index > max_length:
+ break
+ end_index_candidates.append(self.word_start_indices[index])
+ end_index = np.random.choice(end_index_candidates)
+ sampled_text = self.corpus[start_index:end_index].strip()
+ return sampled_text
+ except Exception:
+ pass
+ raise RuntimeError("Was not able to generate a valid string")
+
+
+def brown_corpus() -> str:
+ """Returns a single string with the Brown corpus with all punctuations stripped."""
+ sentences = load_nltk_brown_corpus()
+ corpus = " ".join(itertools.chain.from_iterable(sentences))
+ corpus = corpus.translate({ord(c): None for c in string.punctuation})
+ corpus = re.sub(" +", " ", corpus)
+ return corpus
+
+
+def load_nltk_brown_corpus() -> ConcatenatedCorpusView:
+ """Load the Brown corpus using the NLTK library."""
+ nltk.data.path.append(NLTK_DATA_DIRNAME)
+ try:
+ nltk.corpus.brown.sents()
+ except LookupError:
+ NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ nltk.download("brown", download_dir=NLTK_DATA_DIRNAME)
+ return nltk.corpus.brown.sents()