summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/base_data_module.py6
-rw-r--r--text_recognizer/data/base_mapping.py37
-rw-r--r--text_recognizer/data/download_utils.py2
-rw-r--r--text_recognizer/data/emnist_mapping.py37
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py3
-rw-r--r--text_recognizer/data/iam_lines.py2
-rw-r--r--text_recognizer/data/iam_paragraphs.py12
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py4
-rw-r--r--text_recognizer/data/make_wordpieces.py2
-rw-r--r--text_recognizer/data/mappings.py156
-rw-r--r--text_recognizer/data/transforms.py8
-rw-r--r--text_recognizer/data/word_piece_mapping.py93
12 files changed, 184 insertions, 178 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index fd914b6..16a06d9 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,12 +1,12 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict, Tuple
+from typing import Dict, Tuple, Type
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
-from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.data.base_mapping import AbstractMapping
from text_recognizer.data.base_dataset import BaseDataset
@@ -25,7 +25,7 @@ class BaseDataModule(LightningDataModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
- mapping: AbstractMapping = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
batch_size: int = attr.ib(default=16)
num_workers: int = attr.ib(default=0)
pin_memory: bool = attr.ib(default=True)
diff --git a/text_recognizer/data/base_mapping.py b/text_recognizer/data/base_mapping.py
new file mode 100644
index 0000000..572ac95
--- /dev/null
+++ b/text_recognizer/data/base_mapping.py
@@ -0,0 +1,37 @@
+"""Mapping to and from word pieces."""
+from abc import ABC, abstractmethod
+from typing import Dict, List
+
+from torch import Tensor
+
+
+class AbstractMapping(ABC):
+ def __init__(
+ self, input_size: List[int], mapping: List[str], inverse_mapping: Dict[str, int]
+ ) -> None:
+ self.input_size = input_size
+ self.mapping = mapping
+ self.inverse_mapping = inverse_mapping
+
+ def __len__(self) -> int:
+ return len(self.mapping)
+
+ @property
+ def num_classes(self) -> int:
+ return self.__len__()
+
+ @abstractmethod
+ def get_token(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_index(self, *args, **kwargs) -> Tensor:
+ ...
+
+ @abstractmethod
+ def get_text(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_indices(self, *args, **kwargs) -> Tensor:
+ ...
diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py
index 8938830..a5a5360 100644
--- a/text_recognizer/data/download_utils.py
+++ b/text_recognizer/data/download_utils.py
@@ -1,7 +1,7 @@
"""Util functions for downloading datasets."""
import hashlib
from pathlib import Path
-from typing import Dict, List, Optional
+from typing import Dict, Optional
from urllib.request import urlretrieve
from loguru import logger as log
diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py
new file mode 100644
index 0000000..6c4c43b
--- /dev/null
+++ b/text_recognizer/data/emnist_mapping.py
@@ -0,0 +1,37 @@
+"""Emnist mapping."""
+from typing import List, Optional, Union, Set
+
+from torch import Tensor
+
+from text_recognizer.data.base_mapping import AbstractMapping
+from text_recognizer.data.emnist import emnist_mapping
+
+
+class EmnistMapping(AbstractMapping):
+ def __init__(self, extra_symbols: Optional[Set[str]] = None) -> None:
+ self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
+ self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
+ self.extra_symbols
+ )
+ super().__init__(self.input_size, self.mapping, self.inverse_mapping)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ if (index := int(index)) in self.mapping:
+ return self.mapping[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ if token in self.inverse_mapping:
+ return Tensor(self.inverse_mapping[token])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return "".join([self.mapping[index] for index in indices])
+
+ def get_indices(self, text: str) -> Tensor:
+ return Tensor([self.inverse_mapping[token] for token in text])
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index ccf0759..df0c0e1 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,6 +1,4 @@
"""IAM original and sythetic dataset class."""
-from typing import Dict, List
-
import attr
from torch.utils.data import ConcatDataset
@@ -15,7 +13,6 @@ class IAMExtendedParagraphs(BaseDataModule):
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
word_pieces: bool = attr.ib(default=False)
- num_classes: int = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 1c63729..aba38f9 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -22,7 +22,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data import image_utils
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 6189f7d..11f899f 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -17,7 +17,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.transforms import WordPiece
@@ -50,11 +50,9 @@ class IAMParagraphs(BaseDataModule):
if PROCESSED_DATA_DIRNAME.exists():
return
- log.info(
- "Cropping IAM paragraph regions and saving them along with labels..."
- )
+ log.info("Cropping IAM paragraph regions and saving them along with labels...")
- iam = IAM(mapping=EmnistMapping())
+ iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN,}))
iam.prepare_data()
properties = {}
@@ -83,7 +81,9 @@ class IAMParagraphs(BaseDataModule):
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
- strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0]
+ strings=labels,
+ mapping=self.mapping.inverse_mapping,
+ length=self.output_dims[0],
)
return BaseDataset(
data,
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index c938f8b..24ca896 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -21,7 +21,7 @@ from text_recognizer.data.iam_paragraphs import (
IMAGE_SCALE_FACTOR,
resize_image,
)
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.emnist_mapping import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.iam_lines import (
line_crops_and_labels,
@@ -47,7 +47,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
log.info("Preparing IAM lines for synthetic paragraphs dataset.")
log.info("Cropping IAM line regions and loading labels.")
- iam = IAM(mapping=EmnistMapping())
+ iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN,}))
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
diff --git a/text_recognizer/data/make_wordpieces.py b/text_recognizer/data/make_wordpieces.py
index 40fbee4..8e53815 100644
--- a/text_recognizer/data/make_wordpieces.py
+++ b/text_recognizer/data/make_wordpieces.py
@@ -13,8 +13,6 @@ import click
from loguru import logger as log
import sentencepiece as spm
-from text_recognizer.data.iam_preprocessor import load_metadata
-
def iamdb_pieces(
data_dir: Path, text_file: str, num_pieces: int, output_prefix: str
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
deleted file mode 100644
index d1c64dd..0000000
--- a/text_recognizer/data/mappings.py
+++ /dev/null
@@ -1,156 +0,0 @@
-"""Mapping to and from word pieces."""
-from abc import ABC, abstractmethod
-from pathlib import Path
-from typing import Dict, List, Optional, Union, Set
-
-import attr
-import torch
-from loguru import logger as log
-from torch import Tensor
-
-from text_recognizer.data.emnist import emnist_mapping
-from text_recognizer.data.iam_preprocessor import Preprocessor
-
-
-@attr.s
-class AbstractMapping(ABC):
- input_size: List[int] = attr.ib(init=False)
- mapping: List[str] = attr.ib(init=False)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
-
- def __len__(self) -> int:
- return len(self.mapping)
-
- @property
- def num_classes(self) -> int:
- return self.__len__()
-
- @abstractmethod
- def get_token(self, *args, **kwargs) -> str:
- ...
-
- @abstractmethod
- def get_index(self, *args, **kwargs) -> Tensor:
- ...
-
- @abstractmethod
- def get_text(self, *args, **kwargs) -> str:
- ...
-
- @abstractmethod
- def get_indices(self, *args, **kwargs) -> Tensor:
- ...
-
-
-@attr.s(auto_attribs=True)
-class EmnistMapping(AbstractMapping):
- extra_symbols: Optional[Set[str]] = attr.ib(default=None)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- self.extra_symbols = set(self.extra_symbols) if self.extra_symbols is not None else None
- self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
- self.extra_symbols
- )
-
- def get_token(self, index: Union[int, Tensor]) -> str:
- if (index := int(index)) in self.mapping:
- return self.mapping[index]
- raise KeyError(f"Index ({index}) not in mapping.")
-
- def get_index(self, token: str) -> Tensor:
- if token in self.inverse_mapping:
- return Tensor(self.inverse_mapping[token])
- raise KeyError(f"Token ({token}) not found in inverse mapping.")
-
- def get_text(self, indices: Union[List[int], Tensor]) -> str:
- if isinstance(indices, Tensor):
- indices = indices.tolist()
- return "".join([self.mapping[index] for index in indices])
-
- def get_indices(self, text: str) -> Tensor:
- return Tensor([self.inverse_mapping[token] for token in text])
-
-
-@attr.s(auto_attribs=True)
-class WordPieceMapping(EmnistMapping):
- data_dir: Optional[Path] = attr.ib(default=None)
- num_features: int = attr.ib(default=1000)
- tokens: str = attr.ib(default="iamdb_1kwp_tokens_1000.txt")
- lexicon: str = attr.ib(default="iamdb_1kwp_lex_1000.txt")
- use_words: bool = attr.ib(default=False)
- prepend_wordsep: bool = attr.ib(default=False)
- special_tokens: Set[str] = attr.ib(default={"<s>", "<e>", "<p>"}, converter=set)
- extra_symbols: Set[str] = attr.ib(default={"\n",}, converter=set)
- wordpiece_processor: Preprocessor = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- super().__attrs_post_init__()
- self.data_dir = (
- (
- Path(__file__).resolve().parents[2]
- / "data"
- / "downloaded"
- / "iam"
- / "iamdb"
- )
- if self.data_dir is None
- else Path(self.data_dir)
- )
- log.debug(f"Using data dir: {self.data_dir}")
- if not self.data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
-
- processed_path = (
- Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
- )
-
- tokens_path = processed_path / self.tokens
- lexicon_path = processed_path / self.lexicon
-
- special_tokens = self.special_tokens
- if self.extra_symbols is not None:
- special_tokens = special_tokens | self.extra_symbols
-
- self.wordpiece_processor = Preprocessor(
- data_dir=self.data_dir,
- num_features=self.num_features,
- tokens_path=tokens_path,
- lexicon_path=lexicon_path,
- use_words=self.use_words,
- prepend_wordsep=self.prepend_wordsep,
- special_tokens=special_tokens,
- )
-
- def __len__(self) -> int:
- return len(self.wordpiece_processor.tokens)
-
- def get_token(self, index: Union[int, Tensor]) -> str:
- if (index := int(index)) <= self.wordpiece_processor.num_tokens:
- return self.wordpiece_processor.tokens[index]
- raise KeyError(f"Index ({index}) not in mapping.")
-
- def get_index(self, token: str) -> Tensor:
- if token in self.wordpiece_processor.tokens:
- return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
- raise KeyError(f"Token ({token}) not found in inverse mapping.")
-
- def get_text(self, indices: Union[List[int], Tensor]) -> str:
- if isinstance(indices, Tensor):
- indices = indices.tolist()
- return self.wordpiece_processor.to_text(indices)
-
- def get_indices(self, text: str) -> Tensor:
- return self.wordpiece_processor.to_index(text)
-
- def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
- text = "".join([self.mapping[i] for i in x])
- text = text.lower().replace(" ", "▁")
- return torch.LongTensor(self.wordpiece_processor.to_index(text))
-
- def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]:
- if isinstance(x, int):
- x = [x]
- if isinstance(x, str):
- return self.get_indices(x)
- return self.get_text(x)
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 3b1b929..047496f 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -1,11 +1,11 @@
"""Transforms for PyTorch datasets."""
from pathlib import Path
-from typing import Optional, Union, Sequence
+from typing import Optional, Union, Set
import torch
from torch import Tensor
-from text_recognizer.data.mappings import WordPieceMapping
+from text_recognizer.data.word_piece_mapping import WordPieceMapping
class WordPiece:
@@ -19,8 +19,8 @@ class WordPiece:
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
- special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n",),
+ special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
+ extra_symbols: Optional[Set[str]] = {"\n",},
max_len: int = 451,
) -> None:
self.mapping = WordPieceMapping(
diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py
new file mode 100644
index 0000000..59488c3
--- /dev/null
+++ b/text_recognizer/data/word_piece_mapping.py
@@ -0,0 +1,93 @@
+"""Word piece mapping."""
+from pathlib import Path
+from typing import List, Optional, Union, Set
+
+import torch
+from loguru import logger as log
+from torch import Tensor
+
+from text_recognizer.data.emnist_mapping import EmnistMapping
+from text_recognizer.data.iam_preprocessor import Preprocessor
+
+
+class WordPieceMapping(EmnistMapping):
+ def __init__(
+ self,
+ data_dir: Optional[Path] = None,
+ num_features: int = 1000,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt",
+ lexicon: str = "iamdb_1kwp_lex_1000.txt",
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
+ extra_symbols: Set[str] = {"\n",},
+ ) -> None:
+ super().__init__(extra_symbols=extra_symbols)
+ self.data_dir = (
+ (
+ Path(__file__).resolve().parents[2]
+ / "data"
+ / "downloaded"
+ / "iam"
+ / "iamdb"
+ )
+ if data_dir is None
+ else Path(data_dir)
+ )
+ log.debug(f"Using data dir: {self.data_dir}")
+ if not self.data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
+
+ processed_path = (
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ )
+
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ special_tokens = set(special_tokens)
+ if self.extra_symbols is not None:
+ special_tokens = special_tokens | set(extra_symbols)
+
+ self.wordpiece_processor = Preprocessor(
+ data_dir=self.data_dir,
+ num_features=num_features,
+ tokens_path=tokens_path,
+ lexicon_path=lexicon_path,
+ use_words=use_words,
+ prepend_wordsep=prepend_wordsep,
+ special_tokens=special_tokens,
+ )
+
+ def __len__(self) -> int:
+ return len(self.wordpiece_processor.tokens)
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ if (index := int(index)) <= self.wordpiece_processor.num_tokens:
+ return self.wordpiece_processor.tokens[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ if token in self.wordpiece_processor.tokens:
+ return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return self.wordpiece_processor.to_text(indices).replace(" ", "▁")
+
+ def get_indices(self, text: str) -> Tensor:
+ return self.wordpiece_processor.to_index(text)
+
+ def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
+ text = "".join([self.mapping[i] for i in x])
+ text = text.lower().replace(" ", "▁")
+ return torch.LongTensor(self.wordpiece_processor.to_index(text))
+
+ def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]:
+ if isinstance(x, int):
+ x = [x]
+ if isinstance(x, str):
+ return self.get_indices(x)
+ return self.get_text(x)