1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
|
"""Emnist mapping."""
import json
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
import text_recognizer.metadata.shared as metadata
class Tokenizer:
"""Mapping for EMNIST labels."""
def __init__(
self,
extra_symbols: Optional[Sequence[str]] = None,
lower: bool = True,
start_token: str = "<s>",
end_token: str = "<e>",
pad_token: str = "<p>",
replace_after_end: bool = True,
) -> None:
self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
self.mapping, self.inverse_mapping, self.input_size = self._load_mapping()
self.start_token = start_token
self.end_token = end_token
self.pad_token = pad_token
self.start_index = int(self.get_value(self.start_token))
self.end_index = int(self.get_value(self.end_token))
self.pad_index = int(self.get_value(self.pad_token))
self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
self.replace_after_end = replace_after_end
if lower:
self._to_lower()
def __len__(self) -> int:
return len(self.mapping)
@property
def num_classes(self) -> int:
"""Return number of classes in the dataset."""
return self.__len__()
def _load_mapping(self) -> Tuple[List, Dict[str, int], List[int]]:
"""Return the EMNIST mapping."""
with metadata.ESSENTIALS_FILENAME.open() as f:
essentials = json.load(f)
mapping = list(essentials["characters"])
if self.extra_symbols is not None:
mapping += self.extra_symbols
inverse_mapping = {v: k for k, v in enumerate(mapping)}
input_shape = essentials["input_shape"]
return mapping, inverse_mapping, input_shape
def _to_lower(self) -> None:
"""Converts mapping to lowercase letters only."""
def _filter(x: int) -> int:
if 40 <= x:
return x - 26
return x
self.inverse_mapping = {v: _filter(k) for k, v in enumerate(self.mapping)}
self.mapping = [c for c in self.mapping if not c.isupper()]
def get_token(self, index: Union[int, Tensor]) -> str:
"""Returns token for index value."""
if (index := int(index)) <= len(self.mapping):
return self.mapping[index]
raise KeyError(f"Index ({index}) not in mapping.")
def get_value(self, token: str) -> Tensor:
"""Returns index value of token."""
if token in self.inverse_mapping:
return torch.LongTensor([self.inverse_mapping[token]])
raise KeyError(f"Token ({token}) not found in inverse mapping.")
def decode(self, indices: Union[List[int], Tensor]) -> str:
"""Returns the text from a list of indices."""
if isinstance(indices, Tensor):
indices = indices.tolist()
return "".join(
[
self.mapping[index]
for index in indices
if index not in self.ignore_indices
]
)
def batch_decode(self, ys: Tensor) -> List[str]:
return [self.decode(y) for y in ys]
def decode_logits(self, logits: Tensor) -> List[str]:
ys = self.logits_to_indices(logits)
return self.batch_decode(ys)
def encode(self, text: str) -> Tensor:
"""Returns tensor of indices for a string."""
return Tensor([self.inverse_mapping[token] for token in text])
def first_appearance(self, x: Tensor, dim: int) -> Tensor:
if x.dim() > 2 or x.dim() == 0:
raise ValueError(
f"Only 1 or 2 dimensional tensors allowed, got a tensor with dim {x.dim()}"
)
matches = x == self.end_index
mask = (matches.cumsum(dim) == 1) & matches
does_match, index = mask.max(dim)
first = torch.where(does_match, index, x.shape[dim])
return first
def replace_after(self, x: Tensor) -> Tensor:
first_appearance = self.first_appearance(x, dim=1)
indices = torch.arange(0, x.shape[-1]).type_as(x)
output = torch.where(
indices[None, :]
<= first_appearance[:, None], # if index is before first appearance
x, # return the value from x
self.pad_index, # otherwise, return the replacement value
)
return output # [B, N]
def logits_to_indices(self, logits: Tensor) -> Tensor:
preds = logits.argmax(dim=1)
if self.replace_after_end:
return self.replace_after(preds) # [B, N]
else:
return preds # [B, N]
def __getitem__(self, x: Union[int, Tensor]) -> str:
"""Returns text for a list of indices."""
return self.get_token(x)
|