From 73ccaaa24936faed36fcc467532baa5386d402ae Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 3 Sep 2022 12:13:02 +0200 Subject: Update perceiver --- text_recognizer/networks/perceiver/__init__.py | 1 + text_recognizer/networks/perceiver/attention.py | 10 +++++----- text_recognizer/networks/perceiver/perceiver.py | 22 +++++++++++++--------- 3 files changed, 19 insertions(+), 14 deletions(-) (limited to 'text_recognizer/networks/perceiver') diff --git a/text_recognizer/networks/perceiver/__init__.py b/text_recognizer/networks/perceiver/__init__.py index e69de29..ac2c102 100644 --- a/text_recognizer/networks/perceiver/__init__.py +++ b/text_recognizer/networks/perceiver/__init__.py @@ -0,0 +1 @@ +from text_recognizer.networks.perceiver.perceiver import PerceiverIO diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py index 66aeaa8..0ee51b1 100644 --- a/text_recognizer/networks/perceiver/attention.py +++ b/text_recognizer/networks/perceiver/attention.py @@ -36,11 +36,11 @@ class Attention(nn.Module): q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) sim = einsum("b i d, b j d -> b i j", q, k) * self.scale - if mask is not None: - mask = rearrange(mask, "b ... -> b (...)") - max_neg_val = -torch.finfo(sim.dtype).max - mask = repeat(mask, "b j -> (b h) () j", h=h) - sim.masked_fill_(~mask, max_neg_val) + # if mask is not None: + # mask = rearrange(mask, "b ... -> b (...)") + # max_neg_val = -torch.finfo(sim.dtype).max + # mask = repeat(mask, "b j -> (b h) () j", h=h) + # sim.masked_fill_(~mask, max_neg_val) attn = sim.softmax(dim=-1) out = einsum("b i j, b j d -> b i d", attn, v) diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py index 65ee20c..d4bca0b 100644 --- a/text_recognizer/networks/perceiver/perceiver.py +++ b/text_recognizer/networks/perceiver/perceiver.py @@ -2,9 +2,9 @@ A copy from lucidrains. """ -from itertools import repeat from typing import Optional +from einops import repeat, rearrange import torch from torch import nn, Tensor @@ -44,13 +44,17 @@ class PerceiverIO(nn.Module): self.layers = nn.ModuleList( [ - [ - PreNorm( - latent_dim, - Attention(latent_dim, heads=latent_heads, dim_head=latent_dim), - ), - PreNorm(latent_dim, FeedForward(latent_dim)), - ] + nn.ModuleList( + [ + PreNorm( + latent_dim, + Attention( + latent_dim, heads=latent_heads, dim_head=latent_dim + ), + ), + PreNorm(latent_dim, FeedForward(latent_dim)), + ] + ) for _ in range(depth) ] ) @@ -69,7 +73,7 @@ class PerceiverIO(nn.Module): self, data: Tensor, queries: Tensor, mask: Optional[Tensor] = None ) -> Tensor: b = data.shape[0] - x = repeat(self.latents, "nd -> bnd", b=b) + x = repeat(self.latents, "n d -> b n d", b=b) cross_attn, cross_ff = self.cross_attn_block -- cgit v1.2.3-70-g09d2