From d020059f2f71fe7c25765dde9d535195c09ece01 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 3 Sep 2023 01:14:16 +0200 Subject: Update imports --- text_recognizer/network/transformer/attention.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) (limited to 'text_recognizer/network/transformer/attention.py') diff --git a/text_recognizer/network/transformer/attention.py b/text_recognizer/network/transformer/attention.py index 8e18f8a..dab2c7b 100644 --- a/text_recognizer/network/transformer/attention.py +++ b/text_recognizer/network/transformer/attention.py @@ -1,12 +1,11 @@ """Implements the attention module for the transformer.""" from typing import Optional -from text_recognizer.network.transformer.norm import RMSNorm -from text_recognizer.network.transformer.attend import Attend -import torch from einops import rearrange from torch import Tensor, nn +from .attend import Attend + class Attention(nn.Module): """Standard attention.""" @@ -23,18 +22,19 @@ class Attention(nn.Module): super().__init__() self.heads = heads inner_dim = dim_head * heads + self.scale = dim**-0.5 + self.causal = causal + self.dropout_rate = dropout_rate + self.dropout = nn.Dropout(p=self.dropout_rate) + self.norm = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_k = nn.Linear(dim, inner_dim, bias=False) self.to_v = nn.Linear(dim, inner_dim, bias=False) - # self.q_norm = RMSNorm(heads, dim_head) - # self.k_norm = RMSNorm(heads, dim_head) + self.attend = Attend(use_flash) + self.to_out = nn.Linear(inner_dim, dim, bias=False) - self.scale = dim**-0.5 - self.causal = causal - self.dropout_rate = dropout_rate - self.dropout = nn.Dropout(p=self.dropout_rate) def forward( self, @@ -47,9 +47,11 @@ class Attention(nn.Module): q = self.to_q(x) k = self.to_k(x if context is None else context) v = self.to_v(x if context is None else context) + q, k, v = map( lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v) ) + out = self.attend(q, k, v, self.causal, mask) out = rearrange(out, "b h n d -> b n (h d)") out = self.to_out(out) -- cgit v1.2.3-70-g09d2