diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-05 00:03:34 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-05 00:03:34 +0200 |
commit | 03c19e0b51e6dbb5a0343e9d1d1bc18c184a164f (patch) | |
tree | fafa7e06883f111f5ae9c289b31fa9ddc720d26c /text_recognizer | |
parent | 73ccaaa24936faed36fcc467532baa5386d402ae (diff) |
Add absolute pos embedding
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/transformer/embeddings/absolute.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/absolute.py b/text_recognizer/networks/transformer/embeddings/absolute.py new file mode 100644 index 0000000..e5cdc18 --- /dev/null +++ b/text_recognizer/networks/transformer/embeddings/absolute.py @@ -0,0 +1,34 @@ +"""Absolute positional embedding.""" + +from einops import rearrange +import torch +from torch import nn +import torch.nn.functional as F + + +def l2norm(t, groups=1): + t = rearrange(t, "... (g d) -> ... g d", g=groups) + t = F.normalize(t, p=2, dim=-1) + return rearrange(t, "... g d -> ... (g d)") + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len, l2norm_embed=False): + super().__init__() + self.scale = dim ** -0.5 if not l2norm_embed else 1.0 + self.max_seq_len = max_seq_len + self.l2norm_embed = l2norm_embed + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos=None): + seq_len = x.shape[1] + assert ( + seq_len <= self.max_seq_len + ), f"you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}" + + if pos is None: + pos = torch.arange(seq_len, device=x.device) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return l2norm(pos_emb) if self.l2norm_embed else pos_emb |