From 16e2e420e077253c3b2bc414283281fea557717d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Sep 2022 00:07:13 +0200 Subject: Update conv perceiver --- text_recognizer/networks/__init__.py | 1 + text_recognizer/networks/conv_perceiver.py | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index e2d6fd5..99f5d1d 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,4 +1,5 @@ """Network modules""" +from text_recognizer.networks.conv_perceiver import ConvPerceiver from text_recognizer.networks.conv_transformer import ConvTransformer from text_recognizer.networks.efficientnet.efficientnet import EfficientNet from text_recognizer.networks.vq_transformer import VqTransformer diff --git a/text_recognizer/networks/conv_perceiver.py b/text_recognizer/networks/conv_perceiver.py index cda5e0a..551f04f 100644 --- a/text_recognizer/networks/conv_perceiver.py +++ b/text_recognizer/networks/conv_perceiver.py @@ -1,9 +1,14 @@ """Perceiver network module.""" from typing import Optional, Tuple, Type +from einops import repeat +import torch from torch import nn, Tensor from text_recognizer.networks.perceiver.perceiver import PerceiverIO +from text_recognizer.networks.transformer.embeddings.absolute import ( + AbsolutePositionalEmbedding, +) from text_recognizer.networks.transformer.embeddings.axial import ( AxialPositionalEmbedding, ) @@ -17,12 +22,14 @@ class ConvPerceiver(nn.Module): input_dims: Tuple[int, int, int], hidden_dim: int, queries_dim: int, + num_queries: int, num_classes: int, pad_index: Tensor, encoder: Type[nn.Module], decoder: PerceiverIO, max_length: int, pixel_embedding: AxialPositionalEmbedding, + query_pos_emb: AbsolutePositionalEmbedding, ) -> None: super().__init__() self.input_dims = input_dims @@ -33,25 +40,22 @@ class ConvPerceiver(nn.Module): self.encoder = encoder self.decoder = decoder self.pixel_embedding = pixel_embedding - self.to_queries = nn.Linear(self.hidden_dim, queries_dim) - self.conv = nn.Conv2d( - in_channels=self.encoder.out_channels, - out_channels=self.hidden_dim, - kernel_size=1, - ) + self.query_pos_emb = query_pos_emb + self.queries = nn.Parameter(torch.randn(num_queries, queries_dim)) def encode(self, x: Tensor) -> Tensor: z = self.encoder(x) - z = self.conv(z) - z = self.pixel_embedding(z) + z = torch.concat([z, self.pixel_embedding(z)], dim=1) z = z.flatten(start_dim=2) - # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] z = z.permute(0, 2, 1) return z def decode(self, z: Tensor) -> Tensor: - queries = self.to_queries(z[:, : self.max_length, :]) + b = z.shape[0] + queries = repeat(self.queries, "n d -> b n d", b=b) + pos_emb = repeat(self.query_pos_emb(queries), "n d -> b n d", b=b) + queries = torch.concat([queries, pos_emb], dim=-1) logits = self.decoder(data=z, queries=queries) logits = logits.permute(0, 2, 1) # [B, C, Sy] return logits -- cgit v1.2.3-70-g09d2