diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 23:08:16 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 23:08:16 +0200 |
commit | 2d4714fcfeb8914f240a0d36d938b434e82f191b (patch) | |
tree | 32e7b3446332cee4685ec90870210c51f9f1279f /text_recognizer/networks/transformer/attention.py | |
parent | 5dc8a7097ab6b4f39f0a3add408e3fd0f131f85b (diff) |
Add new transformer network
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index cce1ecc..ac75d2f 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -50,8 +50,9 @@ class MultiHeadAttention(nn.Module): ) nn.init.xavier_normal_(self.fc_out.weight) + @staticmethod def scaled_dot_product_attention( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None ) -> Tensor: """Calculates the scaled dot product attention.""" |