diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 16 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/local_attention.py | 233 |
2 files changed, 1 insertions, 248 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 8387fa4..67558ad 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -6,7 +6,6 @@ import attr from torch import nn, Tensor from text_recognizer.networks.transformer.attention import Attention -from text_recognizer.networks.transformer.local_attention import LocalAttention from text_recognizer.networks.transformer.mlp import FeedForward from text_recognizer.networks.transformer.residual import Residual @@ -24,18 +23,13 @@ class AttentionLayers(nn.Module): norm: Type[nn.Module] = attr.ib() ff: FeedForward = attr.ib() cross_attn: Optional[Attention] = attr.ib(default=None) - local_self_attn: Optional[LocalAttention] = attr.ib(default=None) pre_norm: bool = attr.ib(default=True) - local_depth: Optional[int] = attr.ib(default=None) has_pos_emb: bool = attr.ib(default=False) layer_types: Tuple[str, ...] = attr.ib(init=False) layers: nn.ModuleList = attr.ib(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" - if self.local_self_attn is not None: - if self.local_depth is None: - ValueError("Local depth has to be specified") self.layer_types = self._get_layer_types() * self.depth self.layers = self._build_network() @@ -45,14 +39,8 @@ class AttentionLayers(nn.Module): return "a", "c", "f" return "a", "f" - def _self_attn_block(self, i: int) -> Type[nn.Module]: - if self.local_depth is not None and i < self.local_depth: - return deepcopy(self.local_self_attn) - return deepcopy(self.self_attn) - def _delete(self) -> None: del self.self_attn - del self.local_self_attn del self.ff del self.norm del self.cross_attn @@ -60,11 +48,9 @@ class AttentionLayers(nn.Module): def _build_network(self) -> nn.ModuleList: """Configures transformer network.""" layers = nn.ModuleList([]) - self_attn_depth = 0 for layer_type in self.layer_types: if layer_type == "a": - layer = self._self_attn_block(self_attn_depth) - self_attn_depth += 1 + layer = deepcopy(self.self_attn) elif layer_type == "c": layer = deepcopy(self.cross_attn) elif layer_type == "f": diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py deleted file mode 100644 index a008bab..0000000 --- a/text_recognizer/networks/transformer/local_attention.py +++ /dev/null @@ -1,233 +0,0 @@ -"""Local attention module. - -Also stolen from lucidrains from here: -https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py -""" -from functools import reduce -import math -from operator import mul -from typing import List, Optional, Tuple - -import attr -import torch -from torch import einsum -from torch import nn -from torch import Tensor -import torch.nn.functional as F - -from text_recognizer.networks.transformer.embeddings.rotary import ( - RotaryEmbedding, - rotate_half, -) - - -@attr.s(eq=False) -class LocalAttention(nn.Module): - """Local windowed attention.""" - - dim: int = attr.ib() - num_heads: int = attr.ib() - dim_head: int = attr.ib(default=64) - window_size: int = attr.ib(default=128) - look_back: int = attr.ib(default=1) - dropout_rate: float = attr.ib(default=0.0) - autopad: bool = attr.ib(default=False) - rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None) - - def __attrs_pre_init__(self) -> None: - """Pre init constructor.""" - super().__init__() - - def __attrs_post_init__(self) -> None: - """Post init constructor.""" - self.scale = self.dim ** -0.5 - inner_dim = self.num_heads * self.dim_head - - self.to_qkv = nn.Linear(self.dim, 3 * inner_dim, bias=False) - self.dropout = nn.Dropout(p=self.dropout_rate) - - # Feedforward - self.fc = nn.Linear(inner_dim, self.dim) - - def _to_embeddings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, List[int]]: - """Convert input into query, key, and value.""" - - def _split_heads(t: Tensor) -> Tensor: - return _reshape_dim(t, -1, (-1, self.dim_head)).transpose(1, 2).contiguous() - - def _merge_into_batch(t: Tensor) -> Tensor: - return t.reshape(-1, *t.shape[-2:]) - - qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(_split_heads, qkv) - shape = q.shape - - q, k, v = map(_merge_into_batch, (q, k, v)) - - if self.rotary_embedding is not None: - embedding = self.rotary_embedding(q) - q, k = _apply_rotary_emb(q, k, embedding) - return q, k, v, shape - - def _create_buckets( - self, q: Tensor, k: Tensor, v: Tensor, n: int, b: int, num_windows: int - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - b_n = torch.arange(n).type_as(q).reshape(1, num_windows, self.window_size) - bq, bk, bv = map( - lambda t: t.reshape(b, num_windows, self.window_size, -1), (q, k, v), - ) - - bk = look_around(bk, backward=self.look_back) - bv = look_around(bv, backward=self.look_back) - bq_k = look_around(b_n, backward=self.look_back) - return b_n, bq, bk, bv, bq_k - - def _apply_masks( - self, - b: int, - energy: Tensor, - b_n: Tensor, - bq_k: Tensor, - input_mask: Tensor, - num_windows: int, - ) -> Tensor: - mask_value = -torch.finfo(energy.dtype).max - - # Causal mask. - causal_mask = b_n[:, :, :, None] < bq_k[:, :, None, :] - energy = energy.masked_fill_(causal_mask, mask_value) - del causal_mask - - bucket_mask = bq_k[:, :, None, :] == -1 - energy.masked_fill_(bucket_mask, mask_value) - del bucket_mask - - energy = apply_input_mask( - b, - energy=energy, - input_mask=input_mask, - backward=self.look_back, - window_size=self.window_size, - num_windows=num_windows, - mask_value=mask_value, - autopad=self.autopad, - ) - return energy - - def forward(self, x: Tensor, input_mask: Optional[Tensor] = None,) -> Tensor: - """Computes windowed attention.""" - q, k, v, shape = self._to_embeddings(x) - d = q.shape[-1] - - if self.autopad: - orig_t = q.shape[1] - q, k, v = map( - lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v) - ) - - b, n, d = q.shape - - if not n % self.window_size: - RuntimeError( - f"Sequence length {n} must be divisable with window size {self.window_size}" - ) - - num_windows = n // self.window_size - - # Compute buckets - b_n, bq, bk, bv, bq_k = self._create_buckets(q, k, v, n, b, num_windows) - - # Compute the attention. - energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale - energy = self._apply_masks(b, energy, b_n, bq_k, input_mask, num_windows) - - attn = F.softmax(energy, dim=-1) - attn = self.dropout(attn) - - out = einsum("b h i j, b h j d -> b h i d", attn, bv) - out = out.reshape(-1, n, d) - if self.autopad: - out = out[:, :orig_t, :] - n = orig_t - - b = x.shape[0] - out = out.reshape(*shape) - out = out.reshape(b, n, -1) - out = self.fc(out) - - return out - - -def _apply_rotary_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]: - q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) - return q, k - - -def _reshape_dim(t: Tensor, dim: int, split_dims: Tuple[int, int]) -> Tensor: - shape = list(t.shape) - dims = len(t.shape) - dim = (dim + dims) % dims - shape[dim : dim + 1] = split_dims - return t.reshape(shape) - - -def merge_dims(ind_from: int, ind_to: int, tensor: Tensor) -> Tensor: - """Merge dimensions.""" - shape = list(tensor.shape) - arr_slice = slice(ind_from, ind_to + 1) - shape[arr_slice] = [reduce(mul, shape[arr_slice])] - return tensor.reshape(*shape) - - -def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor: - """Expand tensors dimensions.""" - if unsqueeze: - t = t.unsqueeze(dim) - expand_shape = [-1] * len(t.shape) - expand_shape[dim] = k - return t.expand(*expand_shape) - - -def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor: - """Apply windowing.""" - n = x.shape[1] - dims = (len(x.shape) - dim) * (0, 0) - x_pad = F.pad(x, (*dims, backward, 0), value=pad_value) - tensors = [x_pad[:, ind : (ind + n), ...] for ind in range(backward + 1)] - return torch.cat(tensors, dim=dim) - - -def apply_input_mask( - b: int, - energy: Tensor, - input_mask: Tensor, - backward: int, - window_size: int, - num_windows: int, - mask_value: Tensor, - autopad: bool, -) -> Tensor: - """Applies input mask to energy tensor.""" - h = b // input_mask.shape[0] - if autopad: - input_mask = pad_to_multiple(input_mask, window_size, dim=-1, value=False) - input_mask = input_mask.reshape(-1, num_windows, window_size) - mq = mk = input_mask - mk = look_around(mk, pad_value=False, backward=backward) - mask = mq[:, :, :, None] * mk[:, :, None, :] - mask = merge_dims(0, 1, expand_dim(mask, 1, h)) - energy.masked_fill_(~mask, mask_value) - del mask - return energy - - -def pad_to_multiple( - tensor: Tensor, multiple: int, dim: int = -1, value: int = 0 -) -> Tensor: - seqlen = tensor.shape[dim] - m = seqlen / multiple - if m.is_integer(): - return tensor - remainder = math.ceil(m) * multiple - seqlen - pad_offset = (0,) * (-1 - dim) * 2 - return F.pad(tensor, (*pad_offset, 0, remainder), value=value) |