From 737000da5b44276512beffc1bdf81057df43ab2c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 2 May 2021 22:27:42 +0200 Subject: Attention layer finished --- notebooks/00-testing-stuff-out.ipynb | 163 ++++++++++++++++++++++ text_recognizer/models/vqvae.py | 12 +- text_recognizer/networks/transformer/attention.py | 72 ++++++++-- training/conf/network/vqvae.yaml | 6 +- 4 files changed, 233 insertions(+), 20 deletions(-) diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb index 7c7b3a6..92faaf7 100644 --- a/notebooks/00-testing-stuff-out.ipynb +++ b/notebooks/00-testing-stuff-out.ipynb @@ -352,6 +352,169 @@ "t(datum, trg).shape" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "b, n = 16, 128\n", + "device = \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "x = lambda: torch.ones((b, n), device=device).bool()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 128])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 128])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.ones((b, n), device=device).bool().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(1, 1, 576, 640)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "144" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "576 // 4" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "160" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "640 // 4" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(1, 1, 144, 160)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "from einops import rearrange" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "patch_size=4\n", + "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1440, 16])" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p.shape" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index ef2213c..078235e 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -29,12 +29,14 @@ class LitVQVAEModel(LitBaseModel): """Forward pass with the transformer network.""" return self.network.predict(data) - def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None: + def _log_prediction( + self, data: Tensor, reconstructions: Tensor, title: str + ) -> None: """Logs prediction on image with wandb.""" try: self.logger.experiment.log( { - "val_pred_examples": [ + title: [ wandb.Image(data[0]), wandb.Image(reconstructions[0]), ] @@ -59,7 +61,8 @@ class LitVQVAEModel(LitBaseModel): loss = self.loss_fn(reconstructions, data) loss += vq_loss self.log("val_loss", loss, prog_bar=True) - self._log_prediction(data, reconstructions) + title = "val_pred_examples" + self._log_prediction(data, reconstructions, title) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -67,4 +70,5 @@ class LitVQVAEModel(LitBaseModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - self._log_prediction(data, reconstructions) + title = "test_pred_examples" + self._log_prediction(data, reconstructions, title) diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 8724691..623d680 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,9 +1,10 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple +from einops import rearrange from einops.layers.torch import Rearrange -import numpy as np import torch +from torch import einsum from torch import nn from torch import Tensor import torch.nn.functional as F @@ -34,7 +35,7 @@ class Attention(nn.Module): self.attn_fn = F.softmax # Feedforward - self.proj = nn.Linear(inner_dim, dim) + self.fc = nn.Linear(inner_dim, dim) @staticmethod def _apply_rotary_emb( @@ -47,8 +48,42 @@ class Attention(nn.Module): k = torch.cat((kl, kr), dim=-1) return q, k - def _cross_attention(self) -> Tensor: - pass + @staticmethod + def _compute_input_mask( + b: int, + n: int, + k: Tensor, + mask: Optional[Tensor], + context: Optional[Tensor], + context_mask: Optional[Tensor], + device: str, + ) -> Optional[Tensor]: + if any(x is not None for x in (mask, context_mask)): + q_mask = ( + mask if mask is not None else torch.ones((b, n), device=device).bool() + ) + k_mask = q_mask if context is not None else context_mask + k_mask = ( + torch.ones((b, k.shape[-2]), device=device).bool() + if k_mask is None + else k_mask + ) + q_mask = rearrange(q_mask, "b i -> b () i ()") + k_mask = rearrange(k_mask, "b i -> b () () j") + return q_mask * k_mask + return + + @staticmethod + def _apply_causal_mask( + energy: Tensor, mask: Tensor, mask_value: Tensor, device: str + ) -> Tensor: + i, j = energy.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") + mask = F.pad(mask, (j - i, 0), value=False) + energy.masked_fill_(mask, mask_value) + del mask + return energy def forward( self, @@ -67,14 +102,25 @@ class Attention(nn.Module): k, ) - input_mask = None - if any(x is not None for x in (mask, context_mask)): - q_mask = ( - mask - if mask is not None - else lambda: torch.ones((b, n), device=device).bool() - ) - pass + input_mask = self._compute_input_mask( + b, n, k, mask, context, context_mask, device + ) # Compute the attention - energy = (q @ k.transpose(-2, -1)) * self.scale + energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + mask_value = -torch.finfo(energy.dtype).max + + # Apply input mask + if input_mask is not None: + energy = energy.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + energy = self._apply_causal_mask(energy, mask, mask_value, device) + + attn = self.attn_fn(energy, dim=-1) + attn = self.dropout(attn) + out = einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.fc(out) + return out, attn diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index 8c30bbd..288d2aa 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -2,9 +2,9 @@ type: VQVAE args: in_channels: 1 - channels: [32, 64, 64] - kernel_sizes: [4, 4, 4] - strides: [2, 2, 2] + channels: [64, 96] + kernel_sizes: [4, 4] + strides: [2, 2] num_residual_layers: 2 embedding_dim: 64 num_embeddings: 256 -- cgit v1.2.3-70-g09d2