summaryrefslogtreecommitdiff
path: root/text_recognizer/data/mappings/base_mapping.py
blob: 572ac9562bd1ab18701e80e650c95757581d7679 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""Mapping to and from word pieces."""
from abc import ABC, abstractmethod
from typing import Dict, List

from torch import Tensor


class AbstractMapping(ABC):
    def __init__(
        self, input_size: List[int], mapping: List[str], inverse_mapping: Dict[str, int]
    ) -> None:
        self.input_size = input_size
        self.mapping = mapping
        self.inverse_mapping = inverse_mapping

    def __len__(self) -> int:
        return len(self.mapping)

    @property
    def num_classes(self) -> int:
        return self.__len__()

    @abstractmethod
    def get_token(self, *args, **kwargs) -> str:
        ...

    @abstractmethod
    def get_index(self, *args, **kwargs) -> Tensor:
        ...

    @abstractmethod
    def get_text(self, *args, **kwargs) -> str:
        ...

    @abstractmethod
    def get_indices(self, *args, **kwargs) -> Tensor:
        ...