summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/attention.py24
1 files changed, 11 insertions, 13 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 87792a9..aa15b88 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,7 +1,7 @@
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
-import attr
+from attrs import define, field
from einops import rearrange
import torch
from torch import einsum
@@ -15,22 +15,22 @@ from text_recognizer.networks.transformer.embeddings.rotary import (
)
-@attr.s(eq=False)
+@define(eq=False)
class Attention(nn.Module):
"""Standard attention."""
def __attrs_pre_init__(self) -> None:
super().__init__()
- dim: int = attr.ib()
- num_heads: int = attr.ib()
- causal: bool = attr.ib(default=False)
- dim_head: int = attr.ib(default=64)
- dropout_rate: float = attr.ib(default=0.0)
- rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None)
- scale: float = attr.ib(init=False)
- dropout: nn.Dropout = attr.ib(init=False)
- fc: nn.Linear = attr.ib(init=False)
+ dim: int = field()
+ num_heads: int = field()
+ causal: bool = field(default=False)
+ dim_head: int = field(default=64)
+ dropout_rate: float = field(default=0.0)
+ rotary_embedding: Optional[RotaryEmbedding] = field(default=None)
+ scale: float = field(init=False)
+ dropout: nn.Dropout = field(init=False)
+ fc: nn.Linear = field(init=False)
def __attrs_post_init__(self) -> None:
self.scale = self.dim ** -0.5
@@ -120,7 +120,6 @@ def apply_input_mask(
input_mask = q_mask * k_mask
energy = energy.masked_fill_(~input_mask, mask_value)
- del input_mask
return energy
@@ -133,5 +132,4 @@ def apply_causal_mask(
mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
mask = F.pad(mask, (j - i, 0), value=False)
energy.masked_fill_(mask, mask_value)
- del mask
return energy