summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 23:08:16 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 23:08:16 +0200
commit2d4714fcfeb8914f240a0d36d938b434e82f191b (patch)
tree32e7b3446332cee4685ec90870210c51f9f1279f /text_recognizer/networks/transformer/attention.py
parent5dc8a7097ab6b4f39f0a3add408e3fd0f131f85b (diff)
Add new transformer network
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r--text_recognizer/networks/transformer/attention.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index cce1ecc..ac75d2f 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -50,8 +50,9 @@ class MultiHeadAttention(nn.Module):
)
nn.init.xavier_normal_(self.fc_out.weight)
+ @staticmethod
def scaled_dot_product_attention(
- self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
) -> Tensor:
"""Calculates the scaled dot product attention."""