From 737000da5b44276512beffc1bdf81057df43ab2c Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 2 May 2021 22:27:42 +0200
Subject: Attention layer finished

---
 notebooks/00-testing-stuff-out.ipynb              | 163 ++++++++++++++++++++++
 text_recognizer/models/vqvae.py                   |  12 +-
 text_recognizer/networks/transformer/attention.py |  72 ++++++++--
 training/conf/network/vqvae.yaml                  |   6 +-
 4 files changed, 233 insertions(+), 20 deletions(-)

diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb
index 7c7b3a6..92faaf7 100644
--- a/notebooks/00-testing-stuff-out.ipynb
+++ b/notebooks/00-testing-stuff-out.ipynb
@@ -352,6 +352,169 @@
     "t(datum, trg).shape"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "b, n = 16, 128\n",
+    "device = \"cpu\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x = lambda: torch.ones((b, n), device=device).bool()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([16, 128])"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "x().shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([16, 128])"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "torch.ones((b, n), device=device).bool().shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x = torch.randn(1, 1, 576, 640)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "144"
+      ]
+     },
+     "execution_count": 30,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "576 // 4"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "160"
+      ]
+     },
+     "execution_count": 29,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "640 // 4"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x = torch.randn(1, 1, 144, 160)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from einops import rearrange"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "patch_size=4\n",
+    "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 36,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([1, 1440, 16])"
+      ]
+     },
+     "execution_count": 36,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "p.shape"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index ef2213c..078235e 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -29,12 +29,14 @@ class LitVQVAEModel(LitBaseModel):
         """Forward pass with the transformer network."""
         return self.network.predict(data)
 
-    def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None:
+    def _log_prediction(
+        self, data: Tensor, reconstructions: Tensor, title: str
+    ) -> None:
         """Logs prediction on image with wandb."""
         try:
             self.logger.experiment.log(
                 {
-                    "val_pred_examples": [
+                    title: [
                         wandb.Image(data[0]),
                         wandb.Image(reconstructions[0]),
                     ]
@@ -59,7 +61,8 @@ class LitVQVAEModel(LitBaseModel):
         loss = self.loss_fn(reconstructions, data)
         loss += vq_loss
         self.log("val_loss", loss, prog_bar=True)
-        self._log_prediction(data, reconstructions)
+        title = "val_pred_examples"
+        self._log_prediction(data, reconstructions, title)
 
     def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
         """Test step."""
@@ -67,4 +70,5 @@ class LitVQVAEModel(LitBaseModel):
         reconstructions, vq_loss = self.network(data)
         loss = self.loss_fn(reconstructions, data)
         loss += vq_loss
-        self._log_prediction(data, reconstructions)
+        title = "test_pred_examples"
+        self._log_prediction(data, reconstructions, title)
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 8724691..623d680 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,9 +1,10 @@
 """Implementes the attention module for the transformer."""
 from typing import Optional, Tuple
 
+from einops import rearrange
 from einops.layers.torch import Rearrange
-import numpy as np
 import torch
+from torch import einsum
 from torch import nn
 from torch import Tensor
 import torch.nn.functional as F
@@ -34,7 +35,7 @@ class Attention(nn.Module):
         self.attn_fn = F.softmax
 
         # Feedforward
-        self.proj = nn.Linear(inner_dim, dim)
+        self.fc = nn.Linear(inner_dim, dim)
 
     @staticmethod
     def _apply_rotary_emb(
@@ -47,8 +48,42 @@ class Attention(nn.Module):
         k = torch.cat((kl, kr), dim=-1)
         return q, k
 
-    def _cross_attention(self) -> Tensor:
-        pass
+    @staticmethod
+    def _compute_input_mask(
+        b: int,
+        n: int,
+        k: Tensor,
+        mask: Optional[Tensor],
+        context: Optional[Tensor],
+        context_mask: Optional[Tensor],
+        device: str,
+    ) -> Optional[Tensor]:
+        if any(x is not None for x in (mask, context_mask)):
+            q_mask = (
+                mask if mask is not None else torch.ones((b, n), device=device).bool()
+            )
+            k_mask = q_mask if context is not None else context_mask
+            k_mask = (
+                torch.ones((b, k.shape[-2]), device=device).bool()
+                if k_mask is None
+                else k_mask
+            )
+            q_mask = rearrange(q_mask, "b i -> b () i ()")
+            k_mask = rearrange(k_mask, "b i -> b () () j")
+            return q_mask * k_mask
+        return
+
+    @staticmethod
+    def _apply_causal_mask(
+        energy: Tensor, mask: Tensor, mask_value: Tensor, device: str
+    ) -> Tensor:
+        i, j = energy.shape[-2:]
+        r = torch.arange(i, device=device)
+        mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
+        mask = F.pad(mask, (j - i, 0), value=False)
+        energy.masked_fill_(mask, mask_value)
+        del mask
+        return energy
 
     def forward(
         self,
@@ -67,14 +102,25 @@ class Attention(nn.Module):
             k,
         )
 
-        input_mask = None
-        if any(x is not None for x in (mask, context_mask)):
-            q_mask = (
-                mask
-                if mask is not None
-                else lambda: torch.ones((b, n), device=device).bool()
-            )
-            pass
+        input_mask = self._compute_input_mask(
+            b, n, k, mask, context, context_mask, device
+        )
 
         # Compute the attention
-        energy = (q @ k.transpose(-2, -1)) * self.scale
+        energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
+        mask_value = -torch.finfo(energy.dtype).max
+
+        # Apply input mask
+        if input_mask is not None:
+            energy = energy.masked_fill_(~input_mask, mask_value)
+            del input_mask
+
+        if self.causal:
+            energy = self._apply_causal_mask(energy, mask, mask_value, device)
+
+        attn = self.attn_fn(energy, dim=-1)
+        attn = self.dropout(attn)
+        out = einsum("b h i j, b h j d -> b h i d", attn, v)
+        out = rearrange(out, "b h n d -> b n (h d)")
+        out = self.fc(out)
+        return out, attn
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
index 8c30bbd..288d2aa 100644
--- a/training/conf/network/vqvae.yaml
+++ b/training/conf/network/vqvae.yaml
@@ -2,9 +2,9 @@
 type: VQVAE
 args:
     in_channels: 1
-    channels: [32, 64, 64]
-    kernel_sizes: [4, 4, 4]
-    strides: [2, 2, 2]
+    channels: [64, 96]
+    kernel_sizes: [4, 4]
+    strides: [2, 2]
     num_residual_layers: 2 
     embedding_dim: 64
     num_embeddings: 256
-- 
cgit v1.2.3-70-g09d2