summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-03 12:13:02 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-03 12:13:02 +0200
commit73ccaaa24936faed36fcc467532baa5386d402ae (patch)
treec7230fff21b8a780c2b0cd8a5d610075cbb7f21e /text_recognizer/networks
parent5dd76ca9a3ff35c57cbc7c607afbdb4ee1c8b36f (diff)
Update perceiver
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/conv_perceiver.py62
-rw-r--r--text_recognizer/networks/perceiver/__init__.py1
-rw-r--r--text_recognizer/networks/perceiver/attention.py10
-rw-r--r--text_recognizer/networks/perceiver/perceiver.py22
4 files changed, 81 insertions, 14 deletions
diff --git a/text_recognizer/networks/conv_perceiver.py b/text_recognizer/networks/conv_perceiver.py
new file mode 100644
index 0000000..cda5e0a
--- /dev/null
+++ b/text_recognizer/networks/conv_perceiver.py
@@ -0,0 +1,62 @@
+"""Perceiver network module."""
+from typing import Optional, Tuple, Type
+
+from torch import nn, Tensor
+
+from text_recognizer.networks.perceiver.perceiver import PerceiverIO
+from text_recognizer.networks.transformer.embeddings.axial import (
+ AxialPositionalEmbedding,
+)
+
+
+class ConvPerceiver(nn.Module):
+ """Base transformer network."""
+
+ def __init__(
+ self,
+ input_dims: Tuple[int, int, int],
+ hidden_dim: int,
+ queries_dim: int,
+ num_classes: int,
+ pad_index: Tensor,
+ encoder: Type[nn.Module],
+ decoder: PerceiverIO,
+ max_length: int,
+ pixel_embedding: AxialPositionalEmbedding,
+ ) -> None:
+ super().__init__()
+ self.input_dims = input_dims
+ self.hidden_dim = hidden_dim
+ self.num_classes = num_classes
+ self.pad_index = pad_index
+ self.max_length = max_length
+ self.encoder = encoder
+ self.decoder = decoder
+ self.pixel_embedding = pixel_embedding
+ self.to_queries = nn.Linear(self.hidden_dim, queries_dim)
+ self.conv = nn.Conv2d(
+ in_channels=self.encoder.out_channels,
+ out_channels=self.hidden_dim,
+ kernel_size=1,
+ )
+
+ def encode(self, x: Tensor) -> Tensor:
+ z = self.encoder(x)
+ z = self.conv(z)
+ z = self.pixel_embedding(z)
+ z = z.flatten(start_dim=2)
+
+ # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
+ z = z.permute(0, 2, 1)
+ return z
+
+ def decode(self, z: Tensor) -> Tensor:
+ queries = self.to_queries(z[:, : self.max_length, :])
+ logits = self.decoder(data=z, queries=queries)
+ logits = logits.permute(0, 2, 1) # [B, C, Sy]
+ return logits
+
+ def forward(self, x: Tensor) -> Tensor:
+ z = self.encode(x)
+ logits = self.decode(z)
+ return logits
diff --git a/text_recognizer/networks/perceiver/__init__.py b/text_recognizer/networks/perceiver/__init__.py
index e69de29..ac2c102 100644
--- a/text_recognizer/networks/perceiver/__init__.py
+++ b/text_recognizer/networks/perceiver/__init__.py
@@ -0,0 +1 @@
+from text_recognizer.networks.perceiver.perceiver import PerceiverIO
diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py
index 66aeaa8..0ee51b1 100644
--- a/text_recognizer/networks/perceiver/attention.py
+++ b/text_recognizer/networks/perceiver/attention.py
@@ -36,11 +36,11 @@ class Attention(nn.Module):
q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
- if mask is not None:
- mask = rearrange(mask, "b ... -> b (...)")
- max_neg_val = -torch.finfo(sim.dtype).max
- mask = repeat(mask, "b j -> (b h) () j", h=h)
- sim.masked_fill_(~mask, max_neg_val)
+ # if mask is not None:
+ # mask = rearrange(mask, "b ... -> b (...)")
+ # max_neg_val = -torch.finfo(sim.dtype).max
+ # mask = repeat(mask, "b j -> (b h) () j", h=h)
+ # sim.masked_fill_(~mask, max_neg_val)
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py
index 65ee20c..d4bca0b 100644
--- a/text_recognizer/networks/perceiver/perceiver.py
+++ b/text_recognizer/networks/perceiver/perceiver.py
@@ -2,9 +2,9 @@
A copy from lucidrains.
"""
-from itertools import repeat
from typing import Optional
+from einops import repeat, rearrange
import torch
from torch import nn, Tensor
@@ -44,13 +44,17 @@ class PerceiverIO(nn.Module):
self.layers = nn.ModuleList(
[
- [
- PreNorm(
- latent_dim,
- Attention(latent_dim, heads=latent_heads, dim_head=latent_dim),
- ),
- PreNorm(latent_dim, FeedForward(latent_dim)),
- ]
+ nn.ModuleList(
+ [
+ PreNorm(
+ latent_dim,
+ Attention(
+ latent_dim, heads=latent_heads, dim_head=latent_dim
+ ),
+ ),
+ PreNorm(latent_dim, FeedForward(latent_dim)),
+ ]
+ )
for _ in range(depth)
]
)
@@ -69,7 +73,7 @@ class PerceiverIO(nn.Module):
self, data: Tensor, queries: Tensor, mask: Optional[Tensor] = None
) -> Tensor:
b = data.shape[0]
- x = repeat(self.latents, "nd -> bnd", b=b)
+ x = repeat(self.latents, "n d -> b n d", b=b)
cross_attn, cross_ff = self.cross_attn_block