summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/cnn_tranformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/cnn_tranformer.py')
-rw-r--r--text_recognizer/networks/cnn_tranformer.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py
index ff0ae82..5c13e9a 100644
--- a/text_recognizer/networks/cnn_tranformer.py
+++ b/text_recognizer/networks/cnn_tranformer.py
@@ -1,9 +1,12 @@
"""Vision transformer for character recognition."""
-from typing import Type
+import math
+from typing import Tuple, Type
import attr
from torch import nn, Tensor
+from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
PositionalEncoding,
PositionalEncoding2D,
@@ -21,17 +24,20 @@ class CnnTransformer(nn.Module):
dropout_rate: float = attr.ib()
max_output_len: int = attr.ib()
num_classes: int = attr.ib()
+ padding_idx: int = attr.ib()
# Modules.
encoder: Type[nn.Module] = attr.ib()
- decoder: Type[nn.Module] = attr.ib()
+ decoder: Decoder = attr.ib()
embedding: nn.Embedding = attr.ib(init=False, default=None)
latent_encoder: nn.Sequential = attr.ib(init=False, default=None)
token_embedding: nn.Embedding = attr.ib(init=False, default=None)
token_pos_encoder: PositionalEncoding = attr.ib(init=False, default=None)
head: nn.Linear = attr.ib(init=False, default=None)
+ mapping: AbstractMapping = attr.ib(init=False, default=None)
def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -106,6 +112,7 @@ class CnnTransformer(nn.Module):
Shapes:
- z: :math: `(B, Sx, E)`
+ - trg: :math: `(B, Sy)`
- out: :math: `(B, Sy, T)`
where Sy is the length of the output and T is the number of tokens.
@@ -113,7 +120,12 @@ class CnnTransformer(nn.Module):
Returns:
Tensor: Sequence of word piece embeddings.
"""
- pass
+ trg_mask = trg != self.padding_idx
+ trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim)
+ trg = self.token_pos_encoder(trg)
+ out = self.decoder(x=trg, context=z, mask=trg_mask)
+ logits = self.head(out)
+ return logits
def forward(self, x: Tensor, trg: Tensor) -> Tensor:
"""Encodes images into word piece logtis.
@@ -130,9 +142,10 @@ class CnnTransformer(nn.Module):
the image height and W is the image width.
"""
z = self.encode(x)
- y = self.decode(z, trg)
- return y
+ logits = self.decode(z, trg)
+ return logits
def predict(self, x: Tensor) -> Tensor:
"""Predicts text in image."""
+ # TODO: continue here!!!!!!!!!
pass