diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-03 12:13:02 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-03 12:13:02 +0200 |
commit | 73ccaaa24936faed36fcc467532baa5386d402ae (patch) | |
tree | c7230fff21b8a780c2b0cd8a5d610075cbb7f21e /text_recognizer/networks/perceiver/perceiver.py | |
parent | 5dd76ca9a3ff35c57cbc7c607afbdb4ee1c8b36f (diff) |
Update perceiver
Diffstat (limited to 'text_recognizer/networks/perceiver/perceiver.py')
-rw-r--r-- | text_recognizer/networks/perceiver/perceiver.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py index 65ee20c..d4bca0b 100644 --- a/text_recognizer/networks/perceiver/perceiver.py +++ b/text_recognizer/networks/perceiver/perceiver.py @@ -2,9 +2,9 @@ A copy from lucidrains. """ -from itertools import repeat from typing import Optional +from einops import repeat, rearrange import torch from torch import nn, Tensor @@ -44,13 +44,17 @@ class PerceiverIO(nn.Module): self.layers = nn.ModuleList( [ - [ - PreNorm( - latent_dim, - Attention(latent_dim, heads=latent_heads, dim_head=latent_dim), - ), - PreNorm(latent_dim, FeedForward(latent_dim)), - ] + nn.ModuleList( + [ + PreNorm( + latent_dim, + Attention( + latent_dim, heads=latent_heads, dim_head=latent_dim + ), + ), + PreNorm(latent_dim, FeedForward(latent_dim)), + ] + ) for _ in range(depth) ] ) @@ -69,7 +73,7 @@ class PerceiverIO(nn.Module): self, data: Tensor, queries: Tensor, mask: Optional[Tensor] = None ) -> Tensor: b = data.shape[0] - x = repeat(self.latents, "nd -> bnd", b=b) + x = repeat(self.latents, "n d -> b n d", b=b) cross_attn, cross_ff = self.cross_attn_block |