diff options
Diffstat (limited to 'text_recognizer/networks/beam.py')
-rw-r--r-- | text_recognizer/networks/beam.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/text_recognizer/networks/beam.py b/text_recognizer/networks/beam.py new file mode 100644 index 0000000..dccccdb --- /dev/null +++ b/text_recognizer/networks/beam.py @@ -0,0 +1,83 @@ +"""Implementation of beam search decoder for a sequence to sequence network. + +Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py + +""" +# from typing import List +# from Queue import PriorityQueue + +# from loguru import logger +# import torch +# from torch import nn +# from torch import Tensor +# import torch.nn.functional as F + + +# class Node: +# def __init__( +# self, parent: Node, target_index: int, log_prob: Tensor, length: int +# ) -> None: +# self.parent = parent +# self.target_index = target_index +# self.log_prob = log_prob +# self.length = length +# self.reward = 0.0 + +# def eval(self, alpha: float = 1.0) -> Tensor: +# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward + + +# @torch.no_grad() +# def beam_decoder( +# network, mapper, device, memory: Tensor = None, max_len: int = 97, +# ) -> Tensor: +# beam_width = 10 +# topk = 1 # How many sentences to generate. + +# trg_indices = [mapper(mapper.init_token)] + +# end_nodes = [] + +# node = Node(None, trg_indices, 0, 1) +# nodes = PriorityQueue() + +# nodes.put((node.eval(), node)) +# q_size = 1 + +# # Beam search +# for _ in range(max_len): +# if q_size > 2000: +# logger.warning("Could not decoder input") +# break + +# # Fetch the best node. +# score, n = nodes.get() +# decoder_input = n.target_index + +# if n.target_index == mapper(mapper.eos_token) and n.parent is not None: +# end_nodes.append((score, n)) + +# # If we reached the maximum number of sentences required. +# if len(end_nodes) >= 1: +# break +# else: +# continue + +# # Forward pass with transformer. +# trg = torch.tensor(trg_indices, device=device)[None, :].long() +# trg = network.target_embedding(trg) +# logits = network.decoder(trg=trg, memory=memory, trg_mask=None) +# log_prob = F.log_softmax(logits, dim=2) + +# log_prob, indices = torch.topk(log_prob, beam_width) + +# for new_k in range(beam_width): +# # TODO: continue from here +# token_index = indices[0][new_k].view(1, -1) +# log_p = log_prob[0][new_k].item() + +# node = Node() + +# pass + +# pass |