summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/axial_attention/utils.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-13 18:12:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-13 18:12:13 +0200
commit7be90f5f101d7ace7ff07180950dac4c11086ec1 (patch)
treea99c0fc55dd45f8e4eda39a958d68863885cfd3f /text_recognizer/networks/transformer/axial_attention/utils.py
parent12abf17cd7c31ae4599be366505a4423fbba4044 (diff)
Add axial encoder
Diffstat (limited to 'text_recognizer/networks/transformer/axial_attention/utils.py')
-rw-r--r--text_recognizer/networks/transformer/axial_attention/utils.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/axial_attention/utils.py b/text_recognizer/networks/transformer/axial_attention/utils.py
new file mode 100644
index 0000000..2f5bf7e
--- /dev/null
+++ b/text_recognizer/networks/transformer/axial_attention/utils.py
@@ -0,0 +1,79 @@
+"""Helper functions for axial attention."""
+from operator import itemgetter
+from typing import Callable, List, Tuple
+
+from torch import nn, Tensor
+
+
+def _map_el_ind(arr: Tensor, ind: int) -> List:
+ return list(map(itemgetter(ind), arr))
+
+
+def _sort_indices(arr: Tensor) -> Tuple[List[int], List[int]]:
+ indices = [i for i in range(len(arr))]
+ arr = zip(arr, indices)
+ arr = sorted(arr)
+ return _map_el_ind(arr, 0), _map_el_ind(arr, 1)
+
+
+def calculate_permutations(num_dims: int, emb_dim: int) -> List[List[int]]:
+ """Returns permutations of tensor."""
+ total_dims = num_dims + 2
+ axial_dims = [i for i in range(1, total_dims) if i != emb_dim]
+
+ permutations = []
+
+ for axial_dim in axial_dims:
+ last_two_dims = [axial_dim, emb_dim]
+ dims_rest = set(range(0, total_dims)) - set(last_two_dims)
+ permutation = [*dims_rest, *last_two_dims]
+ permutations.append(permutation)
+
+ return permutations
+
+
+class PermuteToForm(nn.Module):
+ """Helper class for applying axial attention."""
+
+ def __init__(
+ self,
+ fn: Callable,
+ permutation: List[List[int]],
+ ) -> None:
+ super().__init__()
+
+ self.fn = fn
+ self.permutation = permutation
+ _, self.inv_permutation = _sort_indices(self.permutation)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Permutes tensor, applies axial attention, permutes tensor back."""
+ x = x.permute(*self.permutation).contiguous()
+ shape = x.shape
+ *_, t, d = shape
+
+ # Merge all but axial dimension
+ x = x.reshape(-1, t, d)
+
+ # Apply attention
+ x = self.fn(x)
+
+ # Restore original shape and permutation
+ x = x.reshape(*shape)
+ x = x.permute(*self.inv_permutation).contiguous()
+ return x
+
+
+class Sequential(nn.Module):
+ """Applies a list of paired functions to input."""
+
+ def __init__(self, fns: nn.ModuleList) -> None:
+ super().__init__()
+ self.fns = fns
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Applies blocks to input."""
+ for f, g in self.fns:
+ x = x + f(x)
+ x = x + g(x)
+ return x