From d20802e1f412045f7afa4bd8ac50be3488945e90 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 9 Jul 2021 00:46:23 +0200 Subject: Working on cnn transformer, continue with predict --- text_recognizer/networks/cnn_tranformer.py | 23 +++++++++++++++++----- .../transformer/positional_encodings/__init__.py | 6 +++++- .../positional_encodings/rotary_embedding.py | 1 - 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py index ff0ae82..5c13e9a 100644 --- a/text_recognizer/networks/cnn_tranformer.py +++ b/text_recognizer/networks/cnn_tranformer.py @@ -1,9 +1,12 @@ """Vision transformer for character recognition.""" -from typing import Type +import math +from typing import Tuple, Type import attr from torch import nn, Tensor +from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( PositionalEncoding, PositionalEncoding2D, @@ -21,17 +24,20 @@ class CnnTransformer(nn.Module): dropout_rate: float = attr.ib() max_output_len: int = attr.ib() num_classes: int = attr.ib() + padding_idx: int = attr.ib() # Modules. encoder: Type[nn.Module] = attr.ib() - decoder: Type[nn.Module] = attr.ib() + decoder: Decoder = attr.ib() embedding: nn.Embedding = attr.ib(init=False, default=None) latent_encoder: nn.Sequential = attr.ib(init=False, default=None) token_embedding: nn.Embedding = attr.ib(init=False, default=None) token_pos_encoder: PositionalEncoding = attr.ib(init=False, default=None) head: nn.Linear = attr.ib(init=False, default=None) + mapping: AbstractMapping = attr.ib(init=False, default=None) def __attrs_post_init__(self) -> None: + """Post init configuration.""" # Latent projector for down sampling number of filters and 2d # positional encoding. self.latent_encoder = nn.Sequential( @@ -106,6 +112,7 @@ class CnnTransformer(nn.Module): Shapes: - z: :math: `(B, Sx, E)` + - trg: :math: `(B, Sy)` - out: :math: `(B, Sy, T)` where Sy is the length of the output and T is the number of tokens. @@ -113,7 +120,12 @@ class CnnTransformer(nn.Module): Returns: Tensor: Sequence of word piece embeddings. """ - pass + trg_mask = trg != self.padding_idx + trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) + trg = self.token_pos_encoder(trg) + out = self.decoder(x=trg, context=z, mask=trg_mask) + logits = self.head(out) + return logits def forward(self, x: Tensor, trg: Tensor) -> Tensor: """Encodes images into word piece logtis. @@ -130,9 +142,10 @@ class CnnTransformer(nn.Module): the image height and W is the image width. """ z = self.encode(x) - y = self.decode(z, trg) - return y + logits = self.decode(z, trg) + return logits def predict(self, x: Tensor) -> Tensor: """Predicts text in image.""" + # TODO: continue here!!!!!!!!! pass diff --git a/text_recognizer/networks/transformer/positional_encodings/__init__.py b/text_recognizer/networks/transformer/positional_encodings/__init__.py index 91278ee..2ed8a12 100644 --- a/text_recognizer/networks/transformer/positional_encodings/__init__.py +++ b/text_recognizer/networks/transformer/positional_encodings/__init__.py @@ -1,4 +1,8 @@ """Positional encoding for transformers.""" from .absolute_embedding import AbsolutePositionalEmbedding -from .positional_encoding import PositionalEncoding, PositionalEncoding2D +from .positional_encoding import ( + PositionalEncoding, + PositionalEncoding2D, + target_padding_mask, +) from .rotary_embedding import apply_rotary_pos_emb, RotaryEmbedding diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py index 5e80572..41290b4 100644 --- a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py +++ b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py @@ -5,7 +5,6 @@ Stolen from lucidrains: Explanation of roatary: https://blog.eleuther.ai/rotary-embeddings/ - """ from typing import Tuple -- cgit v1.2.3-70-g09d2