diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-30 23:15:03 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-30 23:15:03 +0200 |
commit | 7268035fb9e57342612a8cc50a1fe04e8841ca2f (patch) | |
tree | 8d4cf3743975bd25f2c04d6a56ff3d4608a7e8d9 /text_recognizer/data | |
parent | 92fc1c7ed2f9f64552be8f71d9b8ab0d5a0a88d4 (diff) |
attr bug fix, properly loading network
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/__init__.py | 6 | ||||
-rw-r--r-- | text_recognizer/data/base_data_module.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 8 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/mappings.py | 111 | ||||
-rw-r--r-- | text_recognizer/data/transforms.py | 16 |
10 files changed, 67 insertions, 86 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py index 3599a8b..2727b20 100644 --- a/text_recognizer/data/__init__.py +++ b/text_recognizer/data/__init__.py @@ -1,7 +1 @@ """Dataset modules.""" -from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset -from .base_data_module import BaseDataModule, load_and_print_info -from .download_utils import download_dataset -from .iam_paragraphs import IAMParagraphs -from .iam_synthetic_paragraphs import IAMSyntheticParagraphs -from .iam_extended_paragraphs import IAMExtendedParagraphs diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 18b1996..408ae36 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -17,7 +17,7 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -@attr.s +@attr.s(repr=False) class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 4747508..7548ad5 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -32,7 +32,7 @@ IMAGE_X_PADDING = 28 MAX_OUTPUT_LENGTH = 89 # Same as IAMLines -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, repr=False) class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 58c7369..23e424d 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -10,7 +10,7 @@ from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, repr=False) class IAMExtendedParagraphs(BaseDataModule): augment: bool = attr.ib(default=True) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 13dd379..b7f3fdd 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -37,7 +37,7 @@ IMAGE_WIDTH = 1024 MAX_LABEL_LENGTH = 89 -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, repr=False) class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index de32875..82058e0 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -34,7 +34,7 @@ IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR MAX_LABEL_LENGTH = 682 -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, repr=False) class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index f7457e4..93a13bb 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -9,7 +9,7 @@ import collections import itertools from pathlib import Path import re -from typing import List, Optional, Union +from typing import List, Optional, Union, Sequence import click from loguru import logger @@ -57,15 +57,13 @@ class Preprocessor: lexicon_path: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, - special_tokens: Optional[List[str]] = None, + special_tokens: Optional[Sequence[str]] = None, ) -> None: self.wordsep = "▁" self._use_word = use_words self._prepend_wordsep = prepend_wordsep self.special_tokens = special_tokens if special_tokens is not None else None - self.data_dir = Path(data_dir) - self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) # Load the set of graphemes: @@ -123,7 +121,7 @@ class Preprocessor: self.text.append(example["text"].lower()) def _to_index(self, line: str) -> torch.LongTensor: - if line in self.special_tokens: + if self.special_tokens is not None and line in self.special_tokens: return torch.LongTensor([self.tokens_to_index[line]]) token_to_index = self.graphemes_to_index if self.lexicon is not None: diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index a3697e7..f00a494 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -34,7 +34,7 @@ PROCESSED_DATA_DIRNAME = ( ) -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, repr=False) class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database of synthetic paragraphs.""" diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index 0d778b2..a934fd9 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -1,8 +1,9 @@ """Mapping to and from word pieces.""" from abc import ABC, abstractmethod from pathlib import Path -from typing import List, Optional, Union, Sequence +from typing import Dict, List, Optional, Union, Set, Sequence +import attr from loguru import logger import torch from torch import Tensor @@ -29,10 +30,17 @@ class AbstractMapping(ABC): ... +@attr.s class EmnistMapping(AbstractMapping): - def __init__(self, extra_symbols: Optional[Sequence[str]]) -> None: + extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set) + mapping: Sequence[str] = attr.ib(init=False) + inverse_mapping: Dict[str, int] = attr.ib(init=False) + input_size: List[int] = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( - extra_symbols + self.extra_symbols ) def get_token(self, index: Union[int, Tensor]) -> str: @@ -54,42 +62,21 @@ class EmnistMapping(AbstractMapping): return Tensor([self.inverse_mapping[token] for token in text]) +@attr.s(auto_attribs=True) class WordPieceMapping(EmnistMapping): - def __init__( - self, - num_features: int = 1000, - tokens: str = "iamdb_1kwp_tokens_1000.txt", - lexicon: str = "iamdb_1kwp_lex_1000.txt", - data_dir: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), - extra_symbols: Optional[Sequence[str]] = ("\n",), - ) -> None: - super().__init__(extra_symbols) - self.wordpiece_processor = self._configure_wordpiece_processor( - num_features, - tokens, - lexicon, - data_dir, - use_words, - prepend_wordsep, - special_tokens, - extra_symbols, - ) - - @staticmethod - def _configure_wordpiece_processor( - num_features: int, - tokens: str, - lexicon: str, - data_dir: Optional[Union[str, Path]], - use_words: bool, - prepend_wordsep: bool, - special_tokens: Optional[Sequence[str]], - extra_symbols: Optional[Sequence[str]], - ) -> Preprocessor: - data_dir = ( + data_dir: Optional[Path] = attr.ib(default=None) + num_features: int = attr.ib(default=1000) + tokens: str = attr.ib(default="iamdb_1kwp_tokens_1000.txt") + lexicon: str = attr.ib(default="iamdb_1kwp_lex_1000.txt") + use_words: bool = attr.ib(default=False) + prepend_wordsep: bool = attr.ib(default=False) + special_tokens: Set[str] = attr.ib(default={"<s>", "<e>", "<p>"}, converter=set) + extra_symbols: Set[str] = attr.ib(default={"\n",}, converter=set) + wordpiece_processor: Preprocessor = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: + super().__attrs_post_init__() + self.data_dir = ( ( Path(__file__).resolve().parents[2] / "data" @@ -97,32 +84,32 @@ class WordPieceMapping(EmnistMapping): / "iam" / "iamdb" ) - if data_dir is None - else Path(data_dir) + if self.data_dir is None + else Path(self.data_dir) ) - - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + logger.debug(f"Using data dir: {self.data_dir}") + if not self.data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}") processed_path = ( Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - if extra_symbols is not None: - special_tokens += extra_symbols - - return Preprocessor( - data_dir, - num_features, - tokens_path, - lexicon_path, - use_words, - prepend_wordsep, - special_tokens, + tokens_path = processed_path / self.tokens + lexicon_path = processed_path / self.lexicon + + special_tokens = self.special_tokens + if self.extra_symbols is not None: + special_tokens = special_tokens | self.extra_symbols + + self.wordpiece_processor = Preprocessor( + data_dir=self.data_dir, + num_features=self.num_features, + tokens_path=tokens_path, + lexicon_path=lexicon_path, + use_words=self.use_words, + prepend_wordsep=self.prepend_wordsep, + special_tokens=special_tokens, ) def __len__(self) -> int: @@ -151,7 +138,9 @@ class WordPieceMapping(EmnistMapping): text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) - def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]: + def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]: + if isinstance(x, int): + x = [x] if isinstance(x, str): - return self.get_index(x) - return self.get_token(x) + return self.get_indices(x) + return self.get_text(x) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index 66531a5..3b1b929 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -24,14 +24,14 @@ class WordPiece: max_len: int = 451, ) -> None: self.mapping = WordPieceMapping( - num_features, - tokens, - lexicon, - data_dir, - use_words, - prepend_wordsep, - special_tokens, - extra_symbols, + data_dir=data_dir, + num_features=num_features, + tokens=tokens, + lexicon=lexicon, + use_words=use_words, + prepend_wordsep=prepend_wordsep, + special_tokens=special_tokens, + extra_symbols=extra_symbols, ) self.max_len = max_len |