summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-03 22:12:53 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-03 22:12:53 +0100
commit5e9a7a611284c37b7382f271d989d1ef70546d10 (patch)
treeffadcf05c4612a89e4bc8815ab6cecae3345c9f1 /text_recognizer/networks/transformer
parentc64c85c36e67a2bae07cac1adeef70e82e69225c (diff)
Fix local attn to work with any input length
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/local_attention.py67
1 files changed, 48 insertions, 19 deletions
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py
index 628cf7e..a008bab 100644
--- a/text_recognizer/networks/transformer/local_attention.py
+++ b/text_recognizer/networks/transformer/local_attention.py
@@ -4,8 +4,9 @@ Also stolen from lucidrains from here:
https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py
"""
from functools import reduce
+import math
from operator import mul
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
import attr
import torch
@@ -30,6 +31,7 @@ class LocalAttention(nn.Module):
window_size: int = attr.ib(default=128)
look_back: int = attr.ib(default=1)
dropout_rate: float = attr.ib(default=0.0)
+ autopad: bool = attr.ib(default=False)
rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None)
def __attrs_pre_init__(self) -> None:
@@ -47,7 +49,7 @@ class LocalAttention(nn.Module):
# Feedforward
self.fc = nn.Linear(inner_dim, self.dim)
- def _to_embeddings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+ def _to_embeddings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, List[int]]:
"""Convert input into query, key, and value."""
def _split_heads(t: Tensor) -> Tensor:
@@ -58,13 +60,14 @@ class LocalAttention(nn.Module):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(_split_heads, qkv)
+ shape = q.shape
q, k, v = map(_merge_into_batch, (q, k, v))
if self.rotary_embedding is not None:
embedding = self.rotary_embedding(q)
q, k = _apply_rotary_emb(q, k, embedding)
- return q, k, v
+ return q, k, v, shape
def _create_buckets(
self, q: Tensor, k: Tensor, v: Tensor, n: int, b: int, num_windows: int
@@ -85,7 +88,7 @@ class LocalAttention(nn.Module):
energy: Tensor,
b_n: Tensor,
bq_k: Tensor,
- mask: Tensor,
+ input_mask: Tensor,
num_windows: int,
) -> Tensor:
mask_value = -torch.finfo(energy.dtype).max
@@ -102,19 +105,27 @@ class LocalAttention(nn.Module):
energy = apply_input_mask(
b,
energy=energy,
- mask=mask,
+ input_mask=input_mask,
backward=self.look_back,
window_size=self.window_size,
num_windows=num_windows,
mask_value=mask_value,
+ autopad=self.autopad,
)
return energy
- def forward(
- self, x: Tensor, mask: Optional[Tensor] = None,
- ) -> Tuple[Tensor, Tensor]:
+ def forward(self, x: Tensor, input_mask: Optional[Tensor] = None,) -> Tensor:
"""Computes windowed attention."""
- b, n, _ = x.shape
+ q, k, v, shape = self._to_embeddings(x)
+ d = q.shape[-1]
+
+ if self.autopad:
+ orig_t = q.shape[1]
+ q, k, v = map(
+ lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v)
+ )
+
+ b, n, d = q.shape
if not n % self.window_size:
RuntimeError(
@@ -123,25 +134,28 @@ class LocalAttention(nn.Module):
num_windows = n // self.window_size
- q, k, v = self._to_embeddings(x)
- d = q.shape[-1]
-
# Compute buckets
b_n, bq, bk, bv, bq_k = self._create_buckets(q, k, v, n, b, num_windows)
# Compute the attention.
energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale
- energy = self._apply_masks(b, energy, b_n, bq_k, mask, num_windows)
+ energy = self._apply_masks(b, energy, b_n, bq_k, input_mask, num_windows)
attn = F.softmax(energy, dim=-1)
attn = self.dropout(attn)
out = einsum("b h i j, b h j d -> b h i d", attn, bv)
out = out.reshape(-1, n, d)
+ if self.autopad:
+ out = out[:, :orig_t, :]
+ n = orig_t
- out = out.reshape(b, self.num_heads, n, -1).transpose(1, 2).reshape(b, n, -1)
+ b = x.shape[0]
+ out = out.reshape(*shape)
+ out = out.reshape(b, n, -1)
out = self.fc(out)
- return out, attn
+
+ return out
def _apply_rotary_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
@@ -186,19 +200,34 @@ def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) ->
def apply_input_mask(
b: int,
energy: Tensor,
- mask: Tensor,
+ input_mask: Tensor,
backward: int,
window_size: int,
num_windows: int,
mask_value: Tensor,
+ autopad: bool,
) -> Tensor:
"""Applies input mask to energy tensor."""
- h = b // mask.shape[0]
- mask = mask.reshape(-1, num_windows, window_size)
- mq = mk = mask
+ h = b // input_mask.shape[0]
+ if autopad:
+ input_mask = pad_to_multiple(input_mask, window_size, dim=-1, value=False)
+ input_mask = input_mask.reshape(-1, num_windows, window_size)
+ mq = mk = input_mask
mk = look_around(mk, pad_value=False, backward=backward)
mask = mq[:, :, :, None] * mk[:, :, None, :]
mask = merge_dims(0, 1, expand_dim(mask, 1, h))
energy.masked_fill_(~mask, mask_value)
del mask
return energy
+
+
+def pad_to_multiple(
+ tensor: Tensor, multiple: int, dim: int = -1, value: int = 0
+) -> Tensor:
+ seqlen = tensor.shape[dim]
+ m = seqlen / multiple
+ if m.is_integer():
+ return tensor
+ remainder = math.ceil(m) * multiple - seqlen
+ pad_offset = (0,) * (-1 - dim) * 2
+ return F.pad(tensor, (*pad_offset, 0, remainder), value=value)