From b3fbfd72a8f647161685b28d20b4b61519d8a643 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 15 Apr 2024 21:49:51 +0200 Subject: Update transformer model --- text_recognizer/network/transformer/attention.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'text_recognizer/network/transformer/attention.py') diff --git a/text_recognizer/network/transformer/attention.py b/text_recognizer/network/transformer/attention.py index bae077f..9a4aa0d 100644 --- a/text_recognizer/network/transformer/attention.py +++ b/text_recognizer/network/transformer/attention.py @@ -1,11 +1,12 @@ """Implements the attention module for the transformer.""" from typing import Optional -from einops import rearrange -from text_recognizer.network.transformer.swiglu import SwiGLU import torch +from einops import rearrange from torch import Tensor, nn +from text_recognizer.network.transformer.swiglu import SwiGLU + from .attend import Attend from .embedding.rotary import RotaryEmbedding, apply_rotary_pos_emb @@ -23,7 +24,8 @@ class Attention(nn.Module): dropout_rate: float = 0.0, use_flash: bool = True, norm_context: bool = False, - rotary_emb: Optional[RotaryEmbedding] = None, + use_rotary_emb: bool = False, + one_kv_head: bool = False, ) -> None: super().__init__() self.heads = heads @@ -36,12 +38,13 @@ class Attention(nn.Module): self.norm = nn.LayerNorm(dim) self.context_norm = nn.LayerNorm(dim) if norm_context else nn.Identity() self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(dim, 2 * inner_dim, bias=False) + self.kv_heads = 1 if one_kv_head else heads + self.to_kv = nn.Linear(dim, 2 * self.kv_heads * dim_head, bias=False) self.attend = Attend(use_flash) self.to_out = nn.Linear(inner_dim, dim, bias=False) - self.rotary_emb = rotary_emb + self.rotary_emb = RotaryEmbedding(dim_head) if use_rotary_emb else None self.pos_emb = None ff_inner_dim = ff_mult * dim @@ -68,9 +71,9 @@ class Attention(nn.Module): k, v = self.to_kv(x if context is None else self.context_norm(context)).chunk( 2, dim=-1 ) - - q, k, v = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v) + q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) + k, v = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.kv_heads), (k, v) ) if self.rotary_emb is not None: -- cgit v1.2.3-70-g09d2