summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:03:34 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:03:34 +0200
commit03c19e0b51e6dbb5a0343e9d1d1bc18c184a164f (patch)
treefafa7e06883f111f5ae9c289b31fa9ddc720d26c /text_recognizer/networks/transformer
parent73ccaaa24936faed36fcc467532baa5386d402ae (diff)
Add absolute pos embedding
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/embeddings/absolute.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/absolute.py b/text_recognizer/networks/transformer/embeddings/absolute.py
new file mode 100644
index 0000000..e5cdc18
--- /dev/null
+++ b/text_recognizer/networks/transformer/embeddings/absolute.py
@@ -0,0 +1,34 @@
+"""Absolute positional embedding."""
+
+from einops import rearrange
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+def l2norm(t, groups=1):
+ t = rearrange(t, "... (g d) -> ... g d", g=groups)
+ t = F.normalize(t, p=2, dim=-1)
+ return rearrange(t, "... g d -> ... (g d)")
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ def __init__(self, dim, max_seq_len, l2norm_embed=False):
+ super().__init__()
+ self.scale = dim ** -0.5 if not l2norm_embed else 1.0
+ self.max_seq_len = max_seq_len
+ self.l2norm_embed = l2norm_embed
+ self.emb = nn.Embedding(max_seq_len, dim)
+
+ def forward(self, x, pos=None):
+ seq_len = x.shape[1]
+ assert (
+ seq_len <= self.max_seq_len
+ ), f"you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}"
+
+ if pos is None:
+ pos = torch.arange(seq_len, device=x.device)
+
+ pos_emb = self.emb(pos)
+ pos_emb = pos_emb * self.scale
+ return l2norm(pos_emb) if self.l2norm_embed else pos_emb