summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-11 22:11:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-11 22:11:39 +0200
commitb2ad1ec306d56bbc319b7b41fbbcff04307425d5 (patch)
tree6ddf8861f9d88adf730ef2e7a63bccae681b46c8
parent3a9ca4a230b59e9025216383664da8ef1780a3a0 (diff)
Remove absolute embedding
-rw-r--r--text_recognizer/network/transformer/embedding/absolute.py29
1 files changed, 0 insertions, 29 deletions
diff --git a/text_recognizer/network/transformer/embedding/absolute.py b/text_recognizer/network/transformer/embedding/absolute.py
deleted file mode 100644
index db34157..0000000
--- a/text_recognizer/network/transformer/embedding/absolute.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from typing import Optional
-
-import torch
-from torch import nn, Tensor
-
-from .l2_norm import l2_norm
-
-
-class AbsolutePositionalEmbedding(nn.Module):
- def __init__(self, dim: int, max_length: int, use_l2: bool = False) -> None:
- super().__init__()
- self.scale = dim**-0.5 if not use_l2 else 1.0
- self.max_length = max_length
- self.use_l2 = use_l2
- self.to_embedding = nn.Embedding(max_length, dim)
- if self.use_l2:
- nn.init.normal_(self.to_embedding.weight, std=1e-5)
-
- def forward(self, x: Tensor, pos: Optional[Tensor] = None) -> Tensor:
- n, device = x.shape[1], x.device
- assert (
- n <= self.max_length
- ), f"Sequence length {n} is greater than the maximum positional embedding {self.max_length}"
-
- if pos is None:
- pos = torch.arange(n, device=device)
-
- embedding = self.to_embedding(pos) * self.scale
- return l2_norm(embedding) if self.use_l2 else embedding