summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/transformer/positional_encoding.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/transformer/positional_encoding.py')
-rw-r--r--src/text_recognizer/networks/transformer/positional_encoding.py32
1 files changed, 0 insertions, 32 deletions
diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py
deleted file mode 100644
index 1ba5537..0000000
--- a/src/text_recognizer/networks/transformer/positional_encoding.py
+++ /dev/null
@@ -1,32 +0,0 @@
-"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class PositionalEncoding(nn.Module):
- """Encodes a sense of distance or time for transformer networks."""
-
- def __init__(
- self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
- ) -> None:
- super().__init__()
- self.dropout = nn.Dropout(p=dropout_rate)
- self.max_len = max_len
-
- pe = torch.zeros(max_len, hidden_dim)
- position = torch.arange(0, max_len).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
- )
-
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = pe.unsqueeze(0)
- self.register_buffer("pe", pe)
-
- def forward(self, x: Tensor) -> Tensor:
- """Encodes the tensor with a postional embedding."""
- x = x + self.pe[:, : x.shape[1]]
- return self.dropout(x)