summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/iam_preprocessor.py1
-rw-r--r--text_recognizer/data/mapping.py8
-rw-r--r--text_recognizer/data/mappings.py143
-rw-r--r--text_recognizer/data/transforms.py111
-rw-r--r--text_recognizer/models/base.py6
-rw-r--r--text_recognizer/networks/image_transformer.py4
6 files changed, 166 insertions, 107 deletions
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index d85787e..60f8a9f 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -119,7 +119,6 @@ class Preprocessor:
continue
self.text.append(example["text"].lower())
-
def _to_index(self, line: str) -> torch.LongTensor:
if line in self.special_tokens:
return torch.LongTensor([self.tokens_to_index[line]])
diff --git a/text_recognizer/data/mapping.py b/text_recognizer/data/mapping.py
deleted file mode 100644
index f0edf3f..0000000
--- a/text_recognizer/data/mapping.py
+++ /dev/null
@@ -1,8 +0,0 @@
-"""Mapping to and from word pieces."""
-from pathlib import Path
-
-
-class WordPieces:
-
- def __init__(self) -> None:
- pass
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
new file mode 100644
index 0000000..cfa0ec7
--- /dev/null
+++ b/text_recognizer/data/mappings.py
@@ -0,0 +1,143 @@
+"""Mapping to and from word pieces."""
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import List, Optional, Union, Sequence
+
+from loguru import logger
+import torch
+from torch import Tensor
+
+from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
+
+
+class AbstractMapping(ABC):
+ @abstractmethod
+ def get_token(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_index(self, *args, **kwargs) -> Tensor:
+ ...
+
+ @abstractmethod
+ def get_text(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_indices(self, *args, **kwargs) -> Tensor:
+ ...
+
+
+class EmnistMapping(AbstractMapping):
+ def __init__(self, extra_symbols: Optional[Sequence[str]]) -> None:
+ self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
+ extra_symbols
+ )
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ if (index := int(index)) in self.mapping:
+ return self.mapping[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ if token in self.inverse_mapping:
+ return Tensor(self.inverse_mapping[token])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return "".join([self.mapping[index] for index in indices])
+
+ def get_indices(self, text: str) -> Tensor:
+ return Tensor([self.inverse_mapping[token] for token in text])
+
+
+class WordPieceMapping(EmnistMapping):
+ def __init__(
+ self,
+ num_features: int,
+ tokens: str,
+ lexicon: str,
+ data_dir: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
+ extra_symbols: Optional[Sequence[str]] = None,
+ ) -> None:
+ super().__init__(extra_symbols)
+ self.wordpiece_processor = self._configure_wordpiece_processor(
+ num_features,
+ tokens,
+ lexicon,
+ data_dir,
+ use_words,
+ prepend_wordsep,
+ special_tokens,
+ extra_symbols,
+ )
+
+ def _configure_wordpiece_processor(
+ self,
+ num_features: int,
+ tokens: str,
+ lexicon: str,
+ data_dir: Optional[Union[str, Path]],
+ use_words: bool,
+ prepend_wordsep: bool,
+ special_tokens: Optional[Sequence[str]],
+ extra_symbols: Optional[Sequence[str]],
+ ) -> Preprocessor:
+ data_dir = (
+ (Path(__file__).resolve().parents[2] / "data" / "raw" / "iam" / "iamdb")
+ if data_dir is None
+ else Path(data_dir)
+ )
+
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+
+ processed_path = (
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ )
+
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ if extra_symbols is not None:
+ special_tokens += extra_symbols
+
+ return Preprocessor(
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
+ special_tokens,
+ )
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ if (index := int(index)) <= self.wordpiece_processor.num_tokens:
+ return self.wordpiece_processor.tokens[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ if token in self.wordpiece_processor.tokens:
+ return torch.LongTensor(self.wordpiece_processor.tokens_to_index[token])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return self.wordpiece_processor.to_text(indices)
+
+ def get_indices(self, text: str) -> Tensor:
+ return self.wordpiece_processor.to_index(text)
+
+ def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
+ text = self.mapping.get_text(x)
+ text = text.lower().replace(" ", "▁")
+ return torch.LongTensor(self.wordpiece_processor.to_index(text))
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 297c953..f53df64 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -1,115 +1,36 @@
"""Transforms for PyTorch datasets."""
-from abc import abstractmethod
from pathlib import Path
-from typing import Any, Optional, Union
+from typing import Optional, Union, Sequence
-from loguru import logger
-import torch
from torch import Tensor
-from text_recognizer.datasets.iam_preprocessor import Preprocessor
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.datasets.mappings import WordPieceMapping
-class ToLower:
- """Converts target to lower case."""
-
- def __call__(self, target: Tensor) -> Tensor:
- """Corrects index value in target tensor."""
- device = target.device
- return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
-
-
-class ToCharcters:
- """Converts integers to characters."""
-
- def __init__(self, extra_symbols: Optional[List[str]] = None) -> None:
- self.mapping, _, _ = emnist_mapping(extra_symbols)
-
- def __call__(self, y: Tensor) -> str:
- """Converts a Tensor to a str."""
- return "".join([self.mapping[int(i)] for i in y]).replace(" ", "▁")
-
-
-class WordPieces:
- """Abstract transform for word pieces."""
+class WordPiece:
+ """Converts EMNIST indices to Word Piece indices."""
def __init__(
self,
num_features: int,
+ tokens: str,
+ lexicon: str,
data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
+ special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
+ extra_symbols: Optional[Sequence[str]] = None,
) -> None:
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
- )
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
- processed_path = (
- Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
- )
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- self.preprocessor = Preprocessor(
- data_dir,
+ self.mapping = WordPieceMapping(
num_features,
- tokens_path,
- lexicon_path,
+ tokens,
+ lexicon,
+ data_dir,
use_words,
prepend_wordsep,
+ special_tokens,
+ extra_symbols,
)
- @abstractmethod
- def __call__(self, *args, **kwargs) -> Any:
- """Transforms input."""
- ...
-
-
-class ToWordPieces(WordPieces):
- """Transforms str to word pieces."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- super().__init__(
- num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
- )
-
- def __call__(self, line: str) -> Tensor:
- """Transforms str to word pieces."""
- return self.preprocessor.to_index(line)
-
-
-class ToText(WordPieces):
- """Takes word pieces and converts them to text."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- super().__init__(
- num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
- )
-
- def __call__(self, x: Tensor) -> str:
- """Converts tensor to text."""
- return self.preprocessor.to_text(x.tolist())
+ def __call__(self, x: Tensor) -> Tensor:
+ return self.mapping.emnist_to_wordpiece_indices(x)
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index c6d5d73..aeda039 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -49,7 +49,9 @@ class LitBaseModel(pl.LightningModule):
optimizer_class = getattr(torch.optim, self._optimizer.type)
return optimizer_class(params=self.parameters(), **args)
- def _configure_lr_scheduler(self, optimizer: Type[torch.optim.Optimizer]) -> Dict[str, Any]:
+ def _configure_lr_scheduler(
+ self, optimizer: Type[torch.optim.Optimizer]
+ ) -> Dict[str, Any]:
"""Configures the lr scheduler."""
scheduler = {"monitor": self.monitor}
args = {} or self._lr_scheduler.args
@@ -59,7 +61,7 @@ class LitBaseModel(pl.LightningModule):
scheduler["scheduler"] = getattr(
torch.optim.lr_scheduler, self._lr_scheduler.type
- )(optimizer, **args)
+ )(optimizer, **args)
return scheduler
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index daededa..a6aaca4 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -44,7 +44,9 @@ class ImageTransformer(nn.Module):
dropout_rate: float = 0.1,
transformer_activation: str = "glu",
) -> None:
- self.vocab_size = NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
+ self.vocab_size = (
+ NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
+ )
self.hidden_dim = hidden_dim
self.max_output_length = output_shape[0]