From db86cef2d308f58325278061c6aa177a535e7e03 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 1 Jun 2022 23:10:12 +0200 Subject: Replace attr with attrs --- text_recognizer/networks/transformer/attention.py | 24 +++++++++++------------ 1 file changed, 11 insertions(+), 13 deletions(-) (limited to 'text_recognizer/networks/transformer/attention.py') diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 87792a9..aa15b88 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,7 +1,7 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple -import attr +from attrs import define, field from einops import rearrange import torch from torch import einsum @@ -15,22 +15,22 @@ from text_recognizer.networks.transformer.embeddings.rotary import ( ) -@attr.s(eq=False) +@define(eq=False) class Attention(nn.Module): """Standard attention.""" def __attrs_pre_init__(self) -> None: super().__init__() - dim: int = attr.ib() - num_heads: int = attr.ib() - causal: bool = attr.ib(default=False) - dim_head: int = attr.ib(default=64) - dropout_rate: float = attr.ib(default=0.0) - rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None) - scale: float = attr.ib(init=False) - dropout: nn.Dropout = attr.ib(init=False) - fc: nn.Linear = attr.ib(init=False) + dim: int = field() + num_heads: int = field() + causal: bool = field(default=False) + dim_head: int = field(default=64) + dropout_rate: float = field(default=0.0) + rotary_embedding: Optional[RotaryEmbedding] = field(default=None) + scale: float = field(init=False) + dropout: nn.Dropout = field(init=False) + fc: nn.Linear = field(init=False) def __attrs_post_init__(self) -> None: self.scale = self.dim ** -0.5 @@ -120,7 +120,6 @@ def apply_input_mask( input_mask = q_mask * k_mask energy = energy.masked_fill_(~input_mask, mask_value) - del input_mask return energy @@ -133,5 +132,4 @@ def apply_causal_mask( mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") mask = F.pad(mask, (j - i, 0), value=False) energy.masked_fill_(mask, mask_value) - del mask return energy -- cgit v1.2.3-70-g09d2