From 281c8602b4d945cf329d5bead104729acf47ed9c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Sep 2022 00:04:09 +0200 Subject: Steal lucidrains axial encoding --- .../networks/transformer/embeddings/axial.py | 100 ++++++++++++++++----- 1 file changed, 78 insertions(+), 22 deletions(-) (limited to 'text_recognizer/networks/transformer') diff --git a/text_recognizer/networks/transformer/embeddings/axial.py b/text_recognizer/networks/transformer/embeddings/axial.py index 56f29c5..7b84e12 100644 --- a/text_recognizer/networks/transformer/embeddings/axial.py +++ b/text_recognizer/networks/transformer/embeddings/axial.py @@ -3,31 +3,87 @@ Stolen from: https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100 """ -from typing import Sequence - import torch -from torch import nn, Tensor +from torch import nn +from operator import mul +from functools import reduce class AxialPositionalEmbedding(nn.Module): - """Axial positional embedding.""" + def __init__(self, dim, axial_shape, axial_dims=None): + super().__init__() + + self.dim = dim + self.shape = axial_shape + self.max_seq_len = reduce(mul, axial_shape, 1) + + self.summed = axial_dims is None + axial_dims = ((dim,) * len(axial_shape)) if self.summed else axial_dims + + assert len(self.shape) == len( + axial_dims + ), "number of axial dimensions must equal the number of dimensions in the shape" + assert ( + self.summed or not self.summed and sum(axial_dims) == dim + ), f"axial dimensions must sum up to the target dimension {dim}" + + self.weights = ParameterList(self, "weights", len(axial_shape)) + + for ind, (shape, axial_dim) in enumerate(zip(self.shape, axial_dims)): + ax_shape = [1] * len(self.shape) + ax_shape[ind] = shape + ax_shape = (1, *ax_shape, axial_dim) + ax_emb = nn.Parameter(torch.zeros(ax_shape).normal_(0, 1)) + self.weights.append(ax_emb) + + def forward(self, x): + b, t, _ = x.shape + assert ( + t <= self.max_seq_len + ), f"Sequence length ({t}) must be less than the maximum sequence length allowed ({self.max_seq_len})" + embs = [] + + for ax_emb in self.weights.to_list(): + axial_dim = ax_emb.shape[-1] + expand_shape = (b, *self.shape, axial_dim) + emb = ax_emb.expand(expand_shape).reshape(b, self.max_seq_len, axial_dim) + embs.append(emb) - def __init__(self, dim: int, shape: Sequence[int], emb_dim_index: int = 1) -> None: + pos_emb = sum(embs) if self.summed else torch.cat(embs, dim=-1) + return pos_emb[:, :t].to(x) + + +# a mock parameter list object until below issue is resolved +# https://github.com/pytorch/pytorch/issues/36035 +class ParameterList(object): + def __init__(self, kls, prefix, length): + self.ind = 0 + self.kls = kls + self.prefix = prefix + self.length = length + + def _keyname(self, prefix, ind): + return f"{prefix}_{ind}" + + def append(self, x): + setattr(self.kls, self._keyname(self.prefix, self.ind), x) + self.ind += 1 + + def to_list(self): + return [ + getattr(self.kls, self._keyname(self.prefix, i)) for i in range(self.length) + ] + + +class AxialPositionalEmbeddingImage(nn.Module): + def __init__(self, dim, axial_shape, axial_dims=None): super().__init__() - total_dimensions = len(shape) + 2 - ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index] - - self.num_axials = len(shape) - - for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)): - shape = [1] * total_dimensions - shape[emb_dim_index] = dim - shape[axial_dim_index] = axial_dim - parameter = nn.Parameter(torch.randn(*shape)) - setattr(self, f"param_{i}", parameter) - - def forward(self, x: Tensor) -> Tensor: - """Applies axial positional embedding.""" - for i in range(self.num_axials): - x = x + getattr(self, f"param_{i}") - return x + assert len(axial_shape) == 2, "Axial shape must have 2 dimensions for images" + self.dim = dim + self.pos_emb = AxialPositionalEmbedding(dim, axial_shape, axial_dims) + + def forward(self, img): + b, c, h, w = img.shape + img = img.permute(0, 2, 3, 1).reshape(b, h * w, c) + pos_emb = self.pos_emb(img) + return pos_emb.reshape(b, h, w, self.dim).permute(0, 3, 1, 2) -- cgit v1.2.3-70-g09d2