summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/decoder/greedy_decoder.py24
1 files changed, 13 insertions, 11 deletions
diff --git a/text_recognizer/decoder/greedy_decoder.py b/text_recognizer/decoder/greedy_decoder.py
index 8d55a02..78864cd 100644
--- a/text_recognizer/decoder/greedy_decoder.py
+++ b/text_recognizer/decoder/greedy_decoder.py
@@ -1,8 +1,10 @@
"""Greedy decoder."""
from typing import Type
-from text_recognizer.data.tokenizer import Tokenizer
+
import torch
-from torch import nn, Tensor
+from torch import Tensor, nn
+
+from text_recognizer.data.tokenizer import Tokenizer
class GreedyDecoder:
@@ -25,29 +27,29 @@ class GreedyDecoder:
img_features = self.network.encode(x)
# Create a placeholder matrix for storing outputs from the network
- indecies = (
+ indices = (
torch.ones((bsz, self.max_output_len), dtype=torch.long, device=x.device)
* self.pad_index
)
- indecies[:, 0] = self.start_index
+ indices[:, 0] = self.start_index
for i in range(1, self.max_output_len):
- tokens = indecies[:, :i] # (B, Sy)
+ tokens = indices[:, :i] # (B, Sy)
logits = self.network.decode(tokens, img_features) # [ B, N, C ]
indecies_ = logits.argmax(dim=2) # [ B, N ]
- indecies[:, i] = indecies_[:, -1]
+ indices[:, i] = indecies_[:, -1]
# Early stopping of prediction loop if token is end or padding token.
if (
- (indecies[:, i] == self.end_index) | (indecies[:, i] == self.pad_index)
+ (indices[:, i] == self.end_index) | (indices[:, i] == self.pad_index)
).all():
break
# Set all tokens after end token to pad token.
for i in range(1, self.max_output_len):
- idx = (indecies[:, i - 1] == self.end_index) | (
- indecies[:, i - 1] == self.pad_index
+ idx = (indices[:, i - 1] == self.end_index) | (
+ indices[:, i - 1] == self.pad_index
)
- indecies[idx, i] = self.pad_index
+ indices[idx, i] = self.pad_index
- return indecies
+ return indices