summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_preprocessor.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_preprocessor.py')
-rw-r--r--text_recognizer/data/iam_preprocessor.py16
1 files changed, 7 insertions, 9 deletions
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 93a13bb..bcd77b4 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -1,18 +1,16 @@
"""Preprocessor for extracting word letters from the IAM dataset.
The code is mostly stolen from:
-
https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
-
"""
import collections
import itertools
from pathlib import Path
import re
-from typing import List, Optional, Union, Sequence
+from typing import List, Optional, Union, Set
import click
-from loguru import logger
+from loguru import logger as log
import torch
@@ -57,7 +55,7 @@ class Preprocessor:
lexicon_path: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
- special_tokens: Optional[Sequence[str]] = None,
+ special_tokens: Optional[Set[str]] = None,
) -> None:
self.wordsep = "▁"
self._use_word = use_words
@@ -186,7 +184,7 @@ def cli(
/ "iam"
/ "iamdb"
)
- logger.debug(f"Using data dir: {data_dir}")
+ log.debug(f"Using data dir: {data_dir}")
if not data_dir.exists():
raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
else:
@@ -196,15 +194,15 @@ def cli(
preprocessor.extract_train_text()
processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
- logger.debug(f"Saving processed files at: {processed_dir}")
+ log.debug(f"Saving processed files at: {processed_dir}")
if save_text is not None:
- logger.info("Saving training text")
+ log.info("Saving training text")
with open(processed_dir / save_text, "w") as f:
f.write("\n".join(t for t in preprocessor.text))
if save_tokens is not None:
- logger.info("Saving tokens")
+ log.info("Saving tokens")
with open(processed_dir / save_tokens, "w") as f:
f.write("\n".join(preprocessor.tokens))