summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:04:09 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:04:09 +0200
commit281c8602b4d945cf329d5bead104729acf47ed9c (patch)
tree520bb301ab41c94a713828fc1856c1a0ab37a03c
parent03c19e0b51e6dbb5a0343e9d1d1bc18c184a164f (diff)
Steal lucidrains axial encoding
-rw-r--r--text_recognizer/networks/transformer/embeddings/axial.py100
1 files changed, 78 insertions, 22 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/axial.py b/text_recognizer/networks/transformer/embeddings/axial.py
index 56f29c5..7b84e12 100644
--- a/text_recognizer/networks/transformer/embeddings/axial.py
+++ b/text_recognizer/networks/transformer/embeddings/axial.py
@@ -3,31 +3,87 @@
Stolen from:
https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
"""
-from typing import Sequence
-
import torch
-from torch import nn, Tensor
+from torch import nn
+from operator import mul
+from functools import reduce
class AxialPositionalEmbedding(nn.Module):
- """Axial positional embedding."""
+ def __init__(self, dim, axial_shape, axial_dims=None):
+ super().__init__()
+
+ self.dim = dim
+ self.shape = axial_shape
+ self.max_seq_len = reduce(mul, axial_shape, 1)
+
+ self.summed = axial_dims is None
+ axial_dims = ((dim,) * len(axial_shape)) if self.summed else axial_dims
+
+ assert len(self.shape) == len(
+ axial_dims
+ ), "number of axial dimensions must equal the number of dimensions in the shape"
+ assert (
+ self.summed or not self.summed and sum(axial_dims) == dim
+ ), f"axial dimensions must sum up to the target dimension {dim}"
+
+ self.weights = ParameterList(self, "weights", len(axial_shape))
+
+ for ind, (shape, axial_dim) in enumerate(zip(self.shape, axial_dims)):
+ ax_shape = [1] * len(self.shape)
+ ax_shape[ind] = shape
+ ax_shape = (1, *ax_shape, axial_dim)
+ ax_emb = nn.Parameter(torch.zeros(ax_shape).normal_(0, 1))
+ self.weights.append(ax_emb)
+
+ def forward(self, x):
+ b, t, _ = x.shape
+ assert (
+ t <= self.max_seq_len
+ ), f"Sequence length ({t}) must be less than the maximum sequence length allowed ({self.max_seq_len})"
+ embs = []
+
+ for ax_emb in self.weights.to_list():
+ axial_dim = ax_emb.shape[-1]
+ expand_shape = (b, *self.shape, axial_dim)
+ emb = ax_emb.expand(expand_shape).reshape(b, self.max_seq_len, axial_dim)
+ embs.append(emb)
- def __init__(self, dim: int, shape: Sequence[int], emb_dim_index: int = 1) -> None:
+ pos_emb = sum(embs) if self.summed else torch.cat(embs, dim=-1)
+ return pos_emb[:, :t].to(x)
+
+
+# a mock parameter list object until below issue is resolved
+# https://github.com/pytorch/pytorch/issues/36035
+class ParameterList(object):
+ def __init__(self, kls, prefix, length):
+ self.ind = 0
+ self.kls = kls
+ self.prefix = prefix
+ self.length = length
+
+ def _keyname(self, prefix, ind):
+ return f"{prefix}_{ind}"
+
+ def append(self, x):
+ setattr(self.kls, self._keyname(self.prefix, self.ind), x)
+ self.ind += 1
+
+ def to_list(self):
+ return [
+ getattr(self.kls, self._keyname(self.prefix, i)) for i in range(self.length)
+ ]
+
+
+class AxialPositionalEmbeddingImage(nn.Module):
+ def __init__(self, dim, axial_shape, axial_dims=None):
super().__init__()
- total_dimensions = len(shape) + 2
- ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
-
- self.num_axials = len(shape)
-
- for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
- shape = [1] * total_dimensions
- shape[emb_dim_index] = dim
- shape[axial_dim_index] = axial_dim
- parameter = nn.Parameter(torch.randn(*shape))
- setattr(self, f"param_{i}", parameter)
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies axial positional embedding."""
- for i in range(self.num_axials):
- x = x + getattr(self, f"param_{i}")
- return x
+ assert len(axial_shape) == 2, "Axial shape must have 2 dimensions for images"
+ self.dim = dim
+ self.pos_emb = AxialPositionalEmbedding(dim, axial_shape, axial_dims)
+
+ def forward(self, img):
+ b, c, h, w = img.shape
+ img = img.permute(0, 2, 3, 1).reshape(b, h * w, c)
+ pos_emb = self.pos_emb(img)
+ return pos_emb.reshape(b, h, w, self.dim).permute(0, 3, 1, 2)