diff options
Diffstat (limited to 'text_recognizer/networks/transformer/embeddings')
-rw-r--r-- | text_recognizer/networks/transformer/embeddings/absolute.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/absolute.py b/text_recognizer/networks/transformer/embeddings/absolute.py new file mode 100644 index 0000000..e5cdc18 --- /dev/null +++ b/text_recognizer/networks/transformer/embeddings/absolute.py @@ -0,0 +1,34 @@ +"""Absolute positional embedding.""" + +from einops import rearrange +import torch +from torch import nn +import torch.nn.functional as F + + +def l2norm(t, groups=1): + t = rearrange(t, "... (g d) -> ... g d", g=groups) + t = F.normalize(t, p=2, dim=-1) + return rearrange(t, "... g d -> ... (g d)") + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len, l2norm_embed=False): + super().__init__() + self.scale = dim ** -0.5 if not l2norm_embed else 1.0 + self.max_seq_len = max_seq_len + self.l2norm_embed = l2norm_embed + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos=None): + seq_len = x.shape[1] + assert ( + seq_len <= self.max_seq_len + ), f"you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}" + + if pos is None: + pos = torch.arange(seq_len, device=x.device) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return l2norm(pos_emb) if self.l2norm_embed else pos_emb |