summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/__init__.py2
-rw-r--r--text_recognizer/networks/transformer/attention.py3
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py14
3 files changed, 15 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
index 9febc88..139cd23 100644
--- a/text_recognizer/networks/transformer/__init__.py
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -1,3 +1,3 @@
"""Transformer modules."""
-from .positional_encoding import PositionalEncoding
+from .positional_encoding import PositionalEncoding, PositionalEncoding2D, target_padding_mask
from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index cce1ecc..ac75d2f 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -50,8 +50,9 @@ class MultiHeadAttention(nn.Module):
)
nn.init.xavier_normal_(self.fc_out.weight)
+ @staticmethod
def scaled_dot_product_attention(
- self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
) -> Tensor:
"""Calculates the scaled dot product attention."""
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index d67d297..dbde887 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -56,9 +56,9 @@ class PositionalEncoding2D(nn.Module):
pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
pe_w = PositionalEncoding.make_pe(
- hidden_dim // 2, max_len=max_h
+ hidden_dim // 2, max_len=max_w
) # [W, 1, D // 2]
- pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h)
+ pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h)
pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W]
return pe
@@ -70,3 +70,13 @@ class PositionalEncoding2D(nn.Module):
raise ValueError("Hidden dimensions does not match.")
x += self.pe[:, : x.shape[2], : x.shape[3]]
return x
+
+def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
+ """Returns causal target mask."""
+ trg_pad_mask = (trg != pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask