summaryrefslogtreecommitdiff
path: root/src/tasks/make_wordpieces.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/tasks/make_wordpieces.py')
-rw-r--r--src/tasks/make_wordpieces.py114
1 files changed, 114 insertions, 0 deletions
diff --git a/src/tasks/make_wordpieces.py b/src/tasks/make_wordpieces.py
new file mode 100644
index 0000000..f605920
--- /dev/null
+++ b/src/tasks/make_wordpieces.py
@@ -0,0 +1,114 @@
+"""Creates word pieces from a text file.
+
+Most code stolen from:
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/scripts/make_wordpieces.py
+
+"""
+import io
+from pathlib import Path
+from typing import List, Optional, Union
+
+import click
+from loguru import logger
+import sentencepiece as spm
+
+from text_recognizer.datasets.iam_preprocessor import load_metadata
+
+
+def iamdb_pieces(
+ data_dir: Path, text_file: str, num_pieces: int, output_prefix: str
+) -> None:
+ """Creates word pieces from the iamdb train text."""
+ # Load training text.
+ with open(data_dir / text_file, "r") as f:
+ text = [line.strip() for line in f]
+
+ sp = train_spm_model(
+ iter(text),
+ num_pieces + 1, # To account for <unk>
+ user_symbols=["/"], # added so token is in the output set
+ )
+
+ vocab = sorted(set(w for t in text for w in t.split("_") if w))
+ if "move" not in vocab:
+ raise RuntimeError("`MOVE` not in vocab")
+
+ save_pieces(sp, num_pieces, data_dir, output_prefix, vocab)
+
+
+def train_spm_model(
+ sentences: iter, vocab_size: int, user_symbols: Union[str, List[str]] = ""
+) -> spm.SentencePieceProcessor:
+ """Trains the sentence piece model."""
+ model = io.BytesIO()
+ spm.SentencePieceTrainer.train(
+ sentence_iterator=sentences,
+ model_writer=model,
+ vocab_size=vocab_size,
+ bos_id=-1,
+ eos_id=-1,
+ character_coverage=1.0,
+ user_defined_symbols=user_symbols,
+ )
+ sp = spm.SentencePieceProcessor(model_proto=model.getvalue())
+ return sp
+
+
+def save_pieces(
+ sp: spm.SentencePieceProcessor,
+ num_pieces: int,
+ data_dir: Path,
+ output_prefix: str,
+ vocab: set,
+) -> None:
+ """Saves word pieces to disk."""
+ logger.info(f"Generating word piece list of size {num_pieces}.")
+ pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)]
+ logger.info(f"Encoding vocabulary of size {len(vocab)}.")
+ encoded_vocab = [sp.encode_as_pieces(v) for v in vocab]
+
+ # Save pieces to file.
+ with open(data_dir / f"{output_prefix}_tokens_{num_pieces}.txt", "w") as f:
+ f.write("\n".join(pieces))
+
+ # Save lexicon to a file.
+ with open(data_dir / f"{output_prefix}_lex_{num_pieces}.txt", "w") as f:
+ for v, p in zip(vocab, encoded_vocab):
+ f.write(f"{v} {' '.join(p)}\n")
+
+
+@click.command()
+@click.option("--data_dir", type=str, default=None, help="Path to processed iam dir.")
+@click.option(
+ "--text_file", type=str, default=None, help="Name of sentence piece training text."
+)
+@click.option(
+ "--output_prefix",
+ type=str,
+ default="word_pieces",
+ help="Prefix name to store tokens and lexicon.",
+)
+@click.option("--num_pieces", type=int, default=1000, help="Number of word pieces.")
+def cli(
+ data_dir: Optional[str],
+ text_file: Optional[str],
+ output_prefix: Optional[str],
+ num_pieces: Optional[int],
+) -> None:
+ """CLI for training the sentence piece model."""
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+
+ iamdb_pieces(data_dir, text_file, num_pieces, output_prefix)
+
+
+if __name__ == "__main__":
+ cli()