diff options
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
| -rw-r--r-- | text_recognizer/networks/transformer/attention.py | 3 | 
1 files changed, 2 insertions, 1 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index cce1ecc..ac75d2f 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -50,8 +50,9 @@ class MultiHeadAttention(nn.Module):          )          nn.init.xavier_normal_(self.fc_out.weight) +    @staticmethod      def scaled_dot_product_attention( -        self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None +        query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None      ) -> Tensor:          """Calculates the scaled dot product attention."""  |