diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-01 23:10:12 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-01 23:10:12 +0200 |
commit | db86cef2d308f58325278061c6aa177a535e7e03 (patch) | |
tree | a013fa85816337269f9cdc5a8992813fa62d299d /text_recognizer/networks/transformer | |
parent | b980a281712a5b1ee7ee5bd8f5d4762cd91a070b (diff) |
Replace attr with attrs
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 24 |
1 files changed, 11 insertions, 13 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 87792a9..aa15b88 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,7 +1,7 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple -import attr +from attrs import define, field from einops import rearrange import torch from torch import einsum @@ -15,22 +15,22 @@ from text_recognizer.networks.transformer.embeddings.rotary import ( ) -@attr.s(eq=False) +@define(eq=False) class Attention(nn.Module): """Standard attention.""" def __attrs_pre_init__(self) -> None: super().__init__() - dim: int = attr.ib() - num_heads: int = attr.ib() - causal: bool = attr.ib(default=False) - dim_head: int = attr.ib(default=64) - dropout_rate: float = attr.ib(default=0.0) - rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None) - scale: float = attr.ib(init=False) - dropout: nn.Dropout = attr.ib(init=False) - fc: nn.Linear = attr.ib(init=False) + dim: int = field() + num_heads: int = field() + causal: bool = field(default=False) + dim_head: int = field(default=64) + dropout_rate: float = field(default=0.0) + rotary_embedding: Optional[RotaryEmbedding] = field(default=None) + scale: float = field(init=False) + dropout: nn.Dropout = field(init=False) + fc: nn.Linear = field(init=False) def __attrs_post_init__(self) -> None: self.scale = self.dim ** -0.5 @@ -120,7 +120,6 @@ def apply_input_mask( input_mask = q_mask * k_mask energy = energy.masked_fill_(~input_mask, mask_value) - del input_mask return energy @@ -133,5 +132,4 @@ def apply_causal_mask( mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") mask = F.pad(mask, (j - i, 0), value=False) energy.masked_fill_(mask, mask_value) - del mask return energy |