diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/perceiver/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/networks/perceiver/attention.py | 40 | ||||
-rw-r--r-- | text_recognizer/networks/perceiver/perceiver.py | 91 |
3 files changed, 0 insertions, 132 deletions
diff --git a/text_recognizer/networks/perceiver/__init__.py b/text_recognizer/networks/perceiver/__init__.py deleted file mode 100644 index ac2c102..0000000 --- a/text_recognizer/networks/perceiver/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from text_recognizer.networks.perceiver.perceiver import PerceiverIO diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py deleted file mode 100644 index 19e3e17..0000000 --- a/text_recognizer/networks/perceiver/attention.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Attention module.""" -from typing import Optional - -from einops import rearrange, repeat -import torch -from torch import einsum, nn, Tensor -import torch.nn.functional as F - - -class Attention(nn.Module): - def __init__( - self, - query_dim: int, - context_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - ) -> None: - super().__init__() - inner_dim = heads * dim_head - context_dim = context_dim if context_dim is not None else query_dim - self.scale = dim_head ** -0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_kv = nn.Linear(context_dim, 2 * inner_dim, bias=False) - self.to_out = nn.Linear(inner_dim, query_dim, bias=False) - - def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor: - h = self.heads - q = self.to_q(x) - context = context if context is not None else x - k, v = self.to_kv(context).chunk(2, dim=-1) - - q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - sim = einsum("b i d, b j d -> b i j", q, k) * self.scale - - attn = sim.softmax(dim=-1) - out = einsum("b i j, b j d -> b i d", attn, v) - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) - return self.to_out(out) diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py deleted file mode 100644 index 5b4ab26..0000000 --- a/text_recognizer/networks/perceiver/perceiver.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Perceiver IO. - -A copy from lucidrains. -""" -from typing import Optional - -from einops import repeat, rearrange -import torch -from torch import nn, Tensor - -from text_recognizer.networks.perceiver.attention import Attention -from text_recognizer.networks.transformer.ff import FeedForward -from text_recognizer.networks.transformer.norm import PreNorm - - -class PerceiverIO(nn.Module): - def __init__( - self, - dim: int, - cross_heads: int, - cross_head_dim: int, - num_latents: int, - latent_dim: int, - latent_heads: int, - depth: int, - queries_dim: int, - logits_dim: int, - ) -> None: - super().__init__() - self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) - - self.cross_attn_block = nn.ModuleList( - [ - PreNorm( - latent_dim, - Attention( - latent_dim, dim, heads=cross_heads, dim_head=cross_head_dim - ), - context_dim=dim, - ), - PreNorm(latent_dim, FeedForward(latent_dim)), - ] - ) - - self.layers = nn.ModuleList( - [ - nn.ModuleList( - [ - PreNorm( - latent_dim, - Attention( - latent_dim, heads=latent_heads, dim_head=latent_dim - ), - ), - PreNorm(latent_dim, FeedForward(latent_dim)), - ] - ) - for _ in range(depth) - ] - ) - - self.decoder_cross_attn = PreNorm( - queries_dim, - Attention( - queries_dim, latent_dim, heads=cross_heads, dim_head=cross_head_dim - ), - context_dim=latent_dim, - ) - self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) - self.to_logits = nn.Linear(queries_dim, logits_dim) - - def forward(self, data: Tensor, queries: Tensor) -> Tensor: - b = data.shape[0] - x = repeat(self.latents, "n d -> b n d", b=b) - - cross_attn, cross_ff = self.cross_attn_block - - x = cross_attn(x, context=data) + x - x = cross_ff(x) + x - - for attn, ff in self.layers: - x = attn(x) + x - x = ff(x) + x - - if queries.ndim == 2: - queries = repeat(queries, "nd->bnd", b=b) - - latents = self.decoder_cross_attn(queries, context=x) - latents = latents + self.decoder_ff(latents) - - return self.to_logits(latents) |