diff options
Diffstat (limited to 'text_recognizer/network/transformer')
6 files changed, 20 insertions, 17 deletions
diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py index 4e643fb..d2c17b1 100644 --- a/text_recognizer/network/transformer/attend.py +++ b/text_recognizer/network/transformer/attend.py @@ -32,7 +32,7 @@ class Attend(nn.Module): out = F.scaled_dot_product_attention(q, k, v, is_causal=causal) return out - def atten( + def attn( self, q: Tensor, k: Tensor, @@ -66,7 +66,7 @@ class Attend(nn.Module): if self.use_flash: return self.flash_attn(q, k, v, causal) else: - return self.atten(q, k, v, causal, mask) + return self.attn(q, k, v, causal, mask) def apply_input_mask( diff --git a/text_recognizer/network/transformer/attention.py b/text_recognizer/network/transformer/attention.py index 8e18f8a..dab2c7b 100644 --- a/text_recognizer/network/transformer/attention.py +++ b/text_recognizer/network/transformer/attention.py @@ -1,12 +1,11 @@ """Implements the attention module for the transformer.""" from typing import Optional -from text_recognizer.network.transformer.norm import RMSNorm -from text_recognizer.network.transformer.attend import Attend -import torch from einops import rearrange from torch import Tensor, nn +from .attend import Attend + class Attention(nn.Module): """Standard attention.""" @@ -23,18 +22,19 @@ class Attention(nn.Module): super().__init__() self.heads = heads inner_dim = dim_head * heads + self.scale = dim**-0.5 + self.causal = causal + self.dropout_rate = dropout_rate + self.dropout = nn.Dropout(p=self.dropout_rate) + self.norm = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_k = nn.Linear(dim, inner_dim, bias=False) self.to_v = nn.Linear(dim, inner_dim, bias=False) - # self.q_norm = RMSNorm(heads, dim_head) - # self.k_norm = RMSNorm(heads, dim_head) + self.attend = Attend(use_flash) + self.to_out = nn.Linear(inner_dim, dim, bias=False) - self.scale = dim**-0.5 - self.causal = causal - self.dropout_rate = dropout_rate - self.dropout = nn.Dropout(p=self.dropout_rate) def forward( self, @@ -47,9 +47,11 @@ class Attention(nn.Module): q = self.to_q(x) k = self.to_k(x if context is None else context) v = self.to_v(x if context is None else context) + q, k, v = map( lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v) ) + out = self.attend(q, k, v, self.causal, mask) out = rearrange(out, "b h n d -> b n (h d)") out = self.to_out(out) diff --git a/text_recognizer/network/transformer/decoder.py b/text_recognizer/network/transformer/decoder.py index 06925ba..24a8ac4 100644 --- a/text_recognizer/network/transformer/decoder.py +++ b/text_recognizer/network/transformer/decoder.py @@ -2,8 +2,8 @@ from typing import Optional from torch import Tensor, nn -from text_recognizer.network.transformer.attention import Attention -from text_recognizer.network.transformer.ff import FeedForward +from .attention import Attention +from .ff import FeedForward class Decoder(nn.Module): diff --git a/text_recognizer/network/transformer/embedding/absolute.py b/text_recognizer/network/transformer/embedding/absolute.py index 08b2c2a..db34157 100644 --- a/text_recognizer/network/transformer/embedding/absolute.py +++ b/text_recognizer/network/transformer/embedding/absolute.py @@ -2,7 +2,8 @@ from typing import Optional import torch from torch import nn, Tensor -from text_recognizer.network.transformer.embedding.l2_norm import l2_norm + +from .l2_norm import l2_norm class AbsolutePositionalEmbedding(nn.Module): diff --git a/text_recognizer/network/transformer/embedding/token.py b/text_recognizer/network/transformer/embedding/token.py index 1df2fd6..838f514 100644 --- a/text_recognizer/network/transformer/embedding/token.py +++ b/text_recognizer/network/transformer/embedding/token.py @@ -1,6 +1,6 @@ from torch import nn, Tensor -from text_recognizer.network.transformer.embedding.l2_norm import l2_norm +from .l2_norm import l2_norm class TokenEmbedding(nn.Module): diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py index ea4b0b3..328a40c 100644 --- a/text_recognizer/network/transformer/encoder.py +++ b/text_recognizer/network/transformer/encoder.py @@ -1,8 +1,8 @@ """Transformer encoder module.""" from torch import Tensor, nn -from text_recognizer.network.transformer.attention import Attention -from text_recognizer.network.transformer.ff import FeedForward +from .attention import Attention +from .ff import FeedForward class Encoder(nn.Module): |