diff options
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/local_attention.py | 67 |
1 files changed, 48 insertions, 19 deletions
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py index 628cf7e..a008bab 100644 --- a/text_recognizer/networks/transformer/local_attention.py +++ b/text_recognizer/networks/transformer/local_attention.py @@ -4,8 +4,9 @@ Also stolen from lucidrains from here: https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py """ from functools import reduce +import math from operator import mul -from typing import Optional, Tuple +from typing import List, Optional, Tuple import attr import torch @@ -30,6 +31,7 @@ class LocalAttention(nn.Module): window_size: int = attr.ib(default=128) look_back: int = attr.ib(default=1) dropout_rate: float = attr.ib(default=0.0) + autopad: bool = attr.ib(default=False) rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: @@ -47,7 +49,7 @@ class LocalAttention(nn.Module): # Feedforward self.fc = nn.Linear(inner_dim, self.dim) - def _to_embeddings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def _to_embeddings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, List[int]]: """Convert input into query, key, and value.""" def _split_heads(t: Tensor) -> Tensor: @@ -58,13 +60,14 @@ class LocalAttention(nn.Module): qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(_split_heads, qkv) + shape = q.shape q, k, v = map(_merge_into_batch, (q, k, v)) if self.rotary_embedding is not None: embedding = self.rotary_embedding(q) q, k = _apply_rotary_emb(q, k, embedding) - return q, k, v + return q, k, v, shape def _create_buckets( self, q: Tensor, k: Tensor, v: Tensor, n: int, b: int, num_windows: int @@ -85,7 +88,7 @@ class LocalAttention(nn.Module): energy: Tensor, b_n: Tensor, bq_k: Tensor, - mask: Tensor, + input_mask: Tensor, num_windows: int, ) -> Tensor: mask_value = -torch.finfo(energy.dtype).max @@ -102,19 +105,27 @@ class LocalAttention(nn.Module): energy = apply_input_mask( b, energy=energy, - mask=mask, + input_mask=input_mask, backward=self.look_back, window_size=self.window_size, num_windows=num_windows, mask_value=mask_value, + autopad=self.autopad, ) return energy - def forward( - self, x: Tensor, mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: Tensor, input_mask: Optional[Tensor] = None,) -> Tensor: """Computes windowed attention.""" - b, n, _ = x.shape + q, k, v, shape = self._to_embeddings(x) + d = q.shape[-1] + + if self.autopad: + orig_t = q.shape[1] + q, k, v = map( + lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v) + ) + + b, n, d = q.shape if not n % self.window_size: RuntimeError( @@ -123,25 +134,28 @@ class LocalAttention(nn.Module): num_windows = n // self.window_size - q, k, v = self._to_embeddings(x) - d = q.shape[-1] - # Compute buckets b_n, bq, bk, bv, bq_k = self._create_buckets(q, k, v, n, b, num_windows) # Compute the attention. energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale - energy = self._apply_masks(b, energy, b_n, bq_k, mask, num_windows) + energy = self._apply_masks(b, energy, b_n, bq_k, input_mask, num_windows) attn = F.softmax(energy, dim=-1) attn = self.dropout(attn) out = einsum("b h i j, b h j d -> b h i d", attn, bv) out = out.reshape(-1, n, d) + if self.autopad: + out = out[:, :orig_t, :] + n = orig_t - out = out.reshape(b, self.num_heads, n, -1).transpose(1, 2).reshape(b, n, -1) + b = x.shape[0] + out = out.reshape(*shape) + out = out.reshape(b, n, -1) out = self.fc(out) - return out, attn + + return out def _apply_rotary_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]: @@ -186,19 +200,34 @@ def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> def apply_input_mask( b: int, energy: Tensor, - mask: Tensor, + input_mask: Tensor, backward: int, window_size: int, num_windows: int, mask_value: Tensor, + autopad: bool, ) -> Tensor: """Applies input mask to energy tensor.""" - h = b // mask.shape[0] - mask = mask.reshape(-1, num_windows, window_size) - mq = mk = mask + h = b // input_mask.shape[0] + if autopad: + input_mask = pad_to_multiple(input_mask, window_size, dim=-1, value=False) + input_mask = input_mask.reshape(-1, num_windows, window_size) + mq = mk = input_mask mk = look_around(mk, pad_value=False, backward=backward) mask = mq[:, :, :, None] * mk[:, :, None, :] mask = merge_dims(0, 1, expand_dim(mask, 1, h)) energy.masked_fill_(~mask, mask_value) del mask return energy + + +def pad_to_multiple( + tensor: Tensor, multiple: int, dim: int = -1, value: int = 0 +) -> Tensor: + seqlen = tensor.shape[dim] + m = seqlen / multiple + if m.is_integer(): + return tensor + remainder = math.ceil(m) * multiple - seqlen + pad_offset = (0,) * (-1 - dim) * 2 + return F.pad(tensor, (*pad_offset, 0, remainder), value=value) |