From 20225cbabbf327f34c0e4040ae8b48eecdbe424c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 11 Sep 2023 22:13:59 +0200 Subject: Move greedy decoder --- text_recognizer/decoder/__init__.py | 0 text_recognizer/decoder/greedy_decoder.py | 53 +++++++++++++++++++++++++++++++ text_recognizer/model/greedy_decoder.py | 53 ------------------------------- 3 files changed, 53 insertions(+), 53 deletions(-) create mode 100644 text_recognizer/decoder/__init__.py create mode 100644 text_recognizer/decoder/greedy_decoder.py delete mode 100644 text_recognizer/model/greedy_decoder.py diff --git a/text_recognizer/decoder/__init__.py b/text_recognizer/decoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text_recognizer/decoder/greedy_decoder.py b/text_recognizer/decoder/greedy_decoder.py new file mode 100644 index 0000000..8d55a02 --- /dev/null +++ b/text_recognizer/decoder/greedy_decoder.py @@ -0,0 +1,53 @@ +"""Greedy decoder.""" +from typing import Type +from text_recognizer.data.tokenizer import Tokenizer +import torch +from torch import nn, Tensor + + +class GreedyDecoder: + def __init__( + self, + network: Type[nn.Module], + tokenizer: Tokenizer, + max_output_len: int = 682, + ) -> None: + self.network = network + self.start_index = tokenizer.start_index + self.end_index = tokenizer.end_index + self.pad_index = tokenizer.pad_index + self.max_output_len = max_output_len + + def __call__(self, x: Tensor) -> Tensor: + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + img_features = self.network.encode(x) + + # Create a placeholder matrix for storing outputs from the network + indecies = ( + torch.ones((bsz, self.max_output_len), dtype=torch.long, device=x.device) + * self.pad_index + ) + indecies[:, 0] = self.start_index + + for i in range(1, self.max_output_len): + tokens = indecies[:, :i] # (B, Sy) + logits = self.network.decode(tokens, img_features) # [ B, N, C ] + indecies_ = logits.argmax(dim=2) # [ B, N ] + indecies[:, i] = indecies_[:, -1] + + # Early stopping of prediction loop if token is end or padding token. + if ( + (indecies[:, i] == self.end_index) | (indecies[:, i] == self.pad_index) + ).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = (indecies[:, i - 1] == self.end_index) | ( + indecies[:, i - 1] == self.pad_index + ) + indecies[idx, i] = self.pad_index + + return indecies diff --git a/text_recognizer/model/greedy_decoder.py b/text_recognizer/model/greedy_decoder.py deleted file mode 100644 index 8d55a02..0000000 --- a/text_recognizer/model/greedy_decoder.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Greedy decoder.""" -from typing import Type -from text_recognizer.data.tokenizer import Tokenizer -import torch -from torch import nn, Tensor - - -class GreedyDecoder: - def __init__( - self, - network: Type[nn.Module], - tokenizer: Tokenizer, - max_output_len: int = 682, - ) -> None: - self.network = network - self.start_index = tokenizer.start_index - self.end_index = tokenizer.end_index - self.pad_index = tokenizer.pad_index - self.max_output_len = max_output_len - - def __call__(self, x: Tensor) -> Tensor: - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - img_features = self.network.encode(x) - - # Create a placeholder matrix for storing outputs from the network - indecies = ( - torch.ones((bsz, self.max_output_len), dtype=torch.long, device=x.device) - * self.pad_index - ) - indecies[:, 0] = self.start_index - - for i in range(1, self.max_output_len): - tokens = indecies[:, :i] # (B, Sy) - logits = self.network.decode(tokens, img_features) # [ B, N, C ] - indecies_ = logits.argmax(dim=2) # [ B, N ] - indecies[:, i] = indecies_[:, -1] - - # Early stopping of prediction loop if token is end or padding token. - if ( - (indecies[:, i] == self.end_index) | (indecies[:, i] == self.pad_index) - ).all(): - break - - # Set all tokens after end token to pad token. - for i in range(1, self.max_output_len): - idx = (indecies[:, i - 1] == self.end_index) | ( - indecies[:, i - 1] == self.pad_index - ) - indecies[idx, i] = self.pad_index - - return indecies -- cgit v1.2.3-70-g09d2