diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-02-06 20:00:29 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-02-06 20:00:29 +0100 |
commit | f5ed9049064d18b9fb74c44be0c589dce817865e (patch) | |
tree | 183a074812ba655801b3af8392a2059f8bf3bc8f /text_recognizer/data | |
parent | 76098a8da9731dd7cba1a7334ad9ae8a2acc760e (diff) |
chore: remove word pieces
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/mappings/word_piece.py | 72 | ||||
-rw-r--r-- | text_recognizer/data/transforms/word_piece.py | 45 | ||||
-rw-r--r-- | text_recognizer/data/utils/make_wordpieces.py | 112 |
3 files changed, 0 insertions, 229 deletions
diff --git a/text_recognizer/data/mappings/word_piece.py b/text_recognizer/data/mappings/word_piece.py deleted file mode 100644 index 861c3bd..0000000 --- a/text_recognizer/data/mappings/word_piece.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Word piece mapping.""" -from typing import List, Set, Union - -import torch -from torch import Tensor - -from text_recognizer.data.mappings.emnist import EmnistMapping -from text_recognizer.data.utils.iam_preprocessor import Preprocessor - - -class WordPieceMapping(EmnistMapping): - """Word piece mapping.""" - - def __init__( - self, - num_features: int = 1000, - tokens: str = "iamdb_1kwp_tokens_1000.txt", - lexicon: str = "iamdb_1kwp_lex_1000.txt", - use_words: bool = False, - prepend_wordsep: bool = False, - special_tokens: Set[str] = {"<s>", "<e>", "<p>"}, - extra_symbols: Set[str] = {"\n"}, - ) -> None: - super().__init__(extra_symbols=extra_symbols) - special_tokens = set(special_tokens) - if self.extra_symbols is not None: - special_tokens = special_tokens | set(extra_symbols) - - self.wordpiece_processor = Preprocessor( - num_features=num_features, - tokens=tokens, - lexicon=lexicon, - use_words=use_words, - prepend_wordsep=prepend_wordsep, - special_tokens=special_tokens, - ) - - def __len__(self) -> int: - """Return number of word pieces.""" - return len(self.wordpiece_processor.tokens) - - def get_token(self, index: Union[int, Tensor]) -> str: - """Returns token for index.""" - if (index := int(index)) <= self.wordpiece_processor.num_tokens: - return self.wordpiece_processor.tokens[index] - raise KeyError(f"Index ({index}) not in mapping.") - - def get_index(self, token: str) -> Tensor: - """Returns index of token.""" - if token in self.wordpiece_processor.tokens: - return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]]) - raise KeyError(f"Token ({token}) not found in inverse mapping.") - - def get_text(self, indices: Union[List[int], Tensor]) -> str: - """Returns text from indices.""" - if isinstance(indices, Tensor): - indices = indices.tolist() - return self.wordpiece_processor.to_text(indices) - - def get_indices(self, text: str) -> Tensor: - """Returns indices of text.""" - return self.wordpiece_processor.to_index(text) - - def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor: - """Returns word pieces indices from emnist indices.""" - text = "".join([self.mapping[i] for i in x]) - text = text.lower().replace(" ", "▁") - return torch.LongTensor(self.wordpiece_processor.to_index(text)) - - def __getitem__(self, x: Union[int, Tensor]) -> str: - """Returns token for word piece index.""" - return self.get_token(x) diff --git a/text_recognizer/data/transforms/word_piece.py b/text_recognizer/data/transforms/word_piece.py deleted file mode 100644 index d805c7e..0000000 --- a/text_recognizer/data/transforms/word_piece.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Target transform for word pieces.""" -from typing import Optional, Sequence - -import torch -from torch import Tensor - -from text_recognizer.data.mappings.word_piece_mapping import WordPieceMapping - - -class WordPiece: - """Converts EMNIST indices to Word Piece indices.""" - - def __init__( - self, - num_features: int = 1000, - tokens: str = "iamdb_1kwp_tokens_1000.txt", - lexicon: str = "iamdb_1kwp_lex_1000.txt", - use_words: bool = False, - prepend_wordsep: bool = False, - special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), - extra_symbols: Optional[Sequence[str]] = ("\n",), - max_len: int = 451, - ) -> None: - self.mapping = WordPieceMapping( - num_features=num_features, - tokens=tokens, - lexicon=lexicon, - use_words=use_words, - prepend_wordsep=prepend_wordsep, - special_tokens=special_tokens, - extra_symbols=extra_symbols, - ) - self.max_len = max_len - - def __call__(self, x: Tensor) -> Tensor: - """Converts Emnist target tensor to Word piece target tensor.""" - y = self.mapping.emnist_to_wordpiece_indices(x) - if len(y) < self.max_len: - pad_len = self.max_len - len(y) - y = torch.cat( - (y, torch.LongTensor([self.mapping.get_index("<p>")] * pad_len)) - ) - else: - y = y[: self.max_len] - return y diff --git a/text_recognizer/data/utils/make_wordpieces.py b/text_recognizer/data/utils/make_wordpieces.py deleted file mode 100644 index 8e53815..0000000 --- a/text_recognizer/data/utils/make_wordpieces.py +++ /dev/null @@ -1,112 +0,0 @@ -"""Creates word pieces from a text file. - -Most code stolen from: - - https://github.com/facebookresearch/gtn_applications/blob/master/scripts/make_wordpieces.py - -""" -import io -from pathlib import Path -from typing import List, Optional, Union - -import click -from loguru import logger as log -import sentencepiece as spm - - -def iamdb_pieces( - data_dir: Path, text_file: str, num_pieces: int, output_prefix: str -) -> None: - """Creates word pieces from the iamdb train text.""" - # Load training text. - with open(data_dir / text_file, "r") as f: - text = [line.strip() for line in f] - - sp = train_spm_model( - iter(text), - num_pieces + 1, # To account for <unk> - user_symbols=["/"], # added so token is in the output set - ) - - vocab = sorted(set(w for t in text for w in t.split("▁") if w)) - if "move" not in vocab: - raise RuntimeError("`MOVE` not in vocab") - - save_pieces(sp, num_pieces, data_dir, output_prefix, vocab) - - -def train_spm_model( - sentences: iter, vocab_size: int, user_symbols: Union[str, List[str]] = "" -) -> spm.SentencePieceProcessor: - """Trains the sentence piece model.""" - model = io.BytesIO() - spm.SentencePieceTrainer.train( - sentence_iterator=sentences, - model_writer=model, - vocab_size=vocab_size, - bos_id=-1, - eos_id=-1, - character_coverage=1.0, - user_defined_symbols=user_symbols, - ) - sp = spm.SentencePieceProcessor(model_proto=model.getvalue()) - return sp - - -def save_pieces( - sp: spm.SentencePieceProcessor, - num_pieces: int, - data_dir: Path, - output_prefix: str, - vocab: set, -) -> None: - """Saves word pieces to disk.""" - log.info(f"Generating word piece list of size {num_pieces}.") - pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)] - log.info(f"Encoding vocabulary of size {len(vocab)}.") - encoded_vocab = [sp.encode_as_pieces(v) for v in vocab] - - # Save pieces to file. - with open(data_dir / f"{output_prefix}_tokens_{num_pieces}.txt", "w") as f: - f.write("\n".join(pieces)) - - # Save lexicon to a file. - with open(data_dir / f"{output_prefix}_lex_{num_pieces}.txt", "w") as f: - for v, p in zip(vocab, encoded_vocab): - f.write(f"{v} {' '.join(p)}\n") - - -@click.command() -@click.option("--data_dir", type=str, default=None, help="Path to processed iam dir.") -@click.option( - "--text_file", type=str, default=None, help="Name of sentence piece training text." -) -@click.option( - "--output_prefix", - type=str, - default="word_pieces", - help="Prefix name to store tokens and lexicon.", -) -@click.option("--num_pieces", type=int, default=1000, help="Number of word pieces.") -def cli( - data_dir: Optional[str], - text_file: Optional[str], - output_prefix: Optional[str], - num_pieces: Optional[int], -) -> None: - """CLI for training the sentence piece model.""" - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" - ) - log.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - - iamdb_pieces(data_dir, text_file, num_pieces, output_prefix) - - -if __name__ == "__main__": - cli() |