summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-12-05 20:24:30 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-12-05 20:24:30 +0100
commitfc884dd8fb48d11b37527df3c897c819f7eaeeeb (patch)
tree6f6fd75bd536a6b5a0e38bc5a2f6b76db97a9036
parentbc199a6ae36486c0c98e4808e344c90a6dd452a7 (diff)
Update conv transformer with inheritance from base network
-rw-r--r--text_recognizer/networks/conv_transformer.py93
1 files changed, 13 insertions, 80 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index ff98ec6..77e4984 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -5,6 +5,7 @@ from typing import Optional, Tuple, Type
from loguru import logger as log
from torch import nn, Tensor
+from text_recognizer.networks.base import BaseTransformer
from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder
from text_recognizer.networks.transformer.embeddings.axial import (
AxialPositionalEmbedding,
@@ -12,7 +13,7 @@ from text_recognizer.networks.transformer.embeddings.axial import (
from text_recognizer.networks.transformer.layers import Decoder
-class ConvTransformer(nn.Module):
+class ConvTransformer(BaseTransformer):
"""Convolutional encoder and transformer decoder network."""
def __init__(
@@ -27,15 +28,18 @@ class ConvTransformer(nn.Module):
pixel_pos_embedding: AxialPositionalEmbedding,
token_pos_embedding: Optional[Type[nn.Module]] = None,
) -> None:
- super().__init__()
- self.input_dims = input_dims
- self.hidden_dim = hidden_dim
- self.num_classes = num_classes
- self.pad_index = pad_index
- self.encoder = encoder
- self.decoder = decoder
- self.axial_encoder = axial_encoder
+ super().__init__(
+ input_dims,
+ hidden_dim,
+ num_classes,
+ pad_index,
+ encoder,
+ decoder,
+ token_pos_embedding,
+ )
+
self.pixel_pos_embedding = pixel_pos_embedding
+ self.axial_encoder = axial_encoder
# Latent projector for down sampling number of filters and 2d
# positional encoding.
@@ -45,25 +49,6 @@ class ConvTransformer(nn.Module):
kernel_size=1,
)
- # Token embedding.
- self.token_embedding = nn.Embedding(
- num_embeddings=self.num_classes, embedding_dim=self.hidden_dim
- )
-
- # Positional encoding for decoder tokens.
- if not self.decoder.has_pos_emb:
- self.token_pos_embedding = token_pos_embedding
- else:
- self.token_pos_embedding = None
- log.debug("Decoder already have a positional embedding.")
-
- self.norm = nn.LayerNorm(self.hidden_dim)
-
- # Output layer
- self.to_logits = nn.Linear(
- in_features=self.hidden_dim, out_features=self.num_classes
- )
-
# Initalize weights for encoder.
self.init_weights()
@@ -100,55 +85,3 @@ class ConvTransformer(nn.Module):
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
z = z.permute(0, 2, 1)
return z
-
- def decode(self, src: Tensor, trg: Tensor) -> Tensor:
- """Decodes latent images embedding into word pieces.
-
- Args:
- src (Tensor): Latent images embedding.
- trg (Tensor): Word embeddings.
-
- Shapes:
- - z: :math: `(B, Sx, E)`
- - context: :math: `(B, Sy)`
- - out: :math: `(B, Sy, T)`
-
- where Sy is the length of the output and T is the number of tokens.
-
- Returns:
- Tensor: Sequence of word piece embeddings.
- """
- trg = trg.long()
- trg_mask = trg != self.pad_index
- trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim)
- trg = (
- self.token_pos_embedding(trg)
- if self.token_pos_embedding is not None
- else trg
- )
- out = self.decoder(x=trg, context=src, input_mask=trg_mask)
- out = self.norm(out)
- logits = self.to_logits(out) # [B, Sy, T]
- logits = logits.permute(0, 2, 1) # [B, T, Sy]
- return logits
-
- def forward(self, x: Tensor, context: Tensor) -> Tensor:
- """Encodes images into word piece logtis.
-
- Args:
- x (Tensor): Input image(s).
- context (Tensor): Target word embeddings.
-
- Shapes:
- - x: :math: `(B, C, H, W)`
- - context: :math: `(B, Sy, T)`
-
- where B is the batch size, C is the number of input channels, H is
- the image height and W is the image width.
-
- Returns:
- Tensor: Sequence of logits.
- """
- z = self.encode(x)
- logits = self.decode(z, context)
- return logits