summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/conv_transformer.py20
-rw-r--r--text_recognizer/networks/transformer/axial_attention/__init__.py0
-rw-r--r--text_recognizer/networks/transformer/axial_attention/encoder.py90
-rw-r--r--text_recognizer/networks/transformer/axial_attention/self_attention.py40
-rw-r--r--text_recognizer/networks/transformer/axial_attention/utils.py79
5 files changed, 215 insertions, 14 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 365906f..40047ad 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -5,7 +5,7 @@ from torch import nn, Tensor
from text_recognizer.networks.transformer.decoder import Decoder
from text_recognizer.networks.transformer.embeddings.axial import (
- AxialPositionalEmbedding,
+ AxialPositionalEmbeddingImage,
)
@@ -20,8 +20,8 @@ class ConvTransformer(nn.Module):
pad_index: Tensor,
encoder: Type[nn.Module],
decoder: Decoder,
- pixel_embedding: AxialPositionalEmbedding,
- token_pos_embedding: Optional[Type[nn.Module]] = None,
+ pixel_embedding: AxialPositionalEmbeddingImage,
+ token_pos_embedding: Type[nn.Module],
) -> None:
super().__init__()
self.input_dims = input_dims
@@ -37,11 +37,7 @@ class ConvTransformer(nn.Module):
)
# Positional encoding for decoder tokens.
- if not self.decoder.has_pos_emb:
- self.token_pos_embedding = token_pos_embedding
- else:
- self.token_pos_embedding = None
-
+ self.token_pos_embedding = token_pos_embedding
self.pixel_embedding = pixel_embedding
# Latent projector for down sampling number of filters and 2d
@@ -83,7 +79,7 @@ class ConvTransformer(nn.Module):
"""
z = self.encoder(x)
z = self.conv(z)
- z = self.pixel_embedding(z)
+ z += self.pixel_embedding(z)
z = z.flatten(start_dim=2)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
@@ -110,11 +106,7 @@ class ConvTransformer(nn.Module):
trg = trg.long()
trg_mask = trg != self.pad_index
trg = self.token_embedding(trg)
- trg = (
- self.token_pos_embedding(trg)
- if self.token_pos_embedding is not None
- else trg
- )
+ trg += self.token_pos_embedding(trg)
out = self.decoder(x=trg, context=src, input_mask=trg_mask)
logits = self.to_logits(out) # [B, Sy, C]
logits = logits.permute(0, 2, 1) # [B, C, Sy]
diff --git a/text_recognizer/networks/transformer/axial_attention/__init__.py b/text_recognizer/networks/transformer/axial_attention/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/text_recognizer/networks/transformer/axial_attention/__init__.py
diff --git a/text_recognizer/networks/transformer/axial_attention/encoder.py b/text_recognizer/networks/transformer/axial_attention/encoder.py
new file mode 100644
index 0000000..1cadac1
--- /dev/null
+++ b/text_recognizer/networks/transformer/axial_attention/encoder.py
@@ -0,0 +1,90 @@
+"""Axial transformer encoder."""
+
+from typing import List, Optional, Type
+from text_recognizer.networks.transformer.embeddings.axial import (
+ AxialPositionalEmbeddingImage,
+)
+
+from torch import nn, Tensor
+
+from text_recognizer.networks.transformer.axial_attention.self_attention import (
+ SelfAttention,
+)
+from text_recognizer.networks.transformer.axial_attention.utils import (
+ calculate_permutations,
+ PermuteToForm,
+ Sequential,
+)
+from text_recognizer.networks.transformer.norm import PreNorm
+
+
+class AxialEncoder(nn.Module):
+ """Axial transfomer encoder."""
+
+ def __init__(
+ self,
+ shape: List[int],
+ dim: int,
+ depth: int,
+ heads: int,
+ dim_head: int,
+ dim_index: int,
+ axial_embedding: AxialPositionalEmbeddingImage,
+ ) -> None:
+ super().__init__()
+
+ self.shape = shape
+ self.dim = dim
+ self.depth = depth
+ self.heads = heads
+ self.dim_head = dim_head
+ self.dim_index = dim_index
+ self.axial_embedding = axial_embedding
+
+ self.fn = self._build()
+
+ def _build(self) -> Sequential:
+ permutations = calculate_permutations(2, self.dim_index)
+ get_ff = lambda: nn.Sequential(
+ nn.LayerNorm([self.dim, *self.shape]),
+ nn.Conv2d(
+ in_channels=self.dim,
+ out_channels=4 * self.dim,
+ kernel_size=3,
+ padding=1,
+ ),
+ nn.Mish(inplace=True),
+ nn.Conv2d(
+ in_channels=4 * self.dim,
+ out_channels=self.dim,
+ kernel_size=3,
+ padding=1,
+ ),
+ )
+
+ layers = nn.ModuleList([])
+ for _ in range(self.depth):
+ attns = nn.ModuleList(
+ [
+ PermuteToForm(
+ permutation=permutation,
+ fn=PreNorm(
+ self.dim,
+ SelfAttention(
+ dim=self.dim, heads=self.heads, dim_head=self.dim_head
+ ),
+ ),
+ )
+ for permutation in permutations
+ ]
+ )
+ convs = nn.ModuleList([get_ff(), get_ff()])
+ layers.append(attns)
+ layers.append(convs)
+
+ return Sequential(layers)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Applies fn to input."""
+ x += self.axial_embedding(x)
+ return self.fn(x)
diff --git a/text_recognizer/networks/transformer/axial_attention/self_attention.py b/text_recognizer/networks/transformer/axial_attention/self_attention.py
new file mode 100644
index 0000000..b5e4142
--- /dev/null
+++ b/text_recognizer/networks/transformer/axial_attention/self_attention.py
@@ -0,0 +1,40 @@
+"""Axial self attention module."""
+
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class SelfAttention(nn.Module):
+ """Axial self attention module."""
+
+ def __init__(
+ self,
+ dim: int,
+ dim_head: int,
+ heads: int,
+ ) -> None:
+ super().__init__()
+ self.dim_hidden = heads * dim_head
+ self.heads = heads
+ self.dim_head = dim_head
+ self.to_q = nn.Linear(dim, self.dim_hidden, bias=False)
+ self.to_kv = nn.Linear(dim, 2 * self.dim_hidden, bias=False)
+ self.to_out = nn.Linear(self.dim_hidden, dim)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Applies self attention."""
+ q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
+ b, _, d, h, e = *q.shape, self.heads, self.dim_head
+
+ merge_heads = (
+ lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
+ )
+ q, k, v = map(merge_heads, (q, k, v))
+
+ energy = torch.einsum("bie,bje->bij", q, k) * (e ** -0.5)
+ energy = energy.softmax(dim=-1)
+ attn = torch.einsum("bij,bje->bie", energy, v)
+
+ out = attn.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
+ return self.to_out(out)
diff --git a/text_recognizer/networks/transformer/axial_attention/utils.py b/text_recognizer/networks/transformer/axial_attention/utils.py
new file mode 100644
index 0000000..2f5bf7e
--- /dev/null
+++ b/text_recognizer/networks/transformer/axial_attention/utils.py
@@ -0,0 +1,79 @@
+"""Helper functions for axial attention."""
+from operator import itemgetter
+from typing import Callable, List, Tuple
+
+from torch import nn, Tensor
+
+
+def _map_el_ind(arr: Tensor, ind: int) -> List:
+ return list(map(itemgetter(ind), arr))
+
+
+def _sort_indices(arr: Tensor) -> Tuple[List[int], List[int]]:
+ indices = [i for i in range(len(arr))]
+ arr = zip(arr, indices)
+ arr = sorted(arr)
+ return _map_el_ind(arr, 0), _map_el_ind(arr, 1)
+
+
+def calculate_permutations(num_dims: int, emb_dim: int) -> List[List[int]]:
+ """Returns permutations of tensor."""
+ total_dims = num_dims + 2
+ axial_dims = [i for i in range(1, total_dims) if i != emb_dim]
+
+ permutations = []
+
+ for axial_dim in axial_dims:
+ last_two_dims = [axial_dim, emb_dim]
+ dims_rest = set(range(0, total_dims)) - set(last_two_dims)
+ permutation = [*dims_rest, *last_two_dims]
+ permutations.append(permutation)
+
+ return permutations
+
+
+class PermuteToForm(nn.Module):
+ """Helper class for applying axial attention."""
+
+ def __init__(
+ self,
+ fn: Callable,
+ permutation: List[List[int]],
+ ) -> None:
+ super().__init__()
+
+ self.fn = fn
+ self.permutation = permutation
+ _, self.inv_permutation = _sort_indices(self.permutation)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Permutes tensor, applies axial attention, permutes tensor back."""
+ x = x.permute(*self.permutation).contiguous()
+ shape = x.shape
+ *_, t, d = shape
+
+ # Merge all but axial dimension
+ x = x.reshape(-1, t, d)
+
+ # Apply attention
+ x = self.fn(x)
+
+ # Restore original shape and permutation
+ x = x.reshape(*shape)
+ x = x.permute(*self.inv_permutation).contiguous()
+ return x
+
+
+class Sequential(nn.Module):
+ """Applies a list of paired functions to input."""
+
+ def __init__(self, fns: nn.ModuleList) -> None:
+ super().__init__()
+ self.fns = fns
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Applies blocks to input."""
+ for f, g in self.fns:
+ x = x + f(x)
+ x = x + g(x)
+ return x