From d020059f2f71fe7c25765dde9d535195c09ece01 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 3 Sep 2023 01:14:16 +0200
Subject: Update imports

---
 text_recognizer/network/transformer/attend.py        |  4 ++--
 text_recognizer/network/transformer/attention.py     | 20 +++++++++++---------
 text_recognizer/network/transformer/decoder.py       |  4 ++--
 .../network/transformer/embedding/absolute.py        |  3 ++-
 .../network/transformer/embedding/token.py           |  2 +-
 text_recognizer/network/transformer/encoder.py       |  4 ++--
 text_recognizer/network/vit.py                       | 11 +++++------
 7 files changed, 25 insertions(+), 23 deletions(-)

diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py
index 4e643fb..d2c17b1 100644
--- a/text_recognizer/network/transformer/attend.py
+++ b/text_recognizer/network/transformer/attend.py
@@ -32,7 +32,7 @@ class Attend(nn.Module):
             out = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
         return out
 
-    def atten(
+    def attn(
         self,
         q: Tensor,
         k: Tensor,
@@ -66,7 +66,7 @@ class Attend(nn.Module):
         if self.use_flash:
             return self.flash_attn(q, k, v, causal)
         else:
-            return self.atten(q, k, v, causal, mask)
+            return self.attn(q, k, v, causal, mask)
 
 
 def apply_input_mask(
diff --git a/text_recognizer/network/transformer/attention.py b/text_recognizer/network/transformer/attention.py
index 8e18f8a..dab2c7b 100644
--- a/text_recognizer/network/transformer/attention.py
+++ b/text_recognizer/network/transformer/attention.py
@@ -1,12 +1,11 @@
 """Implements the attention module for the transformer."""
 from typing import Optional
-from text_recognizer.network.transformer.norm import RMSNorm
-from text_recognizer.network.transformer.attend import Attend
 
-import torch
 from einops import rearrange
 from torch import Tensor, nn
 
+from .attend import Attend
+
 
 class Attention(nn.Module):
     """Standard attention."""
@@ -23,18 +22,19 @@ class Attention(nn.Module):
         super().__init__()
         self.heads = heads
         inner_dim = dim_head * heads
+        self.scale = dim**-0.5
+        self.causal = causal
+        self.dropout_rate = dropout_rate
+        self.dropout = nn.Dropout(p=self.dropout_rate)
+
         self.norm = nn.LayerNorm(dim)
         self.to_q = nn.Linear(dim, inner_dim, bias=False)
         self.to_k = nn.Linear(dim, inner_dim, bias=False)
         self.to_v = nn.Linear(dim, inner_dim, bias=False)
-        # self.q_norm = RMSNorm(heads, dim_head)
-        # self.k_norm = RMSNorm(heads, dim_head)
+
         self.attend = Attend(use_flash)
+
         self.to_out = nn.Linear(inner_dim, dim, bias=False)
-        self.scale = dim**-0.5
-        self.causal = causal
-        self.dropout_rate = dropout_rate
-        self.dropout = nn.Dropout(p=self.dropout_rate)
 
     def forward(
         self,
@@ -47,9 +47,11 @@ class Attention(nn.Module):
         q = self.to_q(x)
         k = self.to_k(x if context is None else context)
         v = self.to_v(x if context is None else context)
+
         q, k, v = map(
             lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
         )
+
         out = self.attend(q, k, v, self.causal, mask)
         out = rearrange(out, "b h n d -> b n (h d)")
         out = self.to_out(out)
diff --git a/text_recognizer/network/transformer/decoder.py b/text_recognizer/network/transformer/decoder.py
index 06925ba..24a8ac4 100644
--- a/text_recognizer/network/transformer/decoder.py
+++ b/text_recognizer/network/transformer/decoder.py
@@ -2,8 +2,8 @@
 from typing import Optional
 from torch import Tensor, nn
 
-from text_recognizer.network.transformer.attention import Attention
-from text_recognizer.network.transformer.ff import FeedForward
+from .attention import Attention
+from .ff import FeedForward
 
 
 class Decoder(nn.Module):
diff --git a/text_recognizer/network/transformer/embedding/absolute.py b/text_recognizer/network/transformer/embedding/absolute.py
index 08b2c2a..db34157 100644
--- a/text_recognizer/network/transformer/embedding/absolute.py
+++ b/text_recognizer/network/transformer/embedding/absolute.py
@@ -2,7 +2,8 @@ from typing import Optional
 
 import torch
 from torch import nn, Tensor
-from text_recognizer.network.transformer.embedding.l2_norm import l2_norm
+
+from .l2_norm import l2_norm
 
 
 class AbsolutePositionalEmbedding(nn.Module):
diff --git a/text_recognizer/network/transformer/embedding/token.py b/text_recognizer/network/transformer/embedding/token.py
index 1df2fd6..838f514 100644
--- a/text_recognizer/network/transformer/embedding/token.py
+++ b/text_recognizer/network/transformer/embedding/token.py
@@ -1,6 +1,6 @@
 from torch import nn, Tensor
 
-from text_recognizer.network.transformer.embedding.l2_norm import l2_norm
+from .l2_norm import l2_norm
 
 
 class TokenEmbedding(nn.Module):
diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py
index ea4b0b3..328a40c 100644
--- a/text_recognizer/network/transformer/encoder.py
+++ b/text_recognizer/network/transformer/encoder.py
@@ -1,8 +1,8 @@
 """Transformer encoder module."""
 from torch import Tensor, nn
 
-from text_recognizer.network.transformer.attention import Attention
-from text_recognizer.network.transformer.ff import FeedForward
+from .attention import Attention
+from .ff import FeedForward
 
 
 class Encoder(nn.Module):
diff --git a/text_recognizer/network/vit.py b/text_recognizer/network/vit.py
index b6203d7..1fbf3fc 100644
--- a/text_recognizer/network/vit.py
+++ b/text_recognizer/network/vit.py
@@ -4,10 +4,10 @@ from typing import Type
 from einops.layers.torch import Rearrange
 from torch import Tensor, nn
 
-from text_recognizer.network.transformer.embedding.token import TokenEmbedding
-from text_recognizer.network.transformer.embedding.sincos import sincos_2d
-from text_recognizer.network.transformer.decoder import Decoder
-from text_recognizer.network.transformer.encoder import Encoder
+from .transformer.embedding.token import TokenEmbedding
+from .transformer.embedding.sincos import sincos_2d
+from .transformer.decoder import Decoder
+from .transformer.encoder import Encoder
 
 
 class VisionTransformer(nn.Module):
@@ -59,11 +59,10 @@ class VisionTransformer(nn.Module):
 
     def decode(self, text: Tensor, img_features: Tensor) -> Tensor:
         text = text.long()
-        # TODO: add mask to decoder
         mask = text != self.pad_index
         tokens = self.token_embedding(text)
         tokens = tokens + self.pos_embedding(tokens)
-        output = self.decoder(tokens, context=img_features)
+        output = self.decoder(tokens, context=img_features, mask=mask)
         return self.to_logits(output)
 
     def forward(
-- 
cgit v1.2.3-70-g09d2