summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--notebooks/00-testing-stuff-out.ipynb163
-rw-r--r--text_recognizer/models/vqvae.py12
-rw-r--r--text_recognizer/networks/transformer/attention.py72
-rw-r--r--training/conf/network/vqvae.yaml6
4 files changed, 233 insertions, 20 deletions
diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb
index 7c7b3a6..92faaf7 100644
--- a/notebooks/00-testing-stuff-out.ipynb
+++ b/notebooks/00-testing-stuff-out.ipynb
@@ -358,6 +358,169 @@
"metadata": {},
"outputs": [],
"source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b, n = 16, 128\n",
+ "device = \"cpu\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = lambda: torch.ones((b, n), device=device).bool()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([16, 128])"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x().shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([16, 128])"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.ones((b, n), device=device).bool().shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = torch.randn(1, 1, 576, 640)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "144"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "576 // 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "160"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "640 // 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = torch.randn(1, 1, 144, 160)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from einops import rearrange"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "patch_size=4\n",
+ "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 1440, 16])"
+ ]
+ },
+ "execution_count": 36,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "p.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index ef2213c..078235e 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -29,12 +29,14 @@ class LitVQVAEModel(LitBaseModel):
"""Forward pass with the transformer network."""
return self.network.predict(data)
- def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None:
+ def _log_prediction(
+ self, data: Tensor, reconstructions: Tensor, title: str
+ ) -> None:
"""Logs prediction on image with wandb."""
try:
self.logger.experiment.log(
{
- "val_pred_examples": [
+ title: [
wandb.Image(data[0]),
wandb.Image(reconstructions[0]),
]
@@ -59,7 +61,8 @@ class LitVQVAEModel(LitBaseModel):
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
self.log("val_loss", loss, prog_bar=True)
- self._log_prediction(data, reconstructions)
+ title = "val_pred_examples"
+ self._log_prediction(data, reconstructions, title)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
@@ -67,4 +70,5 @@ class LitVQVAEModel(LitBaseModel):
reconstructions, vq_loss = self.network(data)
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
- self._log_prediction(data, reconstructions)
+ title = "test_pred_examples"
+ self._log_prediction(data, reconstructions, title)
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 8724691..623d680 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,9 +1,10 @@
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
+from einops import rearrange
from einops.layers.torch import Rearrange
-import numpy as np
import torch
+from torch import einsum
from torch import nn
from torch import Tensor
import torch.nn.functional as F
@@ -34,7 +35,7 @@ class Attention(nn.Module):
self.attn_fn = F.softmax
# Feedforward
- self.proj = nn.Linear(inner_dim, dim)
+ self.fc = nn.Linear(inner_dim, dim)
@staticmethod
def _apply_rotary_emb(
@@ -47,8 +48,42 @@ class Attention(nn.Module):
k = torch.cat((kl, kr), dim=-1)
return q, k
- def _cross_attention(self) -> Tensor:
- pass
+ @staticmethod
+ def _compute_input_mask(
+ b: int,
+ n: int,
+ k: Tensor,
+ mask: Optional[Tensor],
+ context: Optional[Tensor],
+ context_mask: Optional[Tensor],
+ device: str,
+ ) -> Optional[Tensor]:
+ if any(x is not None for x in (mask, context_mask)):
+ q_mask = (
+ mask if mask is not None else torch.ones((b, n), device=device).bool()
+ )
+ k_mask = q_mask if context is not None else context_mask
+ k_mask = (
+ torch.ones((b, k.shape[-2]), device=device).bool()
+ if k_mask is None
+ else k_mask
+ )
+ q_mask = rearrange(q_mask, "b i -> b () i ()")
+ k_mask = rearrange(k_mask, "b i -> b () () j")
+ return q_mask * k_mask
+ return
+
+ @staticmethod
+ def _apply_causal_mask(
+ energy: Tensor, mask: Tensor, mask_value: Tensor, device: str
+ ) -> Tensor:
+ i, j = energy.shape[-2:]
+ r = torch.arange(i, device=device)
+ mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
+ mask = F.pad(mask, (j - i, 0), value=False)
+ energy.masked_fill_(mask, mask_value)
+ del mask
+ return energy
def forward(
self,
@@ -67,14 +102,25 @@ class Attention(nn.Module):
k,
)
- input_mask = None
- if any(x is not None for x in (mask, context_mask)):
- q_mask = (
- mask
- if mask is not None
- else lambda: torch.ones((b, n), device=device).bool()
- )
- pass
+ input_mask = self._compute_input_mask(
+ b, n, k, mask, context, context_mask, device
+ )
# Compute the attention
- energy = (q @ k.transpose(-2, -1)) * self.scale
+ energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
+ mask_value = -torch.finfo(energy.dtype).max
+
+ # Apply input mask
+ if input_mask is not None:
+ energy = energy.masked_fill_(~input_mask, mask_value)
+ del input_mask
+
+ if self.causal:
+ energy = self._apply_causal_mask(energy, mask, mask_value, device)
+
+ attn = self.attn_fn(energy, dim=-1)
+ attn = self.dropout(attn)
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ out = self.fc(out)
+ return out, attn
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
index 8c30bbd..288d2aa 100644
--- a/training/conf/network/vqvae.yaml
+++ b/training/conf/network/vqvae.yaml
@@ -2,9 +2,9 @@
type: VQVAE
args:
in_channels: 1
- channels: [32, 64, 64]
- kernel_sizes: [4, 4, 4]
- strides: [2, 2, 2]
+ channels: [64, 96]
+ kernel_sizes: [4, 4]
+ strides: [2, 2]
num_residual_layers: 2
embedding_dim: 64
num_embeddings: 256