summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/data/mappings/word_piece_mapping.py28
-rw-r--r--text_recognizer/data/transforms/word_piece.py2
-rw-r--r--text_recognizer/data/utils/iam_preprocessor.py28
3 files changed, 22 insertions, 36 deletions
diff --git a/text_recognizer/data/mappings/word_piece_mapping.py b/text_recognizer/data/mappings/word_piece_mapping.py
index 6f1790e..f9e4e7a 100644
--- a/text_recognizer/data/mappings/word_piece_mapping.py
+++ b/text_recognizer/data/mappings/word_piece_mapping.py
@@ -15,7 +15,6 @@ class WordPieceMapping(EmnistMapping):
def __init__(
self,
- data_dir: Optional[Path] = None,
num_features: int = 1000,
tokens: str = "iamdb_1kwp_tokens_1000.txt",
lexicon: str = "iamdb_1kwp_lex_1000.txt",
@@ -25,37 +24,14 @@ class WordPieceMapping(EmnistMapping):
extra_symbols: Set[str] = {"\n"},
) -> None:
super().__init__(extra_symbols=extra_symbols)
- self.data_dir = (
- (
- Path(__file__).resolve().parents[3]
- / "data"
- / "downloaded"
- / "iam"
- / "iamdb"
- )
- if data_dir is None
- else Path(data_dir)
- )
- log.debug(f"Using data dir: {self.data_dir}")
- if not self.data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
-
- processed_path = (
- Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
- )
-
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
special_tokens = set(special_tokens)
if self.extra_symbols is not None:
special_tokens = special_tokens | set(extra_symbols)
self.wordpiece_processor = Preprocessor(
- data_dir=self.data_dir,
num_features=num_features,
- tokens_path=tokens_path,
- lexicon_path=lexicon_path,
+ tokens=tokens,
+ lexicon=lexicon,
use_words=use_words,
prepend_wordsep=prepend_wordsep,
special_tokens=special_tokens,
diff --git a/text_recognizer/data/transforms/word_piece.py b/text_recognizer/data/transforms/word_piece.py
index 6bf5472..69f0ce1 100644
--- a/text_recognizer/data/transforms/word_piece.py
+++ b/text_recognizer/data/transforms/word_piece.py
@@ -16,7 +16,6 @@ class WordPiece:
num_features: int = 1000,
tokens: str = "iamdb_1kwp_tokens_1000.txt",
lexicon: str = "iamdb_1kwp_lex_1000.txt",
- data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
@@ -24,7 +23,6 @@ class WordPiece:
max_len: int = 451,
) -> None:
self.mapping = WordPieceMapping(
- data_dir=data_dir,
num_features=num_features,
tokens=tokens,
lexicon=lexicon,
diff --git a/text_recognizer/data/utils/iam_preprocessor.py b/text_recognizer/data/utils/iam_preprocessor.py
index 60ecff1..4f95007 100644
--- a/text_recognizer/data/utils/iam_preprocessor.py
+++ b/text_recognizer/data/utils/iam_preprocessor.py
@@ -47,19 +47,28 @@ class Preprocessor:
def __init__(
self,
- data_dir: Union[str, Path],
num_features: int,
- tokens_path: Optional[Union[str, Path]] = None,
- lexicon_path: Optional[Union[str, Path]] = None,
+ tokens: Optional[str] = None,
+ lexicon: Optional[str] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
special_tokens: Optional[Set[str]] = None,
) -> None:
+ self.data_dir = (
+ Path(__file__).resolve().parents[3]
+ / "data"
+ / "downloaded"
+ / "iam"
+ / "iamdb"
+ )
+ log.debug(f"Using data dir: {self.data_dir}")
+ if not self.data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
+
self.wordsep = "▁"
self._use_word = use_words
self._prepend_wordsep = prepend_wordsep
self.special_tokens = special_tokens if special_tokens is not None else None
- self.data_dir = Path(data_dir)
self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
# Load the set of graphemes:
@@ -70,14 +79,17 @@ class Preprocessor:
self.graphemes = sorted(graphemes)
# Build the token-to-index and index-to-token maps.
- if tokens_path is not None:
- with open(tokens_path, "r") as f:
+ processed_path = (
+ Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
+ )
+ if tokens is not None:
+ with open(processed_path / tokens, "r") as f:
self.tokens = [line.strip() for line in f]
else:
self.tokens = self.graphemes
- if lexicon_path is not None:
- with open(lexicon_path, "r") as f:
+ if lexicon is not None:
+ with open(processed_path / lexicon, "r") as f:
lexicon = (line.strip().split() for line in f)
lexicon = {line[0]: line[1:] for line in lexicon}
self.lexicon = lexicon