diff options
Diffstat (limited to 'text_recognizer/networks/transformer/embeddings/absolute.py')
-rw-r--r-- | text_recognizer/networks/transformer/embeddings/absolute.py | 34 |
1 files changed, 0 insertions, 34 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/absolute.py b/text_recognizer/networks/transformer/embeddings/absolute.py deleted file mode 100644 index 9274b55..0000000 --- a/text_recognizer/networks/transformer/embeddings/absolute.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Absolute positional embedding.""" - -import torch -import torch.nn.functional as F -from einops import rearrange -from torch import nn - - -def l2norm(t, groups=1): - t = rearrange(t, "... (g d) -> ... g d", g=groups) - t = F.normalize(t, p=2, dim=-1) - return rearrange(t, "... g d -> ... (g d)") - - -class AbsolutePositionalEmbedding(nn.Module): - def __init__(self, dim, max_seq_len, l2norm_embed=False): - super().__init__() - self.scale = dim**-0.5 if not l2norm_embed else 1.0 - self.max_seq_len = max_seq_len - self.l2norm_embed = l2norm_embed - self.emb = nn.Embedding(max_seq_len, dim) - - def forward(self, x, pos=None): - seq_len = x.shape[1] - assert ( - seq_len <= self.max_seq_len - ), f"you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}" - - if pos is None: - pos = torch.arange(seq_len, device=x.device) - - pos_emb = self.emb(pos) - pos_emb = pos_emb * self.scale - return l2norm(pos_emb) if self.l2norm_embed else pos_emb |