diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-05 00:07:13 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-05 00:07:13 +0200 |
commit | 16e2e420e077253c3b2bc414283281fea557717d (patch) | |
tree | e8817ffd57e01abbf50a8e04f46a6e347e7a6650 /text_recognizer/networks/conv_perceiver.py | |
parent | cfb460666953c87f606833bf597b53eba0a2900d (diff) |
Update conv perceiver
Diffstat (limited to 'text_recognizer/networks/conv_perceiver.py')
-rw-r--r-- | text_recognizer/networks/conv_perceiver.py | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/text_recognizer/networks/conv_perceiver.py b/text_recognizer/networks/conv_perceiver.py index cda5e0a..551f04f 100644 --- a/text_recognizer/networks/conv_perceiver.py +++ b/text_recognizer/networks/conv_perceiver.py @@ -1,9 +1,14 @@ """Perceiver network module.""" from typing import Optional, Tuple, Type +from einops import repeat +import torch from torch import nn, Tensor from text_recognizer.networks.perceiver.perceiver import PerceiverIO +from text_recognizer.networks.transformer.embeddings.absolute import ( + AbsolutePositionalEmbedding, +) from text_recognizer.networks.transformer.embeddings.axial import ( AxialPositionalEmbedding, ) @@ -17,12 +22,14 @@ class ConvPerceiver(nn.Module): input_dims: Tuple[int, int, int], hidden_dim: int, queries_dim: int, + num_queries: int, num_classes: int, pad_index: Tensor, encoder: Type[nn.Module], decoder: PerceiverIO, max_length: int, pixel_embedding: AxialPositionalEmbedding, + query_pos_emb: AbsolutePositionalEmbedding, ) -> None: super().__init__() self.input_dims = input_dims @@ -33,25 +40,22 @@ class ConvPerceiver(nn.Module): self.encoder = encoder self.decoder = decoder self.pixel_embedding = pixel_embedding - self.to_queries = nn.Linear(self.hidden_dim, queries_dim) - self.conv = nn.Conv2d( - in_channels=self.encoder.out_channels, - out_channels=self.hidden_dim, - kernel_size=1, - ) + self.query_pos_emb = query_pos_emb + self.queries = nn.Parameter(torch.randn(num_queries, queries_dim)) def encode(self, x: Tensor) -> Tensor: z = self.encoder(x) - z = self.conv(z) - z = self.pixel_embedding(z) + z = torch.concat([z, self.pixel_embedding(z)], dim=1) z = z.flatten(start_dim=2) - # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] z = z.permute(0, 2, 1) return z def decode(self, z: Tensor) -> Tensor: - queries = self.to_queries(z[:, : self.max_length, :]) + b = z.shape[0] + queries = repeat(self.queries, "n d -> b n d", b=b) + pos_emb = repeat(self.query_pos_emb(queries), "n d -> b n d", b=b) + queries = torch.concat([queries, pos_emb], dim=-1) logits = self.decoder(data=z, queries=queries) logits = logits.permute(0, 2, 1) # [B, C, Sy] return logits |