summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/embeddings/axial.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/axial.py b/text_recognizer/networks/transformer/embeddings/axial.py
new file mode 100644
index 0000000..56f29c5
--- /dev/null
+++ b/text_recognizer/networks/transformer/embeddings/axial.py
@@ -0,0 +1,33 @@
+"""Axial attention for multi-dimensional data.
+
+Stolen from:
+ https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
+"""
+from typing import Sequence
+
+import torch
+from torch import nn, Tensor
+
+
+class AxialPositionalEmbedding(nn.Module):
+ """Axial positional embedding."""
+
+ def __init__(self, dim: int, shape: Sequence[int], emb_dim_index: int = 1) -> None:
+ super().__init__()
+ total_dimensions = len(shape) + 2
+ ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
+
+ self.num_axials = len(shape)
+
+ for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
+ shape = [1] * total_dimensions
+ shape[emb_dim_index] = dim
+ shape[axial_dim_index] = axial_dim
+ parameter = nn.Parameter(torch.randn(*shape))
+ setattr(self, f"param_{i}", parameter)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Applies axial positional embedding."""
+ for i in range(self.num_axials):
+ x = x + getattr(self, f"param_{i}")
+ return x