summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/embeddings/axial.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:16:56 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:16:56 +0200
commit09797e2ba757ec3a3545387a0b57c2be1956b38b (patch)
tree6f211bf3a1c88d93b78eb2c52348f716da9a8ec5 /text_recognizer/networks/transformer/embeddings/axial.py
parent3e24b92ee1bac124ea8c7bddb15236ccc5fe300d (diff)
Add axial embedding
Diffstat (limited to 'text_recognizer/networks/transformer/embeddings/axial.py')
-rw-r--r--text_recognizer/networks/transformer/embeddings/axial.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/axial.py b/text_recognizer/networks/transformer/embeddings/axial.py
new file mode 100644
index 0000000..56f29c5
--- /dev/null
+++ b/text_recognizer/networks/transformer/embeddings/axial.py
@@ -0,0 +1,33 @@
+"""Axial attention for multi-dimensional data.
+
+Stolen from:
+ https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
+"""
+from typing import Sequence
+
+import torch
+from torch import nn, Tensor
+
+
+class AxialPositionalEmbedding(nn.Module):
+ """Axial positional embedding."""
+
+ def __init__(self, dim: int, shape: Sequence[int], emb_dim_index: int = 1) -> None:
+ super().__init__()
+ total_dimensions = len(shape) + 2
+ ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
+
+ self.num_axials = len(shape)
+
+ for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
+ shape = [1] * total_dimensions
+ shape[emb_dim_index] = dim
+ shape[axial_dim_index] = axial_dim
+ parameter = nn.Parameter(torch.randn(*shape))
+ setattr(self, f"param_{i}", parameter)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Applies axial positional embedding."""
+ for i in range(self.num_axials):
+ x = x + getattr(self, f"param_{i}")
+ return x