diff options
-rw-r--r-- | text_recognizer/networks/perceiver/__init__.py | 0 | ||||
-rw-r--r-- | text_recognizer/networks/perceiver/attention.py | 48 | ||||
-rw-r--r-- | text_recognizer/networks/perceiver/perceiver.py | 89 |
3 files changed, 137 insertions, 0 deletions
diff --git a/text_recognizer/networks/perceiver/__init__.py b/text_recognizer/networks/perceiver/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/text_recognizer/networks/perceiver/__init__.py diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py new file mode 100644 index 0000000..66aeaa8 --- /dev/null +++ b/text_recognizer/networks/perceiver/attention.py @@ -0,0 +1,48 @@ +"""Attention module.""" +from typing import Optional + +from einops import rearrange, repeat +import torch +from torch import einsum, nn, Tensor +import torch.nn.functional as F + + +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + context_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + ) -> None: + super().__init__() + inner_dim = heads * dim_head + context_dim = context_dim if context_dim is not None else query_dim + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, 2 * inner_dim, bias=False) + self.to_out = nn.Linear(inner_dim, query_dim, bias=False) + + def forward( + self, x: Tensor, context: Optional[Tensor] = None, mask=Optional[Tensor] + ) -> Tensor: + h = self.heads + q = self.to_q(x) + context = context if context is not None else x + k, v = self.to_kv(context).chunk(2, dim=-1) + + q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, "b ... -> b (...)") + max_neg_val = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_val) + + attn = sim.softmax(dim=-1) + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py new file mode 100644 index 0000000..65ee20c --- /dev/null +++ b/text_recognizer/networks/perceiver/perceiver.py @@ -0,0 +1,89 @@ +"""Perceiver IO. + +A copy from lucidrains. +""" +from itertools import repeat +from typing import Optional + +import torch +from torch import nn, Tensor + +from text_recognizer.networks.perceiver.attention import Attention +from text_recognizer.networks.transformer.ff import FeedForward +from text_recognizer.networks.transformer.norm import PreNorm + + +class PerceiverIO(nn.Module): + def __init__( + self, + dim: int, + cross_heads: int, + cross_head_dim: int, + num_latents: int, + latent_dim: int, + latent_heads: int, + depth: int, + queries_dim: int, + logits_dim: int, + ) -> None: + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) + + self.cross_attn_block = nn.ModuleList( + [ + PreNorm( + latent_dim, + Attention( + latent_dim, dim, heads=cross_heads, dim_head=cross_head_dim + ), + context_dim=dim, + ), + PreNorm(latent_dim, FeedForward(latent_dim)), + ] + ) + + self.layers = nn.ModuleList( + [ + [ + PreNorm( + latent_dim, + Attention(latent_dim, heads=latent_heads, dim_head=latent_dim), + ), + PreNorm(latent_dim, FeedForward(latent_dim)), + ] + for _ in range(depth) + ] + ) + + self.decoder_cross_attn = PreNorm( + queries_dim, + Attention( + queries_dim, latent_dim, heads=cross_heads, dim_head=cross_head_dim + ), + context_dim=latent_dim, + ) + self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) + self.to_logits = nn.Linear(queries_dim, logits_dim) + + def forward( + self, data: Tensor, queries: Tensor, mask: Optional[Tensor] = None + ) -> Tensor: + b = data.shape[0] + x = repeat(self.latents, "nd -> bnd", b=b) + + cross_attn, cross_ff = self.cross_attn_block + + x = cross_attn(x, context=data, mask=mask) + x + x = cross_ff(x) + x + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + if queries.ndim == 2: + queries = repeat(queries, "nd->bnd", b=b) + + latents = self.decoder_cross_attn(queries, context=x) + latents = latents + self.decoder_ff(latents) + + return self.to_logits(latents) |