diff options
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 42 |
1 files changed, 22 insertions, 20 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index aa15b88..3df5333 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,7 +1,6 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple -from attrs import define, field from einops import rearrange import torch from torch import einsum @@ -15,30 +14,33 @@ from text_recognizer.networks.transformer.embeddings.rotary import ( ) -@define(eq=False) class Attention(nn.Module): """Standard attention.""" - def __attrs_pre_init__(self) -> None: + def __init__( + self, + dim: int, + num_heads: int, + causal: bool = False, + dim_head: int = 64, + dropout_rate: float = 0.0, + rotary_embedding: Optional[RotaryEmbedding] = None, + ) -> None: super().__init__() - dim: int = field() - num_heads: int = field() - causal: bool = field(default=False) - dim_head: int = field(default=64) - dropout_rate: float = field(default=0.0) - rotary_embedding: Optional[RotaryEmbedding] = field(default=None) - scale: float = field(init=False) - dropout: nn.Dropout = field(init=False) - fc: nn.Linear = field(init=False) - - def __attrs_post_init__(self) -> None: + self.dim = dim + self.num_heads = num_heads + self.causal = causal + self.dim_head = dim_head + self.dropout_rate = dropout_rate + self.rotary_embedding = rotary_embedding + self.scale = self.dim ** -0.5 inner_dim = self.num_heads * self.dim_head - self.query = nn.Linear(self.dim, inner_dim, bias=False) - self.key = nn.Linear(self.dim, inner_dim, bias=False) - self.value = nn.Linear(self.dim, inner_dim, bias=False) + self.to_q = nn.Linear(self.dim, inner_dim, bias=False) + self.to_k = nn.Linear(self.dim, inner_dim, bias=False) + self.to_v = nn.Linear(self.dim, inner_dim, bias=False) self.dropout = nn.Dropout(p=self.dropout_rate) @@ -55,9 +57,9 @@ class Attention(nn.Module): """Computes the attention.""" b, n, _, device = *x.shape, x.device - q = self.query(x) - k = self.key(context) if context is not None else self.key(x) - v = self.value(context) if context is not None else self.value(x) + q = self.to_q(x) + k = self.to_k(context) if context is not None else self.to_k(x) + v = self.to_v(context) if context is not None else self.to_v(x) q, k, v = map( lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) ) |