From b505de634bd3419c5034352bb886f2ec9e605f8e Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 2 Sep 2023 01:51:16 +0200 Subject: Update tokenizer with logits and batch decoding --- text_recognizer/data/tokenizer.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) (limited to 'text_recognizer/data') diff --git a/text_recognizer/data/tokenizer.py b/text_recognizer/data/tokenizer.py index 12617a1..f229c74 100644 --- a/text_recognizer/data/tokenizer.py +++ b/text_recognizer/data/tokenizer.py @@ -1,6 +1,5 @@ """Emnist mapping.""" import json -from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple, Union import torch @@ -19,6 +18,7 @@ class Tokenizer: start_token: str = "", end_token: str = "", pad_token: str = "

", + replace_after_end: bool = True, ) -> None: self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None self.mapping, self.inverse_mapping, self.input_size = self._load_mapping() @@ -29,6 +29,7 @@ class Tokenizer: self.end_index = int(self.get_value(self.end_token)) self.pad_index = int(self.get_value(self.pad_token)) self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) + self.replace_after_end = replace_after_end if lower: self._to_lower() @@ -86,10 +87,46 @@ class Tokenizer: ] ) + def batch_decode(self, ys: Tensor) -> List[str]: + return [self.decode(y) for y in ys] + + def decode_logits(self, logits: Tensor) -> List[str]: + ys = self.logits_to_indices(logits) + return self.batch_decode(ys) + def encode(self, text: str) -> Tensor: """Returns tensor of indices for a string.""" return Tensor([self.inverse_mapping[token] for token in text]) + def first_appearance(self, x: Tensor, dim: int) -> Tensor: + if x.dim() > 2 or x.dim() == 0: + raise ValueError( + f"Only 1 or 2 dimensional tensors allowed, got a tensor with dim {x.dim()}" + ) + matches = x == self.end_index + mask = (matches.cumsum(dim) == 1) & matches + does_match, index = mask.max(dim) + first = torch.where(does_match, index, x.shape[dim]) + return first + + def replace_after(self, x: Tensor) -> Tensor: + first_appearance = self.first_appearance(x, dim=1) + indices = torch.arange(0, x.shape[-1]).type_as(x) + output = torch.where( + indices[None, :] + <= first_appearance[:, None], # if index is before first appearance + x, # return the value from x + self.pad_index, # otherwise, return the replacement value + ) + return output # [B, N] + + def logits_to_indices(self, logits: Tensor) -> Tensor: + preds = logits.argmax(dim=1) + if self.replace_after_end: + return self.replace_after(preds) # [B, N] + else: + return preds # [B, N] + def __getitem__(self, x: Union[int, Tensor]) -> str: """Returns text for a list of indices.""" return self.get_token(x) -- cgit v1.2.3-70-g09d2