summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/perceiver/perceiver.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/perceiver/perceiver.py')
-rw-r--r--text_recognizer/networks/perceiver/perceiver.py91
1 files changed, 0 insertions, 91 deletions
diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py
deleted file mode 100644
index 5b4ab26..0000000
--- a/text_recognizer/networks/perceiver/perceiver.py
+++ /dev/null
@@ -1,91 +0,0 @@
-"""Perceiver IO.
-
-A copy from lucidrains.
-"""
-from typing import Optional
-
-from einops import repeat, rearrange
-import torch
-from torch import nn, Tensor
-
-from text_recognizer.networks.perceiver.attention import Attention
-from text_recognizer.networks.transformer.ff import FeedForward
-from text_recognizer.networks.transformer.norm import PreNorm
-
-
-class PerceiverIO(nn.Module):
- def __init__(
- self,
- dim: int,
- cross_heads: int,
- cross_head_dim: int,
- num_latents: int,
- latent_dim: int,
- latent_heads: int,
- depth: int,
- queries_dim: int,
- logits_dim: int,
- ) -> None:
- super().__init__()
- self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
-
- self.cross_attn_block = nn.ModuleList(
- [
- PreNorm(
- latent_dim,
- Attention(
- latent_dim, dim, heads=cross_heads, dim_head=cross_head_dim
- ),
- context_dim=dim,
- ),
- PreNorm(latent_dim, FeedForward(latent_dim)),
- ]
- )
-
- self.layers = nn.ModuleList(
- [
- nn.ModuleList(
- [
- PreNorm(
- latent_dim,
- Attention(
- latent_dim, heads=latent_heads, dim_head=latent_dim
- ),
- ),
- PreNorm(latent_dim, FeedForward(latent_dim)),
- ]
- )
- for _ in range(depth)
- ]
- )
-
- self.decoder_cross_attn = PreNorm(
- queries_dim,
- Attention(
- queries_dim, latent_dim, heads=cross_heads, dim_head=cross_head_dim
- ),
- context_dim=latent_dim,
- )
- self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim))
- self.to_logits = nn.Linear(queries_dim, logits_dim)
-
- def forward(self, data: Tensor, queries: Tensor) -> Tensor:
- b = data.shape[0]
- x = repeat(self.latents, "n d -> b n d", b=b)
-
- cross_attn, cross_ff = self.cross_attn_block
-
- x = cross_attn(x, context=data) + x
- x = cross_ff(x) + x
-
- for attn, ff in self.layers:
- x = attn(x) + x
- x = ff(x) + x
-
- if queries.ndim == 2:
- queries = repeat(queries, "nd->bnd", b=b)
-
- latents = self.decoder_cross_attn(queries, context=x)
- latents = latents + self.decoder_ff(latents)
-
- return self.to_logits(latents)