summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-05 19:25:59 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-05 19:26:39 +0100
commite5e776cb7ce3486d1a9e16f6ae328f55fd20f02b (patch)
tree61ff5abc85015a720793fe724d7b65c4ca82764e /text_recognizer/networks/transformer
parentea525029b8b0355c656280e491796b4821c491a4 (diff)
Rename mask to input_mask
Rename mask to input_mask Rename mask to input_mask
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/attention.py16
-rw-r--r--text_recognizer/networks/transformer/layers.py8
2 files changed, 15 insertions, 9 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 54ef5e2..b86636e 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -51,7 +51,7 @@ class Attention(nn.Module):
self,
x: Tensor,
context: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
+ input_mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
"""Computes the attention."""
@@ -71,10 +71,10 @@ class Attention(nn.Module):
energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
mask_value = -torch.finfo(energy.dtype).max
energy = apply_input_mask(
- b, n, k, energy, mask, context, context_mask, mask_value, device
+ b, n, k, energy, input_mask, context, context_mask, mask_value, device
)
if self.causal:
- energy = apply_causal_mask(energy, mask, mask_value, device)
+ energy = apply_causal_mask(energy, input_mask, mask_value, device)
attn = F.softmax(energy, dim=-1)
attn = self.dropout(attn)
@@ -98,15 +98,19 @@ def apply_input_mask(
n: int,
k: Tensor,
energy: Tensor,
- mask: Optional[Tensor],
+ input_mask: Optional[Tensor],
context: Optional[Tensor],
context_mask: Optional[Tensor],
mask_value: Tensor,
device: str,
) -> Tensor:
"""Applies an input mask."""
- if any(x is not None for x in (mask, context_mask)):
- q_mask = mask if mask is not None else torch.ones((b, n), device=device).bool()
+ if any(x is not None for x in (input_mask, context_mask)):
+ q_mask = (
+ input_mask
+ if input_mask is not None
+ else torch.ones((b, n), device=device).bool()
+ )
k_mask = q_mask if context is None else context_mask
k_mask = (
torch.ones((b, k.shape[-2]), device=device).bool()
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index f740244..8387fa4 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -77,7 +77,7 @@ class AttentionLayers(nn.Module):
self,
x: Tensor,
context: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
+ input_mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward pass."""
@@ -91,9 +91,11 @@ class AttentionLayers(nn.Module):
x = norm(x)
if layer_type == "a":
- out = block(x=x, mask=mask)
+ out = block(x=x, input_mask=input_mask)
elif layer_type == "c":
- out = block(x, context=context, mask=mask, context_mask=context_mask)
+ out = block(
+ x, context=context, input_mask=input_mask, context_mask=context_mask
+ )
elif layer_type == "f":
out = block(x)