summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-05 19:25:59 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-05 19:26:39 +0100
commite5e776cb7ce3486d1a9e16f6ae328f55fd20f02b (patch)
tree61ff5abc85015a720793fe724d7b65c4ca82764e /text_recognizer/networks/transformer/attention.py
parentea525029b8b0355c656280e491796b4821c491a4 (diff)
Rename mask to input_mask
Rename mask to input_mask Rename mask to input_mask
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r--text_recognizer/networks/transformer/attention.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 54ef5e2..b86636e 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -51,7 +51,7 @@ class Attention(nn.Module):
self,
x: Tensor,
context: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
+ input_mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
"""Computes the attention."""
@@ -71,10 +71,10 @@ class Attention(nn.Module):
energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
mask_value = -torch.finfo(energy.dtype).max
energy = apply_input_mask(
- b, n, k, energy, mask, context, context_mask, mask_value, device
+ b, n, k, energy, input_mask, context, context_mask, mask_value, device
)
if self.causal:
- energy = apply_causal_mask(energy, mask, mask_value, device)
+ energy = apply_causal_mask(energy, input_mask, mask_value, device)
attn = F.softmax(energy, dim=-1)
attn = self.dropout(attn)
@@ -98,15 +98,19 @@ def apply_input_mask(
n: int,
k: Tensor,
energy: Tensor,
- mask: Optional[Tensor],
+ input_mask: Optional[Tensor],
context: Optional[Tensor],
context_mask: Optional[Tensor],
mask_value: Tensor,
device: str,
) -> Tensor:
"""Applies an input mask."""
- if any(x is not None for x in (mask, context_mask)):
- q_mask = mask if mask is not None else torch.ones((b, n), device=device).bool()
+ if any(x is not None for x in (input_mask, context_mask)):
+ q_mask = (
+ input_mask
+ if input_mask is not None
+ else torch.ones((b, n), device=device).bool()
+ )
k_mask = q_mask if context is None else context_mask
k_mask = (
torch.ones((b, k.shape[-2]), device=device).bool()