summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/transformer/attention.py42
1 files changed, 22 insertions, 20 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index aa15b88..3df5333 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,7 +1,6 @@
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
-from attrs import define, field
from einops import rearrange
import torch
from torch import einsum
@@ -15,30 +14,33 @@ from text_recognizer.networks.transformer.embeddings.rotary import (
)
-@define(eq=False)
class Attention(nn.Module):
"""Standard attention."""
- def __attrs_pre_init__(self) -> None:
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ causal: bool = False,
+ dim_head: int = 64,
+ dropout_rate: float = 0.0,
+ rotary_embedding: Optional[RotaryEmbedding] = None,
+ ) -> None:
super().__init__()
- dim: int = field()
- num_heads: int = field()
- causal: bool = field(default=False)
- dim_head: int = field(default=64)
- dropout_rate: float = field(default=0.0)
- rotary_embedding: Optional[RotaryEmbedding] = field(default=None)
- scale: float = field(init=False)
- dropout: nn.Dropout = field(init=False)
- fc: nn.Linear = field(init=False)
-
- def __attrs_post_init__(self) -> None:
+ self.dim = dim
+ self.num_heads = num_heads
+ self.causal = causal
+ self.dim_head = dim_head
+ self.dropout_rate = dropout_rate
+ self.rotary_embedding = rotary_embedding
+
self.scale = self.dim ** -0.5
inner_dim = self.num_heads * self.dim_head
- self.query = nn.Linear(self.dim, inner_dim, bias=False)
- self.key = nn.Linear(self.dim, inner_dim, bias=False)
- self.value = nn.Linear(self.dim, inner_dim, bias=False)
+ self.to_q = nn.Linear(self.dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(self.dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(self.dim, inner_dim, bias=False)
self.dropout = nn.Dropout(p=self.dropout_rate)
@@ -55,9 +57,9 @@ class Attention(nn.Module):
"""Computes the attention."""
b, n, _, device = *x.shape, x.device
- q = self.query(x)
- k = self.key(context) if context is not None else self.key(x)
- v = self.value(context) if context is not None else self.value(x)
+ q = self.to_q(x)
+ k = self.to_k(context) if context is not None else self.to_k(x)
+ v = self.to_v(context) if context is not None else self.to_v(x)
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
)