summaryrefslogtreecommitdiff
path: root/text_recognizer/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/decoder')
-rw-r--r--text_recognizer/decoder/__init__.py0
-rw-r--r--text_recognizer/decoder/greedy_decoder.py53
2 files changed, 53 insertions, 0 deletions
diff --git a/text_recognizer/decoder/__init__.py b/text_recognizer/decoder/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/text_recognizer/decoder/__init__.py
diff --git a/text_recognizer/decoder/greedy_decoder.py b/text_recognizer/decoder/greedy_decoder.py
new file mode 100644
index 0000000..8d55a02
--- /dev/null
+++ b/text_recognizer/decoder/greedy_decoder.py
@@ -0,0 +1,53 @@
+"""Greedy decoder."""
+from typing import Type
+from text_recognizer.data.tokenizer import Tokenizer
+import torch
+from torch import nn, Tensor
+
+
+class GreedyDecoder:
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ tokenizer: Tokenizer,
+ max_output_len: int = 682,
+ ) -> None:
+ self.network = network
+ self.start_index = tokenizer.start_index
+ self.end_index = tokenizer.end_index
+ self.pad_index = tokenizer.pad_index
+ self.max_output_len = max_output_len
+
+ def __call__(self, x: Tensor) -> Tensor:
+ bsz = x.shape[0]
+
+ # Encode image(s) to latent vectors.
+ img_features = self.network.encode(x)
+
+ # Create a placeholder matrix for storing outputs from the network
+ indecies = (
+ torch.ones((bsz, self.max_output_len), dtype=torch.long, device=x.device)
+ * self.pad_index
+ )
+ indecies[:, 0] = self.start_index
+
+ for i in range(1, self.max_output_len):
+ tokens = indecies[:, :i] # (B, Sy)
+ logits = self.network.decode(tokens, img_features) # [ B, N, C ]
+ indecies_ = logits.argmax(dim=2) # [ B, N ]
+ indecies[:, i] = indecies_[:, -1]
+
+ # Early stopping of prediction loop if token is end or padding token.
+ if (
+ (indecies[:, i] == self.end_index) | (indecies[:, i] == self.pad_index)
+ ).all():
+ break
+
+ # Set all tokens after end token to pad token.
+ for i in range(1, self.max_output_len):
+ idx = (indecies[:, i - 1] == self.end_index) | (
+ indecies[:, i - 1] == self.pad_index
+ )
+ indecies[idx, i] = self.pad_index
+
+ return indecies