From 7e8e54e84c63171e748bbf09516fd517e6821ace Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 20 Mar 2021 18:09:06 +0100 Subject: Inital commit for refactoring to lightning --- .../networks/transformer/positional_encoding.py | 32 ---------------------- 1 file changed, 32 deletions(-) delete mode 100644 src/text_recognizer/networks/transformer/positional_encoding.py (limited to 'src/text_recognizer/networks/transformer/positional_encoding.py') diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py deleted file mode 100644 index 1ba5537..0000000 --- a/src/text_recognizer/networks/transformer/positional_encoding.py +++ /dev/null @@ -1,32 +0,0 @@ -"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class PositionalEncoding(nn.Module): - """Encodes a sense of distance or time for transformer networks.""" - - def __init__( - self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 - ) -> None: - super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) - self.max_len = max_len - - pe = torch.zeros(max_len, hidden_dim) - position = torch.arange(0, max_len).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim) - ) - - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - self.register_buffer("pe", pe) - - def forward(self, x: Tensor) -> Tensor: - """Encodes the tensor with a postional embedding.""" - x = x + self.pe[:, : x.shape[1]] - return self.dropout(x) -- cgit v1.2.3-70-g09d2