summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/perceiver
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/perceiver')
-rw-r--r--text_recognizer/networks/perceiver/__init__.py1
-rw-r--r--text_recognizer/networks/perceiver/attention.py40
-rw-r--r--text_recognizer/networks/perceiver/perceiver.py91
3 files changed, 0 insertions, 132 deletions
diff --git a/text_recognizer/networks/perceiver/__init__.py b/text_recognizer/networks/perceiver/__init__.py
deleted file mode 100644
index ac2c102..0000000
--- a/text_recognizer/networks/perceiver/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from text_recognizer.networks.perceiver.perceiver import PerceiverIO
diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py
deleted file mode 100644
index 19e3e17..0000000
--- a/text_recognizer/networks/perceiver/attention.py
+++ /dev/null
@@ -1,40 +0,0 @@
-"""Attention module."""
-from typing import Optional
-
-from einops import rearrange, repeat
-import torch
-from torch import einsum, nn, Tensor
-import torch.nn.functional as F
-
-
-class Attention(nn.Module):
- def __init__(
- self,
- query_dim: int,
- context_dim: Optional[int] = None,
- heads: int = 8,
- dim_head: int = 64,
- ) -> None:
- super().__init__()
- inner_dim = heads * dim_head
- context_dim = context_dim if context_dim is not None else query_dim
- self.scale = dim_head ** -0.5
- self.heads = heads
-
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_kv = nn.Linear(context_dim, 2 * inner_dim, bias=False)
- self.to_out = nn.Linear(inner_dim, query_dim, bias=False)
-
- def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor:
- h = self.heads
- q = self.to_q(x)
- context = context if context is not None else x
- k, v = self.to_kv(context).chunk(2, dim=-1)
-
- q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
- sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
-
- attn = sim.softmax(dim=-1)
- out = einsum("b i j, b j d -> b i d", attn, v)
- out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
- return self.to_out(out)
diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py
deleted file mode 100644
index 5b4ab26..0000000
--- a/text_recognizer/networks/perceiver/perceiver.py
+++ /dev/null
@@ -1,91 +0,0 @@
-"""Perceiver IO.
-
-A copy from lucidrains.
-"""
-from typing import Optional
-
-from einops import repeat, rearrange
-import torch
-from torch import nn, Tensor
-
-from text_recognizer.networks.perceiver.attention import Attention
-from text_recognizer.networks.transformer.ff import FeedForward
-from text_recognizer.networks.transformer.norm import PreNorm
-
-
-class PerceiverIO(nn.Module):
- def __init__(
- self,
- dim: int,
- cross_heads: int,
- cross_head_dim: int,
- num_latents: int,
- latent_dim: int,
- latent_heads: int,
- depth: int,
- queries_dim: int,
- logits_dim: int,
- ) -> None:
- super().__init__()
- self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
-
- self.cross_attn_block = nn.ModuleList(
- [
- PreNorm(
- latent_dim,
- Attention(
- latent_dim, dim, heads=cross_heads, dim_head=cross_head_dim
- ),
- context_dim=dim,
- ),
- PreNorm(latent_dim, FeedForward(latent_dim)),
- ]
- )
-
- self.layers = nn.ModuleList(
- [
- nn.ModuleList(
- [
- PreNorm(
- latent_dim,
- Attention(
- latent_dim, heads=latent_heads, dim_head=latent_dim
- ),
- ),
- PreNorm(latent_dim, FeedForward(latent_dim)),
- ]
- )
- for _ in range(depth)
- ]
- )
-
- self.decoder_cross_attn = PreNorm(
- queries_dim,
- Attention(
- queries_dim, latent_dim, heads=cross_heads, dim_head=cross_head_dim
- ),
- context_dim=latent_dim,
- )
- self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim))
- self.to_logits = nn.Linear(queries_dim, logits_dim)
-
- def forward(self, data: Tensor, queries: Tensor) -> Tensor:
- b = data.shape[0]
- x = repeat(self.latents, "n d -> b n d", b=b)
-
- cross_attn, cross_ff = self.cross_attn_block
-
- x = cross_attn(x, context=data) + x
- x = cross_ff(x) + x
-
- for attn, ff in self.layers:
- x = attn(x) + x
- x = ff(x) + x
-
- if queries.ndim == 2:
- queries = repeat(queries, "nd->bnd", b=b)
-
- latents = self.decoder_cross_attn(queries, context=x)
- latents = latents + self.decoder_ff(latents)
-
- return self.to_logits(latents)