From 09797e2ba757ec3a3545387a0b57c2be1956b38b Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Wed, 27 Oct 2021 22:16:56 +0200
Subject: Add axial embedding

---
 .../networks/transformer/embeddings/axial.py       | 33 ++++++++++++++++++++++
 1 file changed, 33 insertions(+)
 create mode 100644 text_recognizer/networks/transformer/embeddings/axial.py

(limited to 'text_recognizer/networks/transformer')

diff --git a/text_recognizer/networks/transformer/embeddings/axial.py b/text_recognizer/networks/transformer/embeddings/axial.py
new file mode 100644
index 0000000..56f29c5
--- /dev/null
+++ b/text_recognizer/networks/transformer/embeddings/axial.py
@@ -0,0 +1,33 @@
+"""Axial attention for multi-dimensional data.
+
+Stolen from:
+    https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
+"""
+from typing import Sequence
+
+import torch
+from torch import nn, Tensor
+
+
+class AxialPositionalEmbedding(nn.Module):
+    """Axial positional embedding."""
+
+    def __init__(self, dim: int, shape: Sequence[int], emb_dim_index: int = 1) -> None:
+        super().__init__()
+        total_dimensions = len(shape) + 2
+        ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
+
+        self.num_axials = len(shape)
+
+        for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
+            shape = [1] * total_dimensions
+            shape[emb_dim_index] = dim
+            shape[axial_dim_index] = axial_dim
+            parameter = nn.Parameter(torch.randn(*shape))
+            setattr(self, f"param_{i}", parameter)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Applies axial positional embedding."""
+        for i in range(self.num_axials):
+            x = x + getattr(self, f"param_{i}")
+        return x
-- 
cgit v1.2.3-70-g09d2