summaryrefslogtreecommitdiff
path: root/text_recognizer/models/greedy_decoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:18:31 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:18:31 +0200
commit0421daf6bd97596703f426ba61c401599b538eeb (patch)
tree3346a27d09bb16e3c891a7d4f3eaf5721a2dd378 /text_recognizer/models/greedy_decoder.py
parent54d8b230eedfdf587e2d2d214d65582fe78c47eb (diff)
Rename
Diffstat (limited to 'text_recognizer/models/greedy_decoder.py')
-rw-r--r--text_recognizer/models/greedy_decoder.py51
1 files changed, 0 insertions, 51 deletions
diff --git a/text_recognizer/models/greedy_decoder.py b/text_recognizer/models/greedy_decoder.py
deleted file mode 100644
index 9d2f192..0000000
--- a/text_recognizer/models/greedy_decoder.py
+++ /dev/null
@@ -1,51 +0,0 @@
-"""Greedy decoder."""
-from typing import Type
-from text_recognizer.data.tokenizer import Tokenizer
-import torch
-from torch import nn, Tensor
-
-
-class GreedyDecoder:
- def __init__(
- self,
- network: Type[nn.Module],
- tokenizer: Tokenizer,
- max_output_len: int = 682,
- ) -> None:
- self.network = network
- self.start_index = tokenizer.start_index
- self.end_index = tokenizer.end_index
- self.pad_index = tokenizer.pad_index
- self.max_output_len = max_output_len
-
- def __call__(self, x: Tensor) -> Tensor:
- bsz = x.shape[0]
-
- # Encode image(s) to latent vectors.
- img_features = self.network.encode(x)
-
- # Create a placeholder matrix for storing outputs from the network
- indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- indecies[:, 0] = self.start_index
-
- for Sy in range(1, self.max_output_len):
- tokens = indecies[:, :Sy] # (B, Sy)
- logits = self.network.decode(tokens, img_features) # (B, C, Sy)
- indecies_ = torch.argmax(logits, dim=1) # (B, Sy)
- indecies[:, Sy : Sy + 1] = indecies_[:, -1:]
-
- # Early stopping of prediction loop if token is end or padding token.
- if (
- (indecies[:, Sy - 1] == self.end_index)
- | (indecies[:, Sy - 1] == self.pad_index)
- ).all():
- break
-
- # Set all tokens after end token to pad token.
- for Sy in range(1, self.max_output_len):
- idx = (indecies[:, Sy - 1] == self.end_index) | (
- indecies[:, Sy - 1] == self.pad_index
- )
- indecies[idx, Sy] = self.pad_index
-
- return indecies