diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:13 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:13 +0200 |
commit | 7be90f5f101d7ace7ff07180950dac4c11086ec1 (patch) | |
tree | a99c0fc55dd45f8e4eda39a958d68863885cfd3f | |
parent | 12abf17cd7c31ae4599be366505a4423fbba4044 (diff) |
Add axial encoder
5 files changed, 215 insertions, 14 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 365906f..40047ad 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -5,7 +5,7 @@ from torch import nn, Tensor from text_recognizer.networks.transformer.decoder import Decoder from text_recognizer.networks.transformer.embeddings.axial import ( - AxialPositionalEmbedding, + AxialPositionalEmbeddingImage, ) @@ -20,8 +20,8 @@ class ConvTransformer(nn.Module): pad_index: Tensor, encoder: Type[nn.Module], decoder: Decoder, - pixel_embedding: AxialPositionalEmbedding, - token_pos_embedding: Optional[Type[nn.Module]] = None, + pixel_embedding: AxialPositionalEmbeddingImage, + token_pos_embedding: Type[nn.Module], ) -> None: super().__init__() self.input_dims = input_dims @@ -37,11 +37,7 @@ class ConvTransformer(nn.Module): ) # Positional encoding for decoder tokens. - if not self.decoder.has_pos_emb: - self.token_pos_embedding = token_pos_embedding - else: - self.token_pos_embedding = None - + self.token_pos_embedding = token_pos_embedding self.pixel_embedding = pixel_embedding # Latent projector for down sampling number of filters and 2d @@ -83,7 +79,7 @@ class ConvTransformer(nn.Module): """ z = self.encoder(x) z = self.conv(z) - z = self.pixel_embedding(z) + z += self.pixel_embedding(z) z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] @@ -110,11 +106,7 @@ class ConvTransformer(nn.Module): trg = trg.long() trg_mask = trg != self.pad_index trg = self.token_embedding(trg) - trg = ( - self.token_pos_embedding(trg) - if self.token_pos_embedding is not None - else trg - ) + trg += self.token_pos_embedding(trg) out = self.decoder(x=trg, context=src, input_mask=trg_mask) logits = self.to_logits(out) # [B, Sy, C] logits = logits.permute(0, 2, 1) # [B, C, Sy] diff --git a/text_recognizer/networks/transformer/axial_attention/__init__.py b/text_recognizer/networks/transformer/axial_attention/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/text_recognizer/networks/transformer/axial_attention/__init__.py diff --git a/text_recognizer/networks/transformer/axial_attention/encoder.py b/text_recognizer/networks/transformer/axial_attention/encoder.py new file mode 100644 index 0000000..1cadac1 --- /dev/null +++ b/text_recognizer/networks/transformer/axial_attention/encoder.py @@ -0,0 +1,90 @@ +"""Axial transformer encoder.""" + +from typing import List, Optional, Type +from text_recognizer.networks.transformer.embeddings.axial import ( + AxialPositionalEmbeddingImage, +) + +from torch import nn, Tensor + +from text_recognizer.networks.transformer.axial_attention.self_attention import ( + SelfAttention, +) +from text_recognizer.networks.transformer.axial_attention.utils import ( + calculate_permutations, + PermuteToForm, + Sequential, +) +from text_recognizer.networks.transformer.norm import PreNorm + + +class AxialEncoder(nn.Module): + """Axial transfomer encoder.""" + + def __init__( + self, + shape: List[int], + dim: int, + depth: int, + heads: int, + dim_head: int, + dim_index: int, + axial_embedding: AxialPositionalEmbeddingImage, + ) -> None: + super().__init__() + + self.shape = shape + self.dim = dim + self.depth = depth + self.heads = heads + self.dim_head = dim_head + self.dim_index = dim_index + self.axial_embedding = axial_embedding + + self.fn = self._build() + + def _build(self) -> Sequential: + permutations = calculate_permutations(2, self.dim_index) + get_ff = lambda: nn.Sequential( + nn.LayerNorm([self.dim, *self.shape]), + nn.Conv2d( + in_channels=self.dim, + out_channels=4 * self.dim, + kernel_size=3, + padding=1, + ), + nn.Mish(inplace=True), + nn.Conv2d( + in_channels=4 * self.dim, + out_channels=self.dim, + kernel_size=3, + padding=1, + ), + ) + + layers = nn.ModuleList([]) + for _ in range(self.depth): + attns = nn.ModuleList( + [ + PermuteToForm( + permutation=permutation, + fn=PreNorm( + self.dim, + SelfAttention( + dim=self.dim, heads=self.heads, dim_head=self.dim_head + ), + ), + ) + for permutation in permutations + ] + ) + convs = nn.ModuleList([get_ff(), get_ff()]) + layers.append(attns) + layers.append(convs) + + return Sequential(layers) + + def forward(self, x: Tensor) -> Tensor: + """Applies fn to input.""" + x += self.axial_embedding(x) + return self.fn(x) diff --git a/text_recognizer/networks/transformer/axial_attention/self_attention.py b/text_recognizer/networks/transformer/axial_attention/self_attention.py new file mode 100644 index 0000000..b5e4142 --- /dev/null +++ b/text_recognizer/networks/transformer/axial_attention/self_attention.py @@ -0,0 +1,40 @@ +"""Axial self attention module.""" + +import torch +from torch import nn +from torch import Tensor + + +class SelfAttention(nn.Module): + """Axial self attention module.""" + + def __init__( + self, + dim: int, + dim_head: int, + heads: int, + ) -> None: + super().__init__() + self.dim_hidden = heads * dim_head + self.heads = heads + self.dim_head = dim_head + self.to_q = nn.Linear(dim, self.dim_hidden, bias=False) + self.to_kv = nn.Linear(dim, 2 * self.dim_hidden, bias=False) + self.to_out = nn.Linear(self.dim_hidden, dim) + + def forward(self, x: Tensor) -> Tensor: + """Applies self attention.""" + q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) + b, _, d, h, e = *q.shape, self.heads, self.dim_head + + merge_heads = ( + lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e) + ) + q, k, v = map(merge_heads, (q, k, v)) + + energy = torch.einsum("bie,bje->bij", q, k) * (e ** -0.5) + energy = energy.softmax(dim=-1) + attn = torch.einsum("bij,bje->bie", energy, v) + + out = attn.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d) + return self.to_out(out) diff --git a/text_recognizer/networks/transformer/axial_attention/utils.py b/text_recognizer/networks/transformer/axial_attention/utils.py new file mode 100644 index 0000000..2f5bf7e --- /dev/null +++ b/text_recognizer/networks/transformer/axial_attention/utils.py @@ -0,0 +1,79 @@ +"""Helper functions for axial attention.""" +from operator import itemgetter +from typing import Callable, List, Tuple + +from torch import nn, Tensor + + +def _map_el_ind(arr: Tensor, ind: int) -> List: + return list(map(itemgetter(ind), arr)) + + +def _sort_indices(arr: Tensor) -> Tuple[List[int], List[int]]: + indices = [i for i in range(len(arr))] + arr = zip(arr, indices) + arr = sorted(arr) + return _map_el_ind(arr, 0), _map_el_ind(arr, 1) + + +def calculate_permutations(num_dims: int, emb_dim: int) -> List[List[int]]: + """Returns permutations of tensor.""" + total_dims = num_dims + 2 + axial_dims = [i for i in range(1, total_dims) if i != emb_dim] + + permutations = [] + + for axial_dim in axial_dims: + last_two_dims = [axial_dim, emb_dim] + dims_rest = set(range(0, total_dims)) - set(last_two_dims) + permutation = [*dims_rest, *last_two_dims] + permutations.append(permutation) + + return permutations + + +class PermuteToForm(nn.Module): + """Helper class for applying axial attention.""" + + def __init__( + self, + fn: Callable, + permutation: List[List[int]], + ) -> None: + super().__init__() + + self.fn = fn + self.permutation = permutation + _, self.inv_permutation = _sort_indices(self.permutation) + + def forward(self, x: Tensor) -> Tensor: + """Permutes tensor, applies axial attention, permutes tensor back.""" + x = x.permute(*self.permutation).contiguous() + shape = x.shape + *_, t, d = shape + + # Merge all but axial dimension + x = x.reshape(-1, t, d) + + # Apply attention + x = self.fn(x) + + # Restore original shape and permutation + x = x.reshape(*shape) + x = x.permute(*self.inv_permutation).contiguous() + return x + + +class Sequential(nn.Module): + """Applies a list of paired functions to input.""" + + def __init__(self, fns: nn.ModuleList) -> None: + super().__init__() + self.fns = fns + + def forward(self, x: Tensor) -> Tensor: + """Applies blocks to input.""" + for f, g in self.fns: + x = x + f(x) + x = x + g(x) + return x |