summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/positional_encodings
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-09 00:46:23 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-09 00:46:23 +0200
commitd20802e1f412045f7afa4bd8ac50be3488945e90 (patch)
treedd24469b9dae9cde1a4ff9c8902ed1172b474b21 /text_recognizer/networks/transformer/positional_encodings
parent3d279b65f19813357ae395e5f72f1efcbd2829f5 (diff)
Working on cnn transformer, continue with predict
Diffstat (limited to 'text_recognizer/networks/transformer/positional_encodings')
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/__init__.py6
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py1
2 files changed, 5 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/positional_encodings/__init__.py b/text_recognizer/networks/transformer/positional_encodings/__init__.py
index 91278ee..2ed8a12 100644
--- a/text_recognizer/networks/transformer/positional_encodings/__init__.py
+++ b/text_recognizer/networks/transformer/positional_encodings/__init__.py
@@ -1,4 +1,8 @@
"""Positional encoding for transformers."""
from .absolute_embedding import AbsolutePositionalEmbedding
-from .positional_encoding import PositionalEncoding, PositionalEncoding2D
+from .positional_encoding import (
+ PositionalEncoding,
+ PositionalEncoding2D,
+ target_padding_mask,
+)
from .rotary_embedding import apply_rotary_pos_emb, RotaryEmbedding
diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
index 5e80572..41290b4 100644
--- a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
+++ b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
@@ -5,7 +5,6 @@ Stolen from lucidrains:
Explanation of roatary:
https://blog.eleuther.ai/rotary-embeddings/
-
"""
from typing import Tuple