summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/transforms')
-rw-r--r--text_recognizer/data/transforms/word_piece.py45
1 files changed, 0 insertions, 45 deletions
diff --git a/text_recognizer/data/transforms/word_piece.py b/text_recognizer/data/transforms/word_piece.py
deleted file mode 100644
index d805c7e..0000000
--- a/text_recognizer/data/transforms/word_piece.py
+++ /dev/null
@@ -1,45 +0,0 @@
-"""Target transform for word pieces."""
-from typing import Optional, Sequence
-
-import torch
-from torch import Tensor
-
-from text_recognizer.data.mappings.word_piece_mapping import WordPieceMapping
-
-
-class WordPiece:
- """Converts EMNIST indices to Word Piece indices."""
-
- def __init__(
- self,
- num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt",
- lexicon: str = "iamdb_1kwp_lex_1000.txt",
- use_words: bool = False,
- prepend_wordsep: bool = False,
- special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n",),
- max_len: int = 451,
- ) -> None:
- self.mapping = WordPieceMapping(
- num_features=num_features,
- tokens=tokens,
- lexicon=lexicon,
- use_words=use_words,
- prepend_wordsep=prepend_wordsep,
- special_tokens=special_tokens,
- extra_symbols=extra_symbols,
- )
- self.max_len = max_len
-
- def __call__(self, x: Tensor) -> Tensor:
- """Converts Emnist target tensor to Word piece target tensor."""
- y = self.mapping.emnist_to_wordpiece_indices(x)
- if len(y) < self.max_len:
- pad_len = self.max_len - len(y)
- y = torch.cat(
- (y, torch.LongTensor([self.mapping.get_index("<p>")] * pad_len))
- )
- else:
- y = y[: self.max_len]
- return y