summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-30 23:15:03 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-30 23:15:03 +0200
commit7268035fb9e57342612a8cc50a1fe04e8841ca2f (patch)
tree8d4cf3743975bd25f2c04d6a56ff3d4608a7e8d9 /text_recognizer/data
parent92fc1c7ed2f9f64552be8f71d9b8ab0d5a0a88d4 (diff)
attr bug fix, properly loading network
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/__init__.py6
-rw-r--r--text_recognizer/data/base_data_module.py2
-rw-r--r--text_recognizer/data/emnist_lines.py2
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py2
-rw-r--r--text_recognizer/data/iam_lines.py2
-rw-r--r--text_recognizer/data/iam_paragraphs.py2
-rw-r--r--text_recognizer/data/iam_preprocessor.py8
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py2
-rw-r--r--text_recognizer/data/mappings.py111
-rw-r--r--text_recognizer/data/transforms.py16
10 files changed, 67 insertions, 86 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py
index 3599a8b..2727b20 100644
--- a/text_recognizer/data/__init__.py
+++ b/text_recognizer/data/__init__.py
@@ -1,7 +1 @@
"""Dataset modules."""
-from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset
-from .base_data_module import BaseDataModule, load_and_print_info
-from .download_utils import download_dataset
-from .iam_paragraphs import IAMParagraphs
-from .iam_synthetic_paragraphs import IAMSyntheticParagraphs
-from .iam_extended_paragraphs import IAMExtendedParagraphs
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 18b1996..408ae36 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -17,7 +17,7 @@ def load_and_print_info(data_module_class: type) -> None:
print(dataset)
-@attr.s
+@attr.s(repr=False)
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 4747508..7548ad5 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -32,7 +32,7 @@ IMAGE_X_PADDING = 28
MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, repr=False)
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,"""
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 58c7369..23e424d 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -10,7 +10,7 @@ from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, repr=False)
class IAMExtendedParagraphs(BaseDataModule):
augment: bool = attr.ib(default=True)
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 13dd379..b7f3fdd 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -37,7 +37,7 @@ IMAGE_WIDTH = 1024
MAX_LABEL_LENGTH = 89
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, repr=False)
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index de32875..82058e0 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -34,7 +34,7 @@ IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR
MAX_LABEL_LENGTH = 682
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, repr=False)
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index f7457e4..93a13bb 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -9,7 +9,7 @@ import collections
import itertools
from pathlib import Path
import re
-from typing import List, Optional, Union
+from typing import List, Optional, Union, Sequence
import click
from loguru import logger
@@ -57,15 +57,13 @@ class Preprocessor:
lexicon_path: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
- special_tokens: Optional[List[str]] = None,
+ special_tokens: Optional[Sequence[str]] = None,
) -> None:
self.wordsep = "▁"
self._use_word = use_words
self._prepend_wordsep = prepend_wordsep
self.special_tokens = special_tokens if special_tokens is not None else None
-
self.data_dir = Path(data_dir)
-
self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
# Load the set of graphemes:
@@ -123,7 +121,7 @@ class Preprocessor:
self.text.append(example["text"].lower())
def _to_index(self, line: str) -> torch.LongTensor:
- if line in self.special_tokens:
+ if self.special_tokens is not None and line in self.special_tokens:
return torch.LongTensor([self.tokens_to_index[line]])
token_to_index = self.graphemes_to_index
if self.lexicon is not None:
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index a3697e7..f00a494 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -34,7 +34,7 @@ PROCESSED_DATA_DIRNAME = (
)
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, repr=False)
class IAMSyntheticParagraphs(IAMParagraphs):
"""IAM Handwriting database of synthetic paragraphs."""
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index 0d778b2..a934fd9 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -1,8 +1,9 @@
"""Mapping to and from word pieces."""
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import List, Optional, Union, Sequence
+from typing import Dict, List, Optional, Union, Set, Sequence
+import attr
from loguru import logger
import torch
from torch import Tensor
@@ -29,10 +30,17 @@ class AbstractMapping(ABC):
...
+@attr.s
class EmnistMapping(AbstractMapping):
- def __init__(self, extra_symbols: Optional[Sequence[str]]) -> None:
+ extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set)
+ mapping: Sequence[str] = attr.ib(init=False)
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
+ input_size: List[int] = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
- extra_symbols
+ self.extra_symbols
)
def get_token(self, index: Union[int, Tensor]) -> str:
@@ -54,42 +62,21 @@ class EmnistMapping(AbstractMapping):
return Tensor([self.inverse_mapping[token] for token in text])
+@attr.s(auto_attribs=True)
class WordPieceMapping(EmnistMapping):
- def __init__(
- self,
- num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt",
- lexicon: str = "iamdb_1kwp_lex_1000.txt",
- data_dir: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n",),
- ) -> None:
- super().__init__(extra_symbols)
- self.wordpiece_processor = self._configure_wordpiece_processor(
- num_features,
- tokens,
- lexicon,
- data_dir,
- use_words,
- prepend_wordsep,
- special_tokens,
- extra_symbols,
- )
-
- @staticmethod
- def _configure_wordpiece_processor(
- num_features: int,
- tokens: str,
- lexicon: str,
- data_dir: Optional[Union[str, Path]],
- use_words: bool,
- prepend_wordsep: bool,
- special_tokens: Optional[Sequence[str]],
- extra_symbols: Optional[Sequence[str]],
- ) -> Preprocessor:
- data_dir = (
+ data_dir: Optional[Path] = attr.ib(default=None)
+ num_features: int = attr.ib(default=1000)
+ tokens: str = attr.ib(default="iamdb_1kwp_tokens_1000.txt")
+ lexicon: str = attr.ib(default="iamdb_1kwp_lex_1000.txt")
+ use_words: bool = attr.ib(default=False)
+ prepend_wordsep: bool = attr.ib(default=False)
+ special_tokens: Set[str] = attr.ib(default={"<s>", "<e>", "<p>"}, converter=set)
+ extra_symbols: Set[str] = attr.ib(default={"\n",}, converter=set)
+ wordpiece_processor: Preprocessor = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ super().__attrs_post_init__()
+ self.data_dir = (
(
Path(__file__).resolve().parents[2]
/ "data"
@@ -97,32 +84,32 @@ class WordPieceMapping(EmnistMapping):
/ "iam"
/ "iamdb"
)
- if data_dir is None
- else Path(data_dir)
+ if self.data_dir is None
+ else Path(self.data_dir)
)
-
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ logger.debug(f"Using data dir: {self.data_dir}")
+ if not self.data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
processed_path = (
Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
)
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- if extra_symbols is not None:
- special_tokens += extra_symbols
-
- return Preprocessor(
- data_dir,
- num_features,
- tokens_path,
- lexicon_path,
- use_words,
- prepend_wordsep,
- special_tokens,
+ tokens_path = processed_path / self.tokens
+ lexicon_path = processed_path / self.lexicon
+
+ special_tokens = self.special_tokens
+ if self.extra_symbols is not None:
+ special_tokens = special_tokens | self.extra_symbols
+
+ self.wordpiece_processor = Preprocessor(
+ data_dir=self.data_dir,
+ num_features=self.num_features,
+ tokens_path=tokens_path,
+ lexicon_path=lexicon_path,
+ use_words=self.use_words,
+ prepend_wordsep=self.prepend_wordsep,
+ special_tokens=special_tokens,
)
def __len__(self) -> int:
@@ -151,7 +138,9 @@ class WordPieceMapping(EmnistMapping):
text = text.lower().replace(" ", "▁")
return torch.LongTensor(self.wordpiece_processor.to_index(text))
- def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]:
+ def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]:
+ if isinstance(x, int):
+ x = [x]
if isinstance(x, str):
- return self.get_index(x)
- return self.get_token(x)
+ return self.get_indices(x)
+ return self.get_text(x)
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 66531a5..3b1b929 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -24,14 +24,14 @@ class WordPiece:
max_len: int = 451,
) -> None:
self.mapping = WordPieceMapping(
- num_features,
- tokens,
- lexicon,
- data_dir,
- use_words,
- prepend_wordsep,
- special_tokens,
- extra_symbols,
+ data_dir=data_dir,
+ num_features=num_features,
+ tokens=tokens,
+ lexicon=lexicon,
+ use_words=use_words,
+ prepend_wordsep=prepend_wordsep,
+ special_tokens=special_tokens,
+ extra_symbols=extra_symbols,
)
self.max_len = max_len