summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/transformer/positional_encoding.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
commitdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch)
tree1b5fc0d06952e13727e85c4f973a26d277068453 /src/text_recognizer/networks/transformer/positional_encoding.py
parente181195a699d7fa237f256d90ab4dedffc03d405 (diff)
new updates
Diffstat (limited to 'src/text_recognizer/networks/transformer/positional_encoding.py')
-rw-r--r--src/text_recognizer/networks/transformer/positional_encoding.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py
new file mode 100644
index 0000000..1ba5537
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/positional_encoding.py
@@ -0,0 +1,32 @@
+"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+ """Encodes a sense of distance or time for transformer networks."""
+
+ def __init__(
+ self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
+ ) -> None:
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout_rate)
+ self.max_len = max_len
+
+ pe = torch.zeros(max_len, hidden_dim)
+ position = torch.arange(0, max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
+ )
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Encodes the tensor with a postional embedding."""
+ x = x + self.pe[:, : x.shape[1]]
+ return self.dropout(x)