From 03c19e0b51e6dbb5a0343e9d1d1bc18c184a164f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Sep 2022 00:03:34 +0200 Subject: Add absolute pos embedding --- .../networks/transformer/embeddings/absolute.py | 34 ++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 text_recognizer/networks/transformer/embeddings/absolute.py diff --git a/text_recognizer/networks/transformer/embeddings/absolute.py b/text_recognizer/networks/transformer/embeddings/absolute.py new file mode 100644 index 0000000..e5cdc18 --- /dev/null +++ b/text_recognizer/networks/transformer/embeddings/absolute.py @@ -0,0 +1,34 @@ +"""Absolute positional embedding.""" + +from einops import rearrange +import torch +from torch import nn +import torch.nn.functional as F + + +def l2norm(t, groups=1): + t = rearrange(t, "... (g d) -> ... g d", g=groups) + t = F.normalize(t, p=2, dim=-1) + return rearrange(t, "... g d -> ... (g d)") + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len, l2norm_embed=False): + super().__init__() + self.scale = dim ** -0.5 if not l2norm_embed else 1.0 + self.max_seq_len = max_seq_len + self.l2norm_embed = l2norm_embed + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos=None): + seq_len = x.shape[1] + assert ( + seq_len <= self.max_seq_len + ), f"you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}" + + if pos is None: + pos = torch.arange(seq_len, device=x.device) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return l2norm(pos_emb) if self.l2norm_embed else pos_emb -- cgit v1.2.3-70-g09d2