summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/data/tokenizer.py39
1 files changed, 38 insertions, 1 deletions
diff --git a/text_recognizer/data/tokenizer.py b/text_recognizer/data/tokenizer.py
index 12617a1..f229c74 100644
--- a/text_recognizer/data/tokenizer.py
+++ b/text_recognizer/data/tokenizer.py
@@ -1,6 +1,5 @@
"""Emnist mapping."""
import json
-from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
@@ -19,6 +18,7 @@ class Tokenizer:
start_token: str = "<s>",
end_token: str = "<e>",
pad_token: str = "<p>",
+ replace_after_end: bool = True,
) -> None:
self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
self.mapping, self.inverse_mapping, self.input_size = self._load_mapping()
@@ -29,6 +29,7 @@ class Tokenizer:
self.end_index = int(self.get_value(self.end_token))
self.pad_index = int(self.get_value(self.pad_token))
self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
+ self.replace_after_end = replace_after_end
if lower:
self._to_lower()
@@ -86,10 +87,46 @@ class Tokenizer:
]
)
+ def batch_decode(self, ys: Tensor) -> List[str]:
+ return [self.decode(y) for y in ys]
+
+ def decode_logits(self, logits: Tensor) -> List[str]:
+ ys = self.logits_to_indices(logits)
+ return self.batch_decode(ys)
+
def encode(self, text: str) -> Tensor:
"""Returns tensor of indices for a string."""
return Tensor([self.inverse_mapping[token] for token in text])
+ def first_appearance(self, x: Tensor, dim: int) -> Tensor:
+ if x.dim() > 2 or x.dim() == 0:
+ raise ValueError(
+ f"Only 1 or 2 dimensional tensors allowed, got a tensor with dim {x.dim()}"
+ )
+ matches = x == self.end_index
+ mask = (matches.cumsum(dim) == 1) & matches
+ does_match, index = mask.max(dim)
+ first = torch.where(does_match, index, x.shape[dim])
+ return first
+
+ def replace_after(self, x: Tensor) -> Tensor:
+ first_appearance = self.first_appearance(x, dim=1)
+ indices = torch.arange(0, x.shape[-1]).type_as(x)
+ output = torch.where(
+ indices[None, :]
+ <= first_appearance[:, None], # if index is before first appearance
+ x, # return the value from x
+ self.pad_index, # otherwise, return the replacement value
+ )
+ return output # [B, N]
+
+ def logits_to_indices(self, logits: Tensor) -> Tensor:
+ preds = logits.argmax(dim=1)
+ if self.replace_after_end:
+ return self.replace_after(preds) # [B, N]
+ else:
+ return preds # [B, N]
+
def __getitem__(self, x: Union[int, Tensor]) -> str:
"""Returns text for a list of indices."""
return self.get_token(x)