diff options
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 40 |
1 files changed, 24 insertions, 16 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 7bafc58..2770dc1 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,6 +1,7 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple +import attr from einops import rearrange from einops.layers.torch import Rearrange import torch @@ -14,31 +15,38 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding ) +@attr.s class Attention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - dim_head: int = 64, - dropout_rate: float = 0.0, - causal: bool = False, - ) -> None: + """Standard attention.""" + + def __attrs_pre_init__(self) -> None: super().__init__() - self.scale = dim ** -0.5 - self.num_heads = num_heads - self.causal = causal - inner_dim = dim * dim_head + + dim: int = attr.ib() + num_heads: int = attr.ib() + dim_head: int = attr.ib(default=64) + dropout_rate: float = attr.ib(default=0.0) + casual: bool = attr.ib(default=False) + scale: float = attr.ib(init=False) + dropout: nn.Dropout = attr.ib(init=False) + fc: nn.Linear = attr.ib(init=False) + qkv_fn: nn.Sequential = attr.ib(init=False) + attn_fn: F.softmax = attr.ib(init=False, default=F.softmax) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.scale = self.dim ** -0.5 + inner_dim = self.dim * self.dim_head # Attnetion self.qkv_fn = nn.Sequential( - nn.Linear(dim, 3 * inner_dim, bias=False), + nn.Linear(self.dim, 3 * inner_dim, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads), ) - self.dropout = nn.Dropout(dropout_rate) - self.attn_fn = F.softmax + self.dropout = nn.Dropout(p=self.dropout_rate) # Feedforward - self.fc = nn.Linear(inner_dim, dim) + self.fc = nn.Linear(inner_dim, self.dim) @staticmethod def _apply_rotary_emb( |