summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/beam.py
blob: dccccdb75ad4d846315b6a49bdccdf964dbf69f5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Implementation of beam search decoder for a sequence to sequence network.

Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py

"""
# from typing import List
# from Queue import PriorityQueue

# from loguru import logger
# import torch
# from torch import nn
# from torch import Tensor
# import torch.nn.functional as F


# class Node:
#     def __init__(
#         self, parent: Node, target_index: int, log_prob: Tensor, length: int
#     ) -> None:
#         self.parent = parent
#         self.target_index = target_index
#         self.log_prob = log_prob
#         self.length = length
#         self.reward = 0.0

#     def eval(self, alpha: float = 1.0) -> Tensor:
#         return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward


# @torch.no_grad()
# def beam_decoder(
#     network, mapper, device, memory: Tensor = None, max_len: int = 97,
# ) -> Tensor:
#     beam_width = 10
#     topk = 1  # How many sentences to generate.

#     trg_indices = [mapper(mapper.init_token)]

#     end_nodes = []

#     node = Node(None, trg_indices, 0, 1)
#     nodes = PriorityQueue()

#     nodes.put((node.eval(), node))
#     q_size = 1

#     # Beam search
#     for _ in range(max_len):
#         if q_size > 2000:
#             logger.warning("Could not decoder input")
#             break

#         # Fetch the best node.
#         score, n = nodes.get()
#         decoder_input = n.target_index

#         if n.target_index == mapper(mapper.eos_token) and n.parent is not None:
#             end_nodes.append((score, n))

#             # If we reached the maximum number of sentences required.
#             if len(end_nodes) >= 1:
#                 break
#             else:
#                 continue

#         # Forward pass with transformer.
#         trg = torch.tensor(trg_indices, device=device)[None, :].long()
#         trg = network.target_embedding(trg)
#         logits = network.decoder(trg=trg, memory=memory, trg_mask=None)
#         log_prob = F.log_softmax(logits, dim=2)

#         log_prob, indices = torch.topk(log_prob, beam_width)

#         for new_k in range(beam_width):
#             # TODO: continue from here
#             token_index = indices[0][new_k].view(1, -1)
#             log_p = log_prob[0][new_k].item()

#             node = Node()

#             pass

#     pass