summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r--text_recognizer/networks/transformer/attention.py31
1 files changed, 14 insertions, 17 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index d687056..b73fec0 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -9,7 +9,10 @@ from torch import nn
from torch import Tensor
import torch.nn.functional as F
-from text_recognizer.networks.transformer.embeddings.rotary import apply_rotary_pos_emb
+from text_recognizer.networks.transformer.embeddings.rotary import (
+ RotaryEmbedding,
+ rotate_half,
+)
@attr.s(eq=False)
@@ -25,15 +28,15 @@ class Attention(nn.Module):
causal: bool = attr.ib(default=False)
dim_head: int = attr.ib(default=64)
dropout_rate: float = attr.ib(default=0.0)
+ rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None)
scale: float = attr.ib(init=False)
dropout: nn.Dropout = attr.ib(init=False)
fc: nn.Linear = attr.ib(init=False)
- qkv_fn: nn.Sequential = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self.scale = self.dim ** -0.5
- inner_dim = self.dim * self.dim_head
+ inner_dim = self.num_heads * self.dim_head
self.query = nn.Linear(self.dim, inner_dim, bias=False)
self.key = nn.Linear(self.dim, inner_dim, bias=False)
@@ -50,7 +53,6 @@ class Attention(nn.Module):
context: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
- rotary_pos_emb: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Computes the attention."""
b, n, _, device = *x.shape, x.device
@@ -61,11 +63,10 @@ class Attention(nn.Module):
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
)
- q, k, v = (
- apply_rotary_emb(q, k, v, rotary_pos_emb)
- if rotary_pos_emb is not None and context is None
- else (q, k, v,)
- )
+
+ if self.rotary_embedding is not None:
+ embedding = self.rotary_embedding(q)
+ q, k, v = _apply_rotary_emb(q, k, v, embedding[None, ...])
energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
mask_value = -torch.finfo(energy.dtype).max
@@ -83,16 +84,12 @@ class Attention(nn.Module):
return out, attn
-def apply_rotary_emb(
- q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor
+def _apply_rotary_emb(
+ q: Tensor, k: Tensor, v: Tensor, freqs: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
- """Applies rotary embedding."""
- emb_len = rotary_pos_emb.shape[-1]
- (ql, qr), (kl, kr), (vl, vr) = map(
- lambda t: (t[..., :emb_len], t[..., emb_len:]), (q, k, v)
+ q, k, v = map(
+ lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k, v)
)
- ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
- q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
return q, k, v