summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-09 00:46:23 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-09 00:46:23 +0200
commitd20802e1f412045f7afa4bd8ac50be3488945e90 (patch)
treedd24469b9dae9cde1a4ff9c8902ed1172b474b21
parent3d279b65f19813357ae395e5f72f1efcbd2829f5 (diff)
Working on cnn transformer, continue with predict
-rw-r--r--text_recognizer/networks/cnn_tranformer.py23
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/__init__.py6
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py1
3 files changed, 23 insertions, 7 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py
index ff0ae82..5c13e9a 100644
--- a/text_recognizer/networks/cnn_tranformer.py
+++ b/text_recognizer/networks/cnn_tranformer.py
@@ -1,9 +1,12 @@
"""Vision transformer for character recognition."""
-from typing import Type
+import math
+from typing import Tuple, Type
import attr
from torch import nn, Tensor
+from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
PositionalEncoding,
PositionalEncoding2D,
@@ -21,17 +24,20 @@ class CnnTransformer(nn.Module):
dropout_rate: float = attr.ib()
max_output_len: int = attr.ib()
num_classes: int = attr.ib()
+ padding_idx: int = attr.ib()
# Modules.
encoder: Type[nn.Module] = attr.ib()
- decoder: Type[nn.Module] = attr.ib()
+ decoder: Decoder = attr.ib()
embedding: nn.Embedding = attr.ib(init=False, default=None)
latent_encoder: nn.Sequential = attr.ib(init=False, default=None)
token_embedding: nn.Embedding = attr.ib(init=False, default=None)
token_pos_encoder: PositionalEncoding = attr.ib(init=False, default=None)
head: nn.Linear = attr.ib(init=False, default=None)
+ mapping: AbstractMapping = attr.ib(init=False, default=None)
def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -106,6 +112,7 @@ class CnnTransformer(nn.Module):
Shapes:
- z: :math: `(B, Sx, E)`
+ - trg: :math: `(B, Sy)`
- out: :math: `(B, Sy, T)`
where Sy is the length of the output and T is the number of tokens.
@@ -113,7 +120,12 @@ class CnnTransformer(nn.Module):
Returns:
Tensor: Sequence of word piece embeddings.
"""
- pass
+ trg_mask = trg != self.padding_idx
+ trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim)
+ trg = self.token_pos_encoder(trg)
+ out = self.decoder(x=trg, context=z, mask=trg_mask)
+ logits = self.head(out)
+ return logits
def forward(self, x: Tensor, trg: Tensor) -> Tensor:
"""Encodes images into word piece logtis.
@@ -130,9 +142,10 @@ class CnnTransformer(nn.Module):
the image height and W is the image width.
"""
z = self.encode(x)
- y = self.decode(z, trg)
- return y
+ logits = self.decode(z, trg)
+ return logits
def predict(self, x: Tensor) -> Tensor:
"""Predicts text in image."""
+ # TODO: continue here!!!!!!!!!
pass
diff --git a/text_recognizer/networks/transformer/positional_encodings/__init__.py b/text_recognizer/networks/transformer/positional_encodings/__init__.py
index 91278ee..2ed8a12 100644
--- a/text_recognizer/networks/transformer/positional_encodings/__init__.py
+++ b/text_recognizer/networks/transformer/positional_encodings/__init__.py
@@ -1,4 +1,8 @@
"""Positional encoding for transformers."""
from .absolute_embedding import AbsolutePositionalEmbedding
-from .positional_encoding import PositionalEncoding, PositionalEncoding2D
+from .positional_encoding import (
+ PositionalEncoding,
+ PositionalEncoding2D,
+ target_padding_mask,
+)
from .rotary_embedding import apply_rotary_pos_emb, RotaryEmbedding
diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
index 5e80572..41290b4 100644
--- a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
+++ b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
@@ -5,7 +5,6 @@ Stolen from lucidrains:
Explanation of roatary:
https://blog.eleuther.ai/rotary-embeddings/
-
"""
from typing import Tuple