diff options
Diffstat (limited to 'text_recognizer/networks/perceiver/perceiver.py')
-rw-r--r-- | text_recognizer/networks/perceiver/perceiver.py | 91 |
1 files changed, 0 insertions, 91 deletions
diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py deleted file mode 100644 index 5b4ab26..0000000 --- a/text_recognizer/networks/perceiver/perceiver.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Perceiver IO. - -A copy from lucidrains. -""" -from typing import Optional - -from einops import repeat, rearrange -import torch -from torch import nn, Tensor - -from text_recognizer.networks.perceiver.attention import Attention -from text_recognizer.networks.transformer.ff import FeedForward -from text_recognizer.networks.transformer.norm import PreNorm - - -class PerceiverIO(nn.Module): - def __init__( - self, - dim: int, - cross_heads: int, - cross_head_dim: int, - num_latents: int, - latent_dim: int, - latent_heads: int, - depth: int, - queries_dim: int, - logits_dim: int, - ) -> None: - super().__init__() - self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) - - self.cross_attn_block = nn.ModuleList( - [ - PreNorm( - latent_dim, - Attention( - latent_dim, dim, heads=cross_heads, dim_head=cross_head_dim - ), - context_dim=dim, - ), - PreNorm(latent_dim, FeedForward(latent_dim)), - ] - ) - - self.layers = nn.ModuleList( - [ - nn.ModuleList( - [ - PreNorm( - latent_dim, - Attention( - latent_dim, heads=latent_heads, dim_head=latent_dim - ), - ), - PreNorm(latent_dim, FeedForward(latent_dim)), - ] - ) - for _ in range(depth) - ] - ) - - self.decoder_cross_attn = PreNorm( - queries_dim, - Attention( - queries_dim, latent_dim, heads=cross_heads, dim_head=cross_head_dim - ), - context_dim=latent_dim, - ) - self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) - self.to_logits = nn.Linear(queries_dim, logits_dim) - - def forward(self, data: Tensor, queries: Tensor) -> Tensor: - b = data.shape[0] - x = repeat(self.latents, "n d -> b n d", b=b) - - cross_attn, cross_ff = self.cross_attn_block - - x = cross_attn(x, context=data) + x - x = cross_ff(x) + x - - for attn, ff in self.layers: - x = attn(x) + x - x = ff(x) + x - - if queries.ndim == 2: - queries = repeat(queries, "nd->bnd", b=b) - - latents = self.decoder_cross_attn(queries, context=x) - latents = latents + self.decoder_ff(latents) - - return self.to_logits(latents) |