summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-28 15:14:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-28 15:14:55 +0200
commitc032ffb05a7ed86f8fe5d596f94e8997c558cae8 (patch)
treebf890ffd4c815db7d510cfb281d253b5728f70c6 /text_recognizer/networks/transformer/attention.py
parent524bf4351ac295bd4ff9914bb1f32eda7f7ff855 (diff)
Reformatting with attrs, config for encoder and decoder
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r--text_recognizer/networks/transformer/attention.py40
1 files changed, 24 insertions, 16 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 7bafc58..2770dc1 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,6 +1,7 @@
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
+import attr
from einops import rearrange
from einops.layers.torch import Rearrange
import torch
@@ -14,31 +15,38 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding
)
+@attr.s
class Attention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int,
- dim_head: int = 64,
- dropout_rate: float = 0.0,
- causal: bool = False,
- ) -> None:
+ """Standard attention."""
+
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.scale = dim ** -0.5
- self.num_heads = num_heads
- self.causal = causal
- inner_dim = dim * dim_head
+
+ dim: int = attr.ib()
+ num_heads: int = attr.ib()
+ dim_head: int = attr.ib(default=64)
+ dropout_rate: float = attr.ib(default=0.0)
+ casual: bool = attr.ib(default=False)
+ scale: float = attr.ib(init=False)
+ dropout: nn.Dropout = attr.ib(init=False)
+ fc: nn.Linear = attr.ib(init=False)
+ qkv_fn: nn.Sequential = attr.ib(init=False)
+ attn_fn: F.softmax = attr.ib(init=False, default=F.softmax)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self.scale = self.dim ** -0.5
+ inner_dim = self.dim * self.dim_head
# Attnetion
self.qkv_fn = nn.Sequential(
- nn.Linear(dim, 3 * inner_dim, bias=False),
+ nn.Linear(self.dim, 3 * inner_dim, bias=False),
Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads),
)
- self.dropout = nn.Dropout(dropout_rate)
- self.attn_fn = F.softmax
+ self.dropout = nn.Dropout(p=self.dropout_rate)
# Feedforward
- self.fc = nn.Linear(inner_dim, dim)
+ self.fc = nn.Linear(inner_dim, self.dim)
@staticmethod
def _apply_rotary_emb(