From 27b001503f068a89acc40cc960a8b54feb1bddc3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 2 Sep 2023 01:53:03 +0200 Subject: Fix bug in decoder --- text_recognizer/model/greedy_decoder.py | 53 +++++++++++++++------------------ 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/text_recognizer/model/greedy_decoder.py b/text_recognizer/model/greedy_decoder.py index 5cbbb66..2c4c16e 100644 --- a/text_recognizer/model/greedy_decoder.py +++ b/text_recognizer/model/greedy_decoder.py @@ -25,34 +25,29 @@ class GreedyDecoder: img_features = self.network.encode(x) # Create a placeholder matrix for storing outputs from the network - indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + indecies = ( + torch.ones((bsz, self.max_output_len), dtype=torch.long, device=x.device) + * self.pad_index + ) indecies[:, 0] = self.start_index - try: - for i in range(1, self.max_output_len): - tokens = indecies[:, :i] # (B, Sy) - logits = self.network.decode(tokens, img_features) # (B, C, Sy) - indecies_ = torch.argmax(logits, dim=1) # (B, Sy) - indecies[:, i : i + 1] = indecies_[:, -1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - (indecies[:, i - 1] == self.end_index) - | (indecies[:, i - 1] == self.pad_index) - ).all(): - break - - # Set all tokens after end token to pad token. - for i in range(1, self.max_output_len): - idx = (indecies[:, i - 1] == self.end_index) | ( - indecies[:, i - 1] == self.pad_index - ) - indecies[idx, i] = self.pad_index - - return indecies - except Exception: - # TODO: investigate this error more - print(x.shape) - # print(indecies) - print(indecies.shape) - print(img_features.shape) + for i in range(1, self.max_output_len): + tokens = indecies[:, :i] # (B, Sy) + logits = self.network.decode(tokens, img_features) # [ B, N, C ] + indecies_ = torch.argmax(logits, dim=2) # [ B, N ] + indecies[:, i] = indecies_[:, -1] + + # Early stopping of prediction loop if token is end or padding token. + if ( + (indecies[:, i] == self.end_index) | (indecies[:, i] == self.pad_index) + ).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = (indecies[:, i - 1] == self.end_index) | ( + indecies[:, i - 1] == self.pad_index + ) + indecies[idx, i] = self.pad_index + + return indecies -- cgit v1.2.3-70-g09d2