summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network/transformer/attention.py')
-rw-r--r--text_recognizer/network/transformer/attention.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/text_recognizer/network/transformer/attention.py b/text_recognizer/network/transformer/attention.py
index bae077f..9a4aa0d 100644
--- a/text_recognizer/network/transformer/attention.py
+++ b/text_recognizer/network/transformer/attention.py
@@ -1,11 +1,12 @@
"""Implements the attention module for the transformer."""
from typing import Optional
-from einops import rearrange
-from text_recognizer.network.transformer.swiglu import SwiGLU
import torch
+from einops import rearrange
from torch import Tensor, nn
+from text_recognizer.network.transformer.swiglu import SwiGLU
+
from .attend import Attend
from .embedding.rotary import RotaryEmbedding, apply_rotary_pos_emb
@@ -23,7 +24,8 @@ class Attention(nn.Module):
dropout_rate: float = 0.0,
use_flash: bool = True,
norm_context: bool = False,
- rotary_emb: Optional[RotaryEmbedding] = None,
+ use_rotary_emb: bool = False,
+ one_kv_head: bool = False,
) -> None:
super().__init__()
self.heads = heads
@@ -36,12 +38,13 @@ class Attention(nn.Module):
self.norm = nn.LayerNorm(dim)
self.context_norm = nn.LayerNorm(dim) if norm_context else nn.Identity()
self.to_q = nn.Linear(dim, inner_dim, bias=False)
- self.to_kv = nn.Linear(dim, 2 * inner_dim, bias=False)
+ self.kv_heads = 1 if one_kv_head else heads
+ self.to_kv = nn.Linear(dim, 2 * self.kv_heads * dim_head, bias=False)
self.attend = Attend(use_flash)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
- self.rotary_emb = rotary_emb
+ self.rotary_emb = RotaryEmbedding(dim_head) if use_rotary_emb else None
self.pos_emb = None
ff_inner_dim = ff_mult * dim
@@ -68,9 +71,9 @@ class Attention(nn.Module):
k, v = self.to_kv(x if context is None else self.context_norm(context)).chunk(
2, dim=-1
)
-
- q, k, v = map(
- lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
+ k, v = map(
+ lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.kv_heads), (k, v)
)
if self.rotary_emb is not None: