summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-02-06 20:00:29 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-02-06 20:00:29 +0100
commitf5ed9049064d18b9fb74c44be0c589dce817865e (patch)
tree183a074812ba655801b3af8392a2059f8bf3bc8f /text_recognizer
parent76098a8da9731dd7cba1a7334ad9ae8a2acc760e (diff)
chore: remove word pieces
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/mappings/word_piece.py72
-rw-r--r--text_recognizer/data/transforms/word_piece.py45
-rw-r--r--text_recognizer/data/utils/make_wordpieces.py112
3 files changed, 0 insertions, 229 deletions
diff --git a/text_recognizer/data/mappings/word_piece.py b/text_recognizer/data/mappings/word_piece.py
deleted file mode 100644
index 861c3bd..0000000
--- a/text_recognizer/data/mappings/word_piece.py
+++ /dev/null
@@ -1,72 +0,0 @@
-"""Word piece mapping."""
-from typing import List, Set, Union
-
-import torch
-from torch import Tensor
-
-from text_recognizer.data.mappings.emnist import EmnistMapping
-from text_recognizer.data.utils.iam_preprocessor import Preprocessor
-
-
-class WordPieceMapping(EmnistMapping):
- """Word piece mapping."""
-
- def __init__(
- self,
- num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt",
- lexicon: str = "iamdb_1kwp_lex_1000.txt",
- use_words: bool = False,
- prepend_wordsep: bool = False,
- special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
- extra_symbols: Set[str] = {"\n"},
- ) -> None:
- super().__init__(extra_symbols=extra_symbols)
- special_tokens = set(special_tokens)
- if self.extra_symbols is not None:
- special_tokens = special_tokens | set(extra_symbols)
-
- self.wordpiece_processor = Preprocessor(
- num_features=num_features,
- tokens=tokens,
- lexicon=lexicon,
- use_words=use_words,
- prepend_wordsep=prepend_wordsep,
- special_tokens=special_tokens,
- )
-
- def __len__(self) -> int:
- """Return number of word pieces."""
- return len(self.wordpiece_processor.tokens)
-
- def get_token(self, index: Union[int, Tensor]) -> str:
- """Returns token for index."""
- if (index := int(index)) <= self.wordpiece_processor.num_tokens:
- return self.wordpiece_processor.tokens[index]
- raise KeyError(f"Index ({index}) not in mapping.")
-
- def get_index(self, token: str) -> Tensor:
- """Returns index of token."""
- if token in self.wordpiece_processor.tokens:
- return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
- raise KeyError(f"Token ({token}) not found in inverse mapping.")
-
- def get_text(self, indices: Union[List[int], Tensor]) -> str:
- """Returns text from indices."""
- if isinstance(indices, Tensor):
- indices = indices.tolist()
- return self.wordpiece_processor.to_text(indices)
-
- def get_indices(self, text: str) -> Tensor:
- """Returns indices of text."""
- return self.wordpiece_processor.to_index(text)
-
- def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
- """Returns word pieces indices from emnist indices."""
- text = "".join([self.mapping[i] for i in x])
- text = text.lower().replace(" ", "▁")
- return torch.LongTensor(self.wordpiece_processor.to_index(text))
-
- def __getitem__(self, x: Union[int, Tensor]) -> str:
- """Returns token for word piece index."""
- return self.get_token(x)
diff --git a/text_recognizer/data/transforms/word_piece.py b/text_recognizer/data/transforms/word_piece.py
deleted file mode 100644
index d805c7e..0000000
--- a/text_recognizer/data/transforms/word_piece.py
+++ /dev/null
@@ -1,45 +0,0 @@
-"""Target transform for word pieces."""
-from typing import Optional, Sequence
-
-import torch
-from torch import Tensor
-
-from text_recognizer.data.mappings.word_piece_mapping import WordPieceMapping
-
-
-class WordPiece:
- """Converts EMNIST indices to Word Piece indices."""
-
- def __init__(
- self,
- num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt",
- lexicon: str = "iamdb_1kwp_lex_1000.txt",
- use_words: bool = False,
- prepend_wordsep: bool = False,
- special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n",),
- max_len: int = 451,
- ) -> None:
- self.mapping = WordPieceMapping(
- num_features=num_features,
- tokens=tokens,
- lexicon=lexicon,
- use_words=use_words,
- prepend_wordsep=prepend_wordsep,
- special_tokens=special_tokens,
- extra_symbols=extra_symbols,
- )
- self.max_len = max_len
-
- def __call__(self, x: Tensor) -> Tensor:
- """Converts Emnist target tensor to Word piece target tensor."""
- y = self.mapping.emnist_to_wordpiece_indices(x)
- if len(y) < self.max_len:
- pad_len = self.max_len - len(y)
- y = torch.cat(
- (y, torch.LongTensor([self.mapping.get_index("<p>")] * pad_len))
- )
- else:
- y = y[: self.max_len]
- return y
diff --git a/text_recognizer/data/utils/make_wordpieces.py b/text_recognizer/data/utils/make_wordpieces.py
deleted file mode 100644
index 8e53815..0000000
--- a/text_recognizer/data/utils/make_wordpieces.py
+++ /dev/null
@@ -1,112 +0,0 @@
-"""Creates word pieces from a text file.
-
-Most code stolen from:
-
- https://github.com/facebookresearch/gtn_applications/blob/master/scripts/make_wordpieces.py
-
-"""
-import io
-from pathlib import Path
-from typing import List, Optional, Union
-
-import click
-from loguru import logger as log
-import sentencepiece as spm
-
-
-def iamdb_pieces(
- data_dir: Path, text_file: str, num_pieces: int, output_prefix: str
-) -> None:
- """Creates word pieces from the iamdb train text."""
- # Load training text.
- with open(data_dir / text_file, "r") as f:
- text = [line.strip() for line in f]
-
- sp = train_spm_model(
- iter(text),
- num_pieces + 1, # To account for <unk>
- user_symbols=["/"], # added so token is in the output set
- )
-
- vocab = sorted(set(w for t in text for w in t.split("▁") if w))
- if "move" not in vocab:
- raise RuntimeError("`MOVE` not in vocab")
-
- save_pieces(sp, num_pieces, data_dir, output_prefix, vocab)
-
-
-def train_spm_model(
- sentences: iter, vocab_size: int, user_symbols: Union[str, List[str]] = ""
-) -> spm.SentencePieceProcessor:
- """Trains the sentence piece model."""
- model = io.BytesIO()
- spm.SentencePieceTrainer.train(
- sentence_iterator=sentences,
- model_writer=model,
- vocab_size=vocab_size,
- bos_id=-1,
- eos_id=-1,
- character_coverage=1.0,
- user_defined_symbols=user_symbols,
- )
- sp = spm.SentencePieceProcessor(model_proto=model.getvalue())
- return sp
-
-
-def save_pieces(
- sp: spm.SentencePieceProcessor,
- num_pieces: int,
- data_dir: Path,
- output_prefix: str,
- vocab: set,
-) -> None:
- """Saves word pieces to disk."""
- log.info(f"Generating word piece list of size {num_pieces}.")
- pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)]
- log.info(f"Encoding vocabulary of size {len(vocab)}.")
- encoded_vocab = [sp.encode_as_pieces(v) for v in vocab]
-
- # Save pieces to file.
- with open(data_dir / f"{output_prefix}_tokens_{num_pieces}.txt", "w") as f:
- f.write("\n".join(pieces))
-
- # Save lexicon to a file.
- with open(data_dir / f"{output_prefix}_lex_{num_pieces}.txt", "w") as f:
- for v, p in zip(vocab, encoded_vocab):
- f.write(f"{v} {' '.join(p)}\n")
-
-
-@click.command()
-@click.option("--data_dir", type=str, default=None, help="Path to processed iam dir.")
-@click.option(
- "--text_file", type=str, default=None, help="Name of sentence piece training text."
-)
-@click.option(
- "--output_prefix",
- type=str,
- default="word_pieces",
- help="Prefix name to store tokens and lexicon.",
-)
-@click.option("--num_pieces", type=int, default=1000, help="Number of word pieces.")
-def cli(
- data_dir: Optional[str],
- text_file: Optional[str],
- output_prefix: Optional[str],
- num_pieces: Optional[int],
-) -> None:
- """CLI for training the sentence piece model."""
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
- )
- log.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
-
- iamdb_pieces(data_dir, text_file, num_pieces, output_prefix)
-
-
-if __name__ == "__main__":
- cli()