summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_perceiver.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conv_perceiver.py')
-rw-r--r--text_recognizer/networks/conv_perceiver.py66
1 files changed, 0 insertions, 66 deletions
diff --git a/text_recognizer/networks/conv_perceiver.py b/text_recognizer/networks/conv_perceiver.py
deleted file mode 100644
index 551f04f..0000000
--- a/text_recognizer/networks/conv_perceiver.py
+++ /dev/null
@@ -1,66 +0,0 @@
-"""Perceiver network module."""
-from typing import Optional, Tuple, Type
-
-from einops import repeat
-import torch
-from torch import nn, Tensor
-
-from text_recognizer.networks.perceiver.perceiver import PerceiverIO
-from text_recognizer.networks.transformer.embeddings.absolute import (
- AbsolutePositionalEmbedding,
-)
-from text_recognizer.networks.transformer.embeddings.axial import (
- AxialPositionalEmbedding,
-)
-
-
-class ConvPerceiver(nn.Module):
- """Base transformer network."""
-
- def __init__(
- self,
- input_dims: Tuple[int, int, int],
- hidden_dim: int,
- queries_dim: int,
- num_queries: int,
- num_classes: int,
- pad_index: Tensor,
- encoder: Type[nn.Module],
- decoder: PerceiverIO,
- max_length: int,
- pixel_embedding: AxialPositionalEmbedding,
- query_pos_emb: AbsolutePositionalEmbedding,
- ) -> None:
- super().__init__()
- self.input_dims = input_dims
- self.hidden_dim = hidden_dim
- self.num_classes = num_classes
- self.pad_index = pad_index
- self.max_length = max_length
- self.encoder = encoder
- self.decoder = decoder
- self.pixel_embedding = pixel_embedding
- self.query_pos_emb = query_pos_emb
- self.queries = nn.Parameter(torch.randn(num_queries, queries_dim))
-
- def encode(self, x: Tensor) -> Tensor:
- z = self.encoder(x)
- z = torch.concat([z, self.pixel_embedding(z)], dim=1)
- z = z.flatten(start_dim=2)
- # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
- z = z.permute(0, 2, 1)
- return z
-
- def decode(self, z: Tensor) -> Tensor:
- b = z.shape[0]
- queries = repeat(self.queries, "n d -> b n d", b=b)
- pos_emb = repeat(self.query_pos_emb(queries), "n d -> b n d", b=b)
- queries = torch.concat([queries, pos_emb], dim=-1)
- logits = self.decoder(data=z, queries=queries)
- logits = logits.permute(0, 2, 1) # [B, C, Sy]
- return logits
-
- def forward(self, x: Tensor) -> Tensor:
- z = self.encode(x)
- logits = self.decode(z)
- return logits