summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_perceiver.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conv_perceiver.py')
-rw-r--r--text_recognizer/networks/conv_perceiver.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/text_recognizer/networks/conv_perceiver.py b/text_recognizer/networks/conv_perceiver.py
index cda5e0a..551f04f 100644
--- a/text_recognizer/networks/conv_perceiver.py
+++ b/text_recognizer/networks/conv_perceiver.py
@@ -1,9 +1,14 @@
"""Perceiver network module."""
from typing import Optional, Tuple, Type
+from einops import repeat
+import torch
from torch import nn, Tensor
from text_recognizer.networks.perceiver.perceiver import PerceiverIO
+from text_recognizer.networks.transformer.embeddings.absolute import (
+ AbsolutePositionalEmbedding,
+)
from text_recognizer.networks.transformer.embeddings.axial import (
AxialPositionalEmbedding,
)
@@ -17,12 +22,14 @@ class ConvPerceiver(nn.Module):
input_dims: Tuple[int, int, int],
hidden_dim: int,
queries_dim: int,
+ num_queries: int,
num_classes: int,
pad_index: Tensor,
encoder: Type[nn.Module],
decoder: PerceiverIO,
max_length: int,
pixel_embedding: AxialPositionalEmbedding,
+ query_pos_emb: AbsolutePositionalEmbedding,
) -> None:
super().__init__()
self.input_dims = input_dims
@@ -33,25 +40,22 @@ class ConvPerceiver(nn.Module):
self.encoder = encoder
self.decoder = decoder
self.pixel_embedding = pixel_embedding
- self.to_queries = nn.Linear(self.hidden_dim, queries_dim)
- self.conv = nn.Conv2d(
- in_channels=self.encoder.out_channels,
- out_channels=self.hidden_dim,
- kernel_size=1,
- )
+ self.query_pos_emb = query_pos_emb
+ self.queries = nn.Parameter(torch.randn(num_queries, queries_dim))
def encode(self, x: Tensor) -> Tensor:
z = self.encoder(x)
- z = self.conv(z)
- z = self.pixel_embedding(z)
+ z = torch.concat([z, self.pixel_embedding(z)], dim=1)
z = z.flatten(start_dim=2)
-
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
z = z.permute(0, 2, 1)
return z
def decode(self, z: Tensor) -> Tensor:
- queries = self.to_queries(z[:, : self.max_length, :])
+ b = z.shape[0]
+ queries = repeat(self.queries, "n d -> b n d", b=b)
+ pos_emb = repeat(self.query_pos_emb(queries), "n d -> b n d", b=b)
+ queries = torch.concat([queries, pos_emb], dim=-1)
logits = self.decoder(data=z, queries=queries)
logits = logits.permute(0, 2, 1) # [B, C, Sy]
return logits