summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/transformer/layers.py16
-rw-r--r--text_recognizer/networks/transformer/local_attention.py233
2 files changed, 1 insertions, 248 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 8387fa4..67558ad 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -6,7 +6,6 @@ import attr
from torch import nn, Tensor
from text_recognizer.networks.transformer.attention import Attention
-from text_recognizer.networks.transformer.local_attention import LocalAttention
from text_recognizer.networks.transformer.mlp import FeedForward
from text_recognizer.networks.transformer.residual import Residual
@@ -24,18 +23,13 @@ class AttentionLayers(nn.Module):
norm: Type[nn.Module] = attr.ib()
ff: FeedForward = attr.ib()
cross_attn: Optional[Attention] = attr.ib(default=None)
- local_self_attn: Optional[LocalAttention] = attr.ib(default=None)
pre_norm: bool = attr.ib(default=True)
- local_depth: Optional[int] = attr.ib(default=None)
has_pos_emb: bool = attr.ib(default=False)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- if self.local_self_attn is not None:
- if self.local_depth is None:
- ValueError("Local depth has to be specified")
self.layer_types = self._get_layer_types() * self.depth
self.layers = self._build_network()
@@ -45,14 +39,8 @@ class AttentionLayers(nn.Module):
return "a", "c", "f"
return "a", "f"
- def _self_attn_block(self, i: int) -> Type[nn.Module]:
- if self.local_depth is not None and i < self.local_depth:
- return deepcopy(self.local_self_attn)
- return deepcopy(self.self_attn)
-
def _delete(self) -> None:
del self.self_attn
- del self.local_self_attn
del self.ff
del self.norm
del self.cross_attn
@@ -60,11 +48,9 @@ class AttentionLayers(nn.Module):
def _build_network(self) -> nn.ModuleList:
"""Configures transformer network."""
layers = nn.ModuleList([])
- self_attn_depth = 0
for layer_type in self.layer_types:
if layer_type == "a":
- layer = self._self_attn_block(self_attn_depth)
- self_attn_depth += 1
+ layer = deepcopy(self.self_attn)
elif layer_type == "c":
layer = deepcopy(self.cross_attn)
elif layer_type == "f":
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py
deleted file mode 100644
index a008bab..0000000
--- a/text_recognizer/networks/transformer/local_attention.py
+++ /dev/null
@@ -1,233 +0,0 @@
-"""Local attention module.
-
-Also stolen from lucidrains from here:
-https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py
-"""
-from functools import reduce
-import math
-from operator import mul
-from typing import List, Optional, Tuple
-
-import attr
-import torch
-from torch import einsum
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.transformer.embeddings.rotary import (
- RotaryEmbedding,
- rotate_half,
-)
-
-
-@attr.s(eq=False)
-class LocalAttention(nn.Module):
- """Local windowed attention."""
-
- dim: int = attr.ib()
- num_heads: int = attr.ib()
- dim_head: int = attr.ib(default=64)
- window_size: int = attr.ib(default=128)
- look_back: int = attr.ib(default=1)
- dropout_rate: float = attr.ib(default=0.0)
- autopad: bool = attr.ib(default=False)
- rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None)
-
- def __attrs_pre_init__(self) -> None:
- """Pre init constructor."""
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- """Post init constructor."""
- self.scale = self.dim ** -0.5
- inner_dim = self.num_heads * self.dim_head
-
- self.to_qkv = nn.Linear(self.dim, 3 * inner_dim, bias=False)
- self.dropout = nn.Dropout(p=self.dropout_rate)
-
- # Feedforward
- self.fc = nn.Linear(inner_dim, self.dim)
-
- def _to_embeddings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, List[int]]:
- """Convert input into query, key, and value."""
-
- def _split_heads(t: Tensor) -> Tensor:
- return _reshape_dim(t, -1, (-1, self.dim_head)).transpose(1, 2).contiguous()
-
- def _merge_into_batch(t: Tensor) -> Tensor:
- return t.reshape(-1, *t.shape[-2:])
-
- qkv = self.to_qkv(x).chunk(3, dim=-1)
- q, k, v = map(_split_heads, qkv)
- shape = q.shape
-
- q, k, v = map(_merge_into_batch, (q, k, v))
-
- if self.rotary_embedding is not None:
- embedding = self.rotary_embedding(q)
- q, k = _apply_rotary_emb(q, k, embedding)
- return q, k, v, shape
-
- def _create_buckets(
- self, q: Tensor, k: Tensor, v: Tensor, n: int, b: int, num_windows: int
- ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- b_n = torch.arange(n).type_as(q).reshape(1, num_windows, self.window_size)
- bq, bk, bv = map(
- lambda t: t.reshape(b, num_windows, self.window_size, -1), (q, k, v),
- )
-
- bk = look_around(bk, backward=self.look_back)
- bv = look_around(bv, backward=self.look_back)
- bq_k = look_around(b_n, backward=self.look_back)
- return b_n, bq, bk, bv, bq_k
-
- def _apply_masks(
- self,
- b: int,
- energy: Tensor,
- b_n: Tensor,
- bq_k: Tensor,
- input_mask: Tensor,
- num_windows: int,
- ) -> Tensor:
- mask_value = -torch.finfo(energy.dtype).max
-
- # Causal mask.
- causal_mask = b_n[:, :, :, None] < bq_k[:, :, None, :]
- energy = energy.masked_fill_(causal_mask, mask_value)
- del causal_mask
-
- bucket_mask = bq_k[:, :, None, :] == -1
- energy.masked_fill_(bucket_mask, mask_value)
- del bucket_mask
-
- energy = apply_input_mask(
- b,
- energy=energy,
- input_mask=input_mask,
- backward=self.look_back,
- window_size=self.window_size,
- num_windows=num_windows,
- mask_value=mask_value,
- autopad=self.autopad,
- )
- return energy
-
- def forward(self, x: Tensor, input_mask: Optional[Tensor] = None,) -> Tensor:
- """Computes windowed attention."""
- q, k, v, shape = self._to_embeddings(x)
- d = q.shape[-1]
-
- if self.autopad:
- orig_t = q.shape[1]
- q, k, v = map(
- lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v)
- )
-
- b, n, d = q.shape
-
- if not n % self.window_size:
- RuntimeError(
- f"Sequence length {n} must be divisable with window size {self.window_size}"
- )
-
- num_windows = n // self.window_size
-
- # Compute buckets
- b_n, bq, bk, bv, bq_k = self._create_buckets(q, k, v, n, b, num_windows)
-
- # Compute the attention.
- energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale
- energy = self._apply_masks(b, energy, b_n, bq_k, input_mask, num_windows)
-
- attn = F.softmax(energy, dim=-1)
- attn = self.dropout(attn)
-
- out = einsum("b h i j, b h j d -> b h i d", attn, bv)
- out = out.reshape(-1, n, d)
- if self.autopad:
- out = out[:, :orig_t, :]
- n = orig_t
-
- b = x.shape[0]
- out = out.reshape(*shape)
- out = out.reshape(b, n, -1)
- out = self.fc(out)
-
- return out
-
-
-def _apply_rotary_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
- q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
- return q, k
-
-
-def _reshape_dim(t: Tensor, dim: int, split_dims: Tuple[int, int]) -> Tensor:
- shape = list(t.shape)
- dims = len(t.shape)
- dim = (dim + dims) % dims
- shape[dim : dim + 1] = split_dims
- return t.reshape(shape)
-
-
-def merge_dims(ind_from: int, ind_to: int, tensor: Tensor) -> Tensor:
- """Merge dimensions."""
- shape = list(tensor.shape)
- arr_slice = slice(ind_from, ind_to + 1)
- shape[arr_slice] = [reduce(mul, shape[arr_slice])]
- return tensor.reshape(*shape)
-
-
-def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor:
- """Expand tensors dimensions."""
- if unsqueeze:
- t = t.unsqueeze(dim)
- expand_shape = [-1] * len(t.shape)
- expand_shape[dim] = k
- return t.expand(*expand_shape)
-
-
-def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor:
- """Apply windowing."""
- n = x.shape[1]
- dims = (len(x.shape) - dim) * (0, 0)
- x_pad = F.pad(x, (*dims, backward, 0), value=pad_value)
- tensors = [x_pad[:, ind : (ind + n), ...] for ind in range(backward + 1)]
- return torch.cat(tensors, dim=dim)
-
-
-def apply_input_mask(
- b: int,
- energy: Tensor,
- input_mask: Tensor,
- backward: int,
- window_size: int,
- num_windows: int,
- mask_value: Tensor,
- autopad: bool,
-) -> Tensor:
- """Applies input mask to energy tensor."""
- h = b // input_mask.shape[0]
- if autopad:
- input_mask = pad_to_multiple(input_mask, window_size, dim=-1, value=False)
- input_mask = input_mask.reshape(-1, num_windows, window_size)
- mq = mk = input_mask
- mk = look_around(mk, pad_value=False, backward=backward)
- mask = mq[:, :, :, None] * mk[:, :, None, :]
- mask = merge_dims(0, 1, expand_dim(mask, 1, h))
- energy.masked_fill_(~mask, mask_value)
- del mask
- return energy
-
-
-def pad_to_multiple(
- tensor: Tensor, multiple: int, dim: int = -1, value: int = 0
-) -> Tensor:
- seqlen = tensor.shape[dim]
- m = seqlen / multiple
- if m.is_integer():
- return tensor
- remainder = math.ceil(m) * multiple - seqlen
- pad_offset = (0,) * (-1 - dim) * 2
- return F.pad(tensor, (*pad_offset, 0, remainder), value=value)