diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-02 01:53:03 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-02 01:53:03 +0200 |
commit | 27b001503f068a89acc40cc960a8b54feb1bddc3 (patch) | |
tree | 6d6c8e86955288db98997e797efce1d2527bca5b | |
parent | 579f3edc3e20ddbe8207ee0c4189a270b2dfedc1 (diff) |
Fix bug in decoder
-rw-r--r-- | text_recognizer/model/greedy_decoder.py | 53 |
1 files changed, 24 insertions, 29 deletions
diff --git a/text_recognizer/model/greedy_decoder.py b/text_recognizer/model/greedy_decoder.py index 5cbbb66..2c4c16e 100644 --- a/text_recognizer/model/greedy_decoder.py +++ b/text_recognizer/model/greedy_decoder.py @@ -25,34 +25,29 @@ class GreedyDecoder: img_features = self.network.encode(x) # Create a placeholder matrix for storing outputs from the network - indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + indecies = ( + torch.ones((bsz, self.max_output_len), dtype=torch.long, device=x.device) + * self.pad_index + ) indecies[:, 0] = self.start_index - try: - for i in range(1, self.max_output_len): - tokens = indecies[:, :i] # (B, Sy) - logits = self.network.decode(tokens, img_features) # (B, C, Sy) - indecies_ = torch.argmax(logits, dim=1) # (B, Sy) - indecies[:, i : i + 1] = indecies_[:, -1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - (indecies[:, i - 1] == self.end_index) - | (indecies[:, i - 1] == self.pad_index) - ).all(): - break - - # Set all tokens after end token to pad token. - for i in range(1, self.max_output_len): - idx = (indecies[:, i - 1] == self.end_index) | ( - indecies[:, i - 1] == self.pad_index - ) - indecies[idx, i] = self.pad_index - - return indecies - except Exception: - # TODO: investigate this error more - print(x.shape) - # print(indecies) - print(indecies.shape) - print(img_features.shape) + for i in range(1, self.max_output_len): + tokens = indecies[:, :i] # (B, Sy) + logits = self.network.decode(tokens, img_features) # [ B, N, C ] + indecies_ = torch.argmax(logits, dim=2) # [ B, N ] + indecies[:, i] = indecies_[:, -1] + + # Early stopping of prediction loop if token is end or padding token. + if ( + (indecies[:, i] == self.end_index) | (indecies[:, i] == self.pad_index) + ).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = (indecies[:, i - 1] == self.end_index) | ( + indecies[:, i - 1] == self.pad_index + ) + indecies[idx, i] = self.pad_index + + return indecies |