summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/network/transformer/attention.py35
1 files changed, 30 insertions, 5 deletions
diff --git a/text_recognizer/network/transformer/attention.py b/text_recognizer/network/transformer/attention.py
index dab2c7b..bae077f 100644
--- a/text_recognizer/network/transformer/attention.py
+++ b/text_recognizer/network/transformer/attention.py
@@ -2,9 +2,12 @@
from typing import Optional
from einops import rearrange
+from text_recognizer.network.transformer.swiglu import SwiGLU
+import torch
from torch import Tensor, nn
from .attend import Attend
+from .embedding.rotary import RotaryEmbedding, apply_rotary_pos_emb
class Attention(nn.Module):
@@ -16,8 +19,11 @@ class Attention(nn.Module):
heads: int,
causal: bool = False,
dim_head: int = 64,
+ ff_mult: int = 4,
dropout_rate: float = 0.0,
use_flash: bool = True,
+ norm_context: bool = False,
+ rotary_emb: Optional[RotaryEmbedding] = None,
) -> None:
super().__init__()
self.heads = heads
@@ -28,13 +34,27 @@ class Attention(nn.Module):
self.dropout = nn.Dropout(p=self.dropout_rate)
self.norm = nn.LayerNorm(dim)
+ self.context_norm = nn.LayerNorm(dim) if norm_context else nn.Identity()
self.to_q = nn.Linear(dim, inner_dim, bias=False)
- self.to_k = nn.Linear(dim, inner_dim, bias=False)
- self.to_v = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, 2 * inner_dim, bias=False)
self.attend = Attend(use_flash)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
+ self.rotary_emb = rotary_emb
+ self.pos_emb = None
+ ff_inner_dim = ff_mult * dim
+
+ self.ff = nn.Sequential(
+ nn.Linear(dim, 2 * ff_inner_dim), SwiGLU(), nn.Linear(ff_inner_dim, dim)
+ )
+
+ def get_rotary_embedding(self, n: int, device: torch.device) -> Tensor:
+ assert self.rotary_emb is not None, "No rotary embedding"
+ if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
+ return self.pos_emb[:n].to(device)
+ self.pos_emb = self.rotary_emb(n, device=device)
+ return self.pos_emb
def forward(
self,
@@ -45,14 +65,19 @@ class Attention(nn.Module):
"""Computes the attention."""
x = self.norm(x)
q = self.to_q(x)
- k = self.to_k(x if context is None else context)
- v = self.to_v(x if context is None else context)
+ k, v = self.to_kv(x if context is None else self.context_norm(context)).chunk(
+ 2, dim=-1
+ )
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
)
+ if self.rotary_emb is not None:
+ pos_emb = self.get_rotary_embedding(x.shape[1], x.device)
+ q, k = map(lambda t: apply_rotary_pos_emb(pos_emb, t), (q, k))
+
out = self.attend(q, k, v, self.causal, mask)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out(out)
- return out
+ return out + self.ff(x)