diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-09 00:46:23 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-09 00:46:23 +0200 |
commit | d20802e1f412045f7afa4bd8ac50be3488945e90 (patch) | |
tree | dd24469b9dae9cde1a4ff9c8902ed1172b474b21 /text_recognizer/networks/cnn_tranformer.py | |
parent | 3d279b65f19813357ae395e5f72f1efcbd2829f5 (diff) |
Working on cnn transformer, continue with predict
Diffstat (limited to 'text_recognizer/networks/cnn_tranformer.py')
-rw-r--r-- | text_recognizer/networks/cnn_tranformer.py | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py index ff0ae82..5c13e9a 100644 --- a/text_recognizer/networks/cnn_tranformer.py +++ b/text_recognizer/networks/cnn_tranformer.py @@ -1,9 +1,12 @@ """Vision transformer for character recognition.""" -from typing import Type +import math +from typing import Tuple, Type import attr from torch import nn, Tensor +from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( PositionalEncoding, PositionalEncoding2D, @@ -21,17 +24,20 @@ class CnnTransformer(nn.Module): dropout_rate: float = attr.ib() max_output_len: int = attr.ib() num_classes: int = attr.ib() + padding_idx: int = attr.ib() # Modules. encoder: Type[nn.Module] = attr.ib() - decoder: Type[nn.Module] = attr.ib() + decoder: Decoder = attr.ib() embedding: nn.Embedding = attr.ib(init=False, default=None) latent_encoder: nn.Sequential = attr.ib(init=False, default=None) token_embedding: nn.Embedding = attr.ib(init=False, default=None) token_pos_encoder: PositionalEncoding = attr.ib(init=False, default=None) head: nn.Linear = attr.ib(init=False, default=None) + mapping: AbstractMapping = attr.ib(init=False, default=None) def __attrs_post_init__(self) -> None: + """Post init configuration.""" # Latent projector for down sampling number of filters and 2d # positional encoding. self.latent_encoder = nn.Sequential( @@ -106,6 +112,7 @@ class CnnTransformer(nn.Module): Shapes: - z: :math: `(B, Sx, E)` + - trg: :math: `(B, Sy)` - out: :math: `(B, Sy, T)` where Sy is the length of the output and T is the number of tokens. @@ -113,7 +120,12 @@ class CnnTransformer(nn.Module): Returns: Tensor: Sequence of word piece embeddings. """ - pass + trg_mask = trg != self.padding_idx + trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) + trg = self.token_pos_encoder(trg) + out = self.decoder(x=trg, context=z, mask=trg_mask) + logits = self.head(out) + return logits def forward(self, x: Tensor, trg: Tensor) -> Tensor: """Encodes images into word piece logtis. @@ -130,9 +142,10 @@ class CnnTransformer(nn.Module): the image height and W is the image width. """ z = self.encode(x) - y = self.decode(z, trg) - return y + logits = self.decode(z, trg) + return logits def predict(self, x: Tensor) -> Tensor: """Predicts text in image.""" + # TODO: continue here!!!!!!!!! pass |