From da29c1cf4d062087f1b29dc9402ee6384203b690 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 25 Oct 2021 22:31:36 +0200 Subject: Make data dir static in preprocessor and word piece --- .../data/mappings/word_piece_mapping.py | 28 ++-------------------- text_recognizer/data/transforms/word_piece.py | 2 -- text_recognizer/data/utils/iam_preprocessor.py | 28 +++++++++++++++------- 3 files changed, 22 insertions(+), 36 deletions(-) diff --git a/text_recognizer/data/mappings/word_piece_mapping.py b/text_recognizer/data/mappings/word_piece_mapping.py index 6f1790e..f9e4e7a 100644 --- a/text_recognizer/data/mappings/word_piece_mapping.py +++ b/text_recognizer/data/mappings/word_piece_mapping.py @@ -15,7 +15,6 @@ class WordPieceMapping(EmnistMapping): def __init__( self, - data_dir: Optional[Path] = None, num_features: int = 1000, tokens: str = "iamdb_1kwp_tokens_1000.txt", lexicon: str = "iamdb_1kwp_lex_1000.txt", @@ -25,37 +24,14 @@ class WordPieceMapping(EmnistMapping): extra_symbols: Set[str] = {"\n"}, ) -> None: super().__init__(extra_symbols=extra_symbols) - self.data_dir = ( - ( - Path(__file__).resolve().parents[3] - / "data" - / "downloaded" - / "iam" - / "iamdb" - ) - if data_dir is None - else Path(data_dir) - ) - log.debug(f"Using data dir: {self.data_dir}") - if not self.data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}") - - processed_path = ( - Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" - ) - - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - special_tokens = set(special_tokens) if self.extra_symbols is not None: special_tokens = special_tokens | set(extra_symbols) self.wordpiece_processor = Preprocessor( - data_dir=self.data_dir, num_features=num_features, - tokens_path=tokens_path, - lexicon_path=lexicon_path, + tokens=tokens, + lexicon=lexicon, use_words=use_words, prepend_wordsep=prepend_wordsep, special_tokens=special_tokens, diff --git a/text_recognizer/data/transforms/word_piece.py b/text_recognizer/data/transforms/word_piece.py index 6bf5472..69f0ce1 100644 --- a/text_recognizer/data/transforms/word_piece.py +++ b/text_recognizer/data/transforms/word_piece.py @@ -16,7 +16,6 @@ class WordPiece: num_features: int = 1000, tokens: str = "iamdb_1kwp_tokens_1000.txt", lexicon: str = "iamdb_1kwp_lex_1000.txt", - data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Set[str] = {"", "", "

"}, @@ -24,7 +23,6 @@ class WordPiece: max_len: int = 451, ) -> None: self.mapping = WordPieceMapping( - data_dir=data_dir, num_features=num_features, tokens=tokens, lexicon=lexicon, diff --git a/text_recognizer/data/utils/iam_preprocessor.py b/text_recognizer/data/utils/iam_preprocessor.py index 60ecff1..4f95007 100644 --- a/text_recognizer/data/utils/iam_preprocessor.py +++ b/text_recognizer/data/utils/iam_preprocessor.py @@ -47,19 +47,28 @@ class Preprocessor: def __init__( self, - data_dir: Union[str, Path], num_features: int, - tokens_path: Optional[Union[str, Path]] = None, - lexicon_path: Optional[Union[str, Path]] = None, + tokens: Optional[str] = None, + lexicon: Optional[str] = None, use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Optional[Set[str]] = None, ) -> None: + self.data_dir = ( + Path(__file__).resolve().parents[3] + / "data" + / "downloaded" + / "iam" + / "iamdb" + ) + log.debug(f"Using data dir: {self.data_dir}") + if not self.data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}") + self.wordsep = "▁" self._use_word = use_words self._prepend_wordsep = prepend_wordsep self.special_tokens = special_tokens if special_tokens is not None else None - self.data_dir = Path(data_dir) self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) # Load the set of graphemes: @@ -70,14 +79,17 @@ class Preprocessor: self.graphemes = sorted(graphemes) # Build the token-to-index and index-to-token maps. - if tokens_path is not None: - with open(tokens_path, "r") as f: + processed_path = ( + Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" + ) + if tokens is not None: + with open(processed_path / tokens, "r") as f: self.tokens = [line.strip() for line in f] else: self.tokens = self.graphemes - if lexicon_path is not None: - with open(lexicon_path, "r") as f: + if lexicon is not None: + with open(processed_path / lexicon, "r") as f: lexicon = (line.strip().split() for line in f) lexicon = {line[0]: line[1:] for line in lexicon} self.lexicon = lexicon -- cgit v1.2.3-70-g09d2