summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r--text_recognizer/networks/transformer/attention.py40
1 files changed, 24 insertions, 16 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 7bafc58..2770dc1 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,6 +1,7 @@
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
+import attr
from einops import rearrange
from einops.layers.torch import Rearrange
import torch
@@ -14,31 +15,38 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding
)
+@attr.s
class Attention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int,
- dim_head: int = 64,
- dropout_rate: float = 0.0,
- causal: bool = False,
- ) -> None:
+ """Standard attention."""
+
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.scale = dim ** -0.5
- self.num_heads = num_heads
- self.causal = causal
- inner_dim = dim * dim_head
+
+ dim: int = attr.ib()
+ num_heads: int = attr.ib()
+ dim_head: int = attr.ib(default=64)
+ dropout_rate: float = attr.ib(default=0.0)
+ casual: bool = attr.ib(default=False)
+ scale: float = attr.ib(init=False)
+ dropout: nn.Dropout = attr.ib(init=False)
+ fc: nn.Linear = attr.ib(init=False)
+ qkv_fn: nn.Sequential = attr.ib(init=False)
+ attn_fn: F.softmax = attr.ib(init=False, default=F.softmax)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self.scale = self.dim ** -0.5
+ inner_dim = self.dim * self.dim_head
# Attnetion
self.qkv_fn = nn.Sequential(
- nn.Linear(dim, 3 * inner_dim, bias=False),
+ nn.Linear(self.dim, 3 * inner_dim, bias=False),
Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads),
)
- self.dropout = nn.Dropout(dropout_rate)
- self.attn_fn = F.softmax
+ self.dropout = nn.Dropout(p=self.dropout_rate)
# Feedforward
- self.fc = nn.Linear(inner_dim, dim)
+ self.fc = nn.Linear(inner_dim, self.dim)
@staticmethod
def _apply_rotary_emb(