From d020059f2f71fe7c25765dde9d535195c09ece01 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 3 Sep 2023 01:14:16 +0200 Subject: Update imports --- text_recognizer/network/transformer/attend.py | 4 ++-- text_recognizer/network/transformer/attention.py | 20 +++++++++++--------- text_recognizer/network/transformer/decoder.py | 4 ++-- .../network/transformer/embedding/absolute.py | 3 ++- .../network/transformer/embedding/token.py | 2 +- text_recognizer/network/transformer/encoder.py | 4 ++-- text_recognizer/network/vit.py | 11 +++++------ 7 files changed, 25 insertions(+), 23 deletions(-) diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py index 4e643fb..d2c17b1 100644 --- a/text_recognizer/network/transformer/attend.py +++ b/text_recognizer/network/transformer/attend.py @@ -32,7 +32,7 @@ class Attend(nn.Module): out = F.scaled_dot_product_attention(q, k, v, is_causal=causal) return out - def atten( + def attn( self, q: Tensor, k: Tensor, @@ -66,7 +66,7 @@ class Attend(nn.Module): if self.use_flash: return self.flash_attn(q, k, v, causal) else: - return self.atten(q, k, v, causal, mask) + return self.attn(q, k, v, causal, mask) def apply_input_mask( diff --git a/text_recognizer/network/transformer/attention.py b/text_recognizer/network/transformer/attention.py index 8e18f8a..dab2c7b 100644 --- a/text_recognizer/network/transformer/attention.py +++ b/text_recognizer/network/transformer/attention.py @@ -1,12 +1,11 @@ """Implements the attention module for the transformer.""" from typing import Optional -from text_recognizer.network.transformer.norm import RMSNorm -from text_recognizer.network.transformer.attend import Attend -import torch from einops import rearrange from torch import Tensor, nn +from .attend import Attend + class Attention(nn.Module): """Standard attention.""" @@ -23,18 +22,19 @@ class Attention(nn.Module): super().__init__() self.heads = heads inner_dim = dim_head * heads + self.scale = dim**-0.5 + self.causal = causal + self.dropout_rate = dropout_rate + self.dropout = nn.Dropout(p=self.dropout_rate) + self.norm = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_k = nn.Linear(dim, inner_dim, bias=False) self.to_v = nn.Linear(dim, inner_dim, bias=False) - # self.q_norm = RMSNorm(heads, dim_head) - # self.k_norm = RMSNorm(heads, dim_head) + self.attend = Attend(use_flash) + self.to_out = nn.Linear(inner_dim, dim, bias=False) - self.scale = dim**-0.5 - self.causal = causal - self.dropout_rate = dropout_rate - self.dropout = nn.Dropout(p=self.dropout_rate) def forward( self, @@ -47,9 +47,11 @@ class Attention(nn.Module): q = self.to_q(x) k = self.to_k(x if context is None else context) v = self.to_v(x if context is None else context) + q, k, v = map( lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v) ) + out = self.attend(q, k, v, self.causal, mask) out = rearrange(out, "b h n d -> b n (h d)") out = self.to_out(out) diff --git a/text_recognizer/network/transformer/decoder.py b/text_recognizer/network/transformer/decoder.py index 06925ba..24a8ac4 100644 --- a/text_recognizer/network/transformer/decoder.py +++ b/text_recognizer/network/transformer/decoder.py @@ -2,8 +2,8 @@ from typing import Optional from torch import Tensor, nn -from text_recognizer.network.transformer.attention import Attention -from text_recognizer.network.transformer.ff import FeedForward +from .attention import Attention +from .ff import FeedForward class Decoder(nn.Module): diff --git a/text_recognizer/network/transformer/embedding/absolute.py b/text_recognizer/network/transformer/embedding/absolute.py index 08b2c2a..db34157 100644 --- a/text_recognizer/network/transformer/embedding/absolute.py +++ b/text_recognizer/network/transformer/embedding/absolute.py @@ -2,7 +2,8 @@ from typing import Optional import torch from torch import nn, Tensor -from text_recognizer.network.transformer.embedding.l2_norm import l2_norm + +from .l2_norm import l2_norm class AbsolutePositionalEmbedding(nn.Module): diff --git a/text_recognizer/network/transformer/embedding/token.py b/text_recognizer/network/transformer/embedding/token.py index 1df2fd6..838f514 100644 --- a/text_recognizer/network/transformer/embedding/token.py +++ b/text_recognizer/network/transformer/embedding/token.py @@ -1,6 +1,6 @@ from torch import nn, Tensor -from text_recognizer.network.transformer.embedding.l2_norm import l2_norm +from .l2_norm import l2_norm class TokenEmbedding(nn.Module): diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py index ea4b0b3..328a40c 100644 --- a/text_recognizer/network/transformer/encoder.py +++ b/text_recognizer/network/transformer/encoder.py @@ -1,8 +1,8 @@ """Transformer encoder module.""" from torch import Tensor, nn -from text_recognizer.network.transformer.attention import Attention -from text_recognizer.network.transformer.ff import FeedForward +from .attention import Attention +from .ff import FeedForward class Encoder(nn.Module): diff --git a/text_recognizer/network/vit.py b/text_recognizer/network/vit.py index b6203d7..1fbf3fc 100644 --- a/text_recognizer/network/vit.py +++ b/text_recognizer/network/vit.py @@ -4,10 +4,10 @@ from typing import Type from einops.layers.torch import Rearrange from torch import Tensor, nn -from text_recognizer.network.transformer.embedding.token import TokenEmbedding -from text_recognizer.network.transformer.embedding.sincos import sincos_2d -from text_recognizer.network.transformer.decoder import Decoder -from text_recognizer.network.transformer.encoder import Encoder +from .transformer.embedding.token import TokenEmbedding +from .transformer.embedding.sincos import sincos_2d +from .transformer.decoder import Decoder +from .transformer.encoder import Encoder class VisionTransformer(nn.Module): @@ -59,11 +59,10 @@ class VisionTransformer(nn.Module): def decode(self, text: Tensor, img_features: Tensor) -> Tensor: text = text.long() - # TODO: add mask to decoder mask = text != self.pad_index tokens = self.token_embedding(text) tokens = tokens + self.pos_embedding(tokens) - output = self.decoder(tokens, context=img_features) + output = self.decoder(tokens, context=img_features, mask=mask) return self.to_logits(output) def forward( -- cgit v1.2.3-70-g09d2