summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/model/greedy_decoder.py53
1 files changed, 24 insertions, 29 deletions
diff --git a/text_recognizer/model/greedy_decoder.py b/text_recognizer/model/greedy_decoder.py
index 5cbbb66..2c4c16e 100644
--- a/text_recognizer/model/greedy_decoder.py
+++ b/text_recognizer/model/greedy_decoder.py
@@ -25,34 +25,29 @@ class GreedyDecoder:
img_features = self.network.encode(x)
# Create a placeholder matrix for storing outputs from the network
- indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
+ indecies = (
+ torch.ones((bsz, self.max_output_len), dtype=torch.long, device=x.device)
+ * self.pad_index
+ )
indecies[:, 0] = self.start_index
- try:
- for i in range(1, self.max_output_len):
- tokens = indecies[:, :i] # (B, Sy)
- logits = self.network.decode(tokens, img_features) # (B, C, Sy)
- indecies_ = torch.argmax(logits, dim=1) # (B, Sy)
- indecies[:, i : i + 1] = indecies_[:, -1:]
-
- # Early stopping of prediction loop if token is end or padding token.
- if (
- (indecies[:, i - 1] == self.end_index)
- | (indecies[:, i - 1] == self.pad_index)
- ).all():
- break
-
- # Set all tokens after end token to pad token.
- for i in range(1, self.max_output_len):
- idx = (indecies[:, i - 1] == self.end_index) | (
- indecies[:, i - 1] == self.pad_index
- )
- indecies[idx, i] = self.pad_index
-
- return indecies
- except Exception:
- # TODO: investigate this error more
- print(x.shape)
- # print(indecies)
- print(indecies.shape)
- print(img_features.shape)
+ for i in range(1, self.max_output_len):
+ tokens = indecies[:, :i] # (B, Sy)
+ logits = self.network.decode(tokens, img_features) # [ B, N, C ]
+ indecies_ = torch.argmax(logits, dim=2) # [ B, N ]
+ indecies[:, i] = indecies_[:, -1]
+
+ # Early stopping of prediction loop if token is end or padding token.
+ if (
+ (indecies[:, i] == self.end_index) | (indecies[:, i] == self.pad_index)
+ ).all():
+ break
+
+ # Set all tokens after end token to pad token.
+ for i in range(1, self.max_output_len):
+ idx = (indecies[:, i - 1] == self.end_index) | (
+ indecies[:, i - 1] == self.pad_index
+ )
+ indecies[idx, i] = self.pad_index
+
+ return indecies