summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-21 21:35:34 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-21 21:35:34 +0100
commita9363f3944f1ad31590c48d5d51c45df3bbf43b1 (patch)
tree3e1bedf9220c95a1b6888b9e5ca1c57b59bd98f4 /text_recognizer/networks/conv_transformer.py
parentce99fd904576b8daeb2985f3341793c2a33e9d45 (diff)
Add axial encoder to conv transformer
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index b554695..da99bbf 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,10 +1,14 @@
"""Vision transformer for character recognition."""
import math
-from typing import Optional, Tuple, Type
+from typing import List, Optional, Tuple, Type
from loguru import logger as log
from torch import nn, Tensor
+from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder
+from text_recognizer.networks.transformer.embeddings.axial import (
+ AxialPositionalEmbedding,
+)
from text_recognizer.networks.transformer.layers import Decoder
@@ -19,7 +23,8 @@ class ConvTransformer(nn.Module):
pad_index: Tensor,
encoder: nn.Module,
decoder: Decoder,
- pixel_pos_embedding: Type[nn.Module],
+ axial_encoder: Optional[AxialEncoder],
+ pixel_pos_embedding: AxialPositionalEmbedding,
token_pos_embedding: Optional[Type[nn.Module]] = None,
) -> None:
super().__init__()
@@ -29,6 +34,7 @@ class ConvTransformer(nn.Module):
self.pad_index = pad_index
self.encoder = encoder
self.decoder = decoder
+ self.axial_encoder = axial_encoder
self.pixel_pos_embedding = pixel_pos_embedding
# Latent projector for down sampling number of filters and 2d
@@ -86,6 +92,7 @@ class ConvTransformer(nn.Module):
z = self.encoder(x)
z = self.conv(z)
z = self.pixel_pos_embedding(z)
+ z = self.axial_encoder(z) if self.axial_encoder is not None else z
z = z.flatten(start_dim=2)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]