summaryrefslogtreecommitdiff
path: root/text_recognizer/data/mappings.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/mappings.py')
-rw-r--r--text_recognizer/data/mappings.py24
1 files changed, 17 insertions, 7 deletions
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index b69e888..d1c64dd 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -1,18 +1,30 @@
"""Mapping to and from word pieces."""
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import Dict, List, Optional, Union, Set, Sequence
+from typing import Dict, List, Optional, Union, Set
import attr
-import loguru.logger as log
import torch
+from loguru import logger as log
from torch import Tensor
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.iam_preprocessor import Preprocessor
+@attr.s
class AbstractMapping(ABC):
+ input_size: List[int] = attr.ib(init=False)
+ mapping: List[str] = attr.ib(init=False)
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
+
+ def __len__(self) -> int:
+ return len(self.mapping)
+
+ @property
+ def num_classes(self) -> int:
+ return self.__len__()
+
@abstractmethod
def get_token(self, *args, **kwargs) -> str:
...
@@ -30,15 +42,13 @@ class AbstractMapping(ABC):
...
-@attr.s
+@attr.s(auto_attribs=True)
class EmnistMapping(AbstractMapping):
- extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set)
- mapping: Sequence[str] = attr.ib(init=False)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
- input_size: List[int] = attr.ib(init=False)
+ extra_symbols: Optional[Set[str]] = attr.ib(default=None)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
+ self.extra_symbols = set(self.extra_symbols) if self.extra_symbols is not None else None
self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
self.extra_symbols
)