diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-15 21:50:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-15 21:50:20 +0200 |
commit | f8efbce6536025f1c0fb95d3cbb4c489cb51af49 (patch) | |
tree | 8431bab4736ba69858b0fd3efb2d1a8321497395 /text_recognizer/decoder | |
parent | 73edf3fcb4c47cda27e230c98718d4abdc3400e2 (diff) |
Diffstat (limited to 'text_recognizer/decoder')
-rw-r--r-- | text_recognizer/decoder/greedy_decoder.py | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/text_recognizer/decoder/greedy_decoder.py b/text_recognizer/decoder/greedy_decoder.py index 8d55a02..78864cd 100644 --- a/text_recognizer/decoder/greedy_decoder.py +++ b/text_recognizer/decoder/greedy_decoder.py @@ -1,8 +1,10 @@ """Greedy decoder.""" from typing import Type -from text_recognizer.data.tokenizer import Tokenizer + import torch -from torch import nn, Tensor +from torch import Tensor, nn + +from text_recognizer.data.tokenizer import Tokenizer class GreedyDecoder: @@ -25,29 +27,29 @@ class GreedyDecoder: img_features = self.network.encode(x) # Create a placeholder matrix for storing outputs from the network - indecies = ( + indices = ( torch.ones((bsz, self.max_output_len), dtype=torch.long, device=x.device) * self.pad_index ) - indecies[:, 0] = self.start_index + indices[:, 0] = self.start_index for i in range(1, self.max_output_len): - tokens = indecies[:, :i] # (B, Sy) + tokens = indices[:, :i] # (B, Sy) logits = self.network.decode(tokens, img_features) # [ B, N, C ] indecies_ = logits.argmax(dim=2) # [ B, N ] - indecies[:, i] = indecies_[:, -1] + indices[:, i] = indecies_[:, -1] # Early stopping of prediction loop if token is end or padding token. if ( - (indecies[:, i] == self.end_index) | (indecies[:, i] == self.pad_index) + (indices[:, i] == self.end_index) | (indices[:, i] == self.pad_index) ).all(): break # Set all tokens after end token to pad token. for i in range(1, self.max_output_len): - idx = (indecies[:, i - 1] == self.end_index) | ( - indecies[:, i - 1] == self.pad_index + idx = (indices[:, i - 1] == self.end_index) | ( + indices[:, i - 1] == self.pad_index ) - indecies[idx, i] = self.pad_index + indices[idx, i] = self.pad_index - return indecies + return indices |