summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/positional_encoding.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /text_recognizer/networks/transformer/positional_encoding.py
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'text_recognizer/networks/transformer/positional_encoding.py')
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
new file mode 100644
index 0000000..1ba5537
--- /dev/null
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -0,0 +1,32 @@
+"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+ """Encodes a sense of distance or time for transformer networks."""
+
+ def __init__(
+ self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
+ ) -> None:
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout_rate)
+ self.max_len = max_len
+
+ pe = torch.zeros(max_len, hidden_dim)
+ position = torch.arange(0, max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
+ )
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Encodes the tensor with a postional embedding."""
+ x = x + self.pe[:, : x.shape[1]]
+ return self.dropout(x)