diff options
Diffstat (limited to 'text_recognizer/networks/perceiver/perceiver.py')
-rw-r--r-- | text_recognizer/networks/perceiver/perceiver.py | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py new file mode 100644 index 0000000..65ee20c --- /dev/null +++ b/text_recognizer/networks/perceiver/perceiver.py @@ -0,0 +1,89 @@ +"""Perceiver IO. + +A copy from lucidrains. +""" +from itertools import repeat +from typing import Optional + +import torch +from torch import nn, Tensor + +from text_recognizer.networks.perceiver.attention import Attention +from text_recognizer.networks.transformer.ff import FeedForward +from text_recognizer.networks.transformer.norm import PreNorm + + +class PerceiverIO(nn.Module): + def __init__( + self, + dim: int, + cross_heads: int, + cross_head_dim: int, + num_latents: int, + latent_dim: int, + latent_heads: int, + depth: int, + queries_dim: int, + logits_dim: int, + ) -> None: + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) + + self.cross_attn_block = nn.ModuleList( + [ + PreNorm( + latent_dim, + Attention( + latent_dim, dim, heads=cross_heads, dim_head=cross_head_dim + ), + context_dim=dim, + ), + PreNorm(latent_dim, FeedForward(latent_dim)), + ] + ) + + self.layers = nn.ModuleList( + [ + [ + PreNorm( + latent_dim, + Attention(latent_dim, heads=latent_heads, dim_head=latent_dim), + ), + PreNorm(latent_dim, FeedForward(latent_dim)), + ] + for _ in range(depth) + ] + ) + + self.decoder_cross_attn = PreNorm( + queries_dim, + Attention( + queries_dim, latent_dim, heads=cross_heads, dim_head=cross_head_dim + ), + context_dim=latent_dim, + ) + self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) + self.to_logits = nn.Linear(queries_dim, logits_dim) + + def forward( + self, data: Tensor, queries: Tensor, mask: Optional[Tensor] = None + ) -> Tensor: + b = data.shape[0] + x = repeat(self.latents, "nd -> bnd", b=b) + + cross_attn, cross_ff = self.cross_attn_block + + x = cross_attn(x, context=data, mask=mask) + x + x = cross_ff(x) + x + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + if queries.ndim == 2: + queries = repeat(queries, "nd->bnd", b=b) + + latents = self.decoder_cross_attn(queries, context=x) + latents = latents + self.decoder_ff(latents) + + return self.to_logits(latents) |