diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:35:34 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:35:34 +0100 |
commit | a9363f3944f1ad31590c48d5d51c45df3bbf43b1 (patch) | |
tree | 3e1bedf9220c95a1b6888b9e5ca1c57b59bd98f4 /text_recognizer/networks | |
parent | ce99fd904576b8daeb2985f3341793c2a33e9d45 (diff) |
Add axial encoder to conv transformer
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index b554695..da99bbf 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,10 +1,14 @@ """Vision transformer for character recognition.""" import math -from typing import Optional, Tuple, Type +from typing import List, Optional, Tuple, Type from loguru import logger as log from torch import nn, Tensor +from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder +from text_recognizer.networks.transformer.embeddings.axial import ( + AxialPositionalEmbedding, +) from text_recognizer.networks.transformer.layers import Decoder @@ -19,7 +23,8 @@ class ConvTransformer(nn.Module): pad_index: Tensor, encoder: nn.Module, decoder: Decoder, - pixel_pos_embedding: Type[nn.Module], + axial_encoder: Optional[AxialEncoder], + pixel_pos_embedding: AxialPositionalEmbedding, token_pos_embedding: Optional[Type[nn.Module]] = None, ) -> None: super().__init__() @@ -29,6 +34,7 @@ class ConvTransformer(nn.Module): self.pad_index = pad_index self.encoder = encoder self.decoder = decoder + self.axial_encoder = axial_encoder self.pixel_pos_embedding = pixel_pos_embedding # Latent projector for down sampling number of filters and 2d @@ -86,6 +92,7 @@ class ConvTransformer(nn.Module): z = self.encoder(x) z = self.conv(z) z = self.pixel_pos_embedding(z) + z = self.axial_encoder(z) if self.axial_encoder is not None else z z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] |