summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/conv_transformer.py2
-rw-r--r--text_recognizer/networks/transformer/attention.py16
-rw-r--r--text_recognizer/networks/transformer/layers.py8
3 files changed, 16 insertions, 10 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 0c838d8..59ce814 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -118,7 +118,7 @@ class ConvTransformer(nn.Module):
if self.token_pos_embedding is not None
else trg
)
- out = self.decoder(x=trg, context=src, mask=trg_mask)
+ out = self.decoder(x=trg, context=src, input_mask=trg_mask)
logits = self.head(out) # [B, Sy, T]
logits = logits.permute(0, 2, 1) # [B, T, Sy]
return logits
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 54ef5e2..b86636e 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -51,7 +51,7 @@ class Attention(nn.Module):
self,
x: Tensor,
context: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
+ input_mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
"""Computes the attention."""
@@ -71,10 +71,10 @@ class Attention(nn.Module):
energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
mask_value = -torch.finfo(energy.dtype).max
energy = apply_input_mask(
- b, n, k, energy, mask, context, context_mask, mask_value, device
+ b, n, k, energy, input_mask, context, context_mask, mask_value, device
)
if self.causal:
- energy = apply_causal_mask(energy, mask, mask_value, device)
+ energy = apply_causal_mask(energy, input_mask, mask_value, device)
attn = F.softmax(energy, dim=-1)
attn = self.dropout(attn)
@@ -98,15 +98,19 @@ def apply_input_mask(
n: int,
k: Tensor,
energy: Tensor,
- mask: Optional[Tensor],
+ input_mask: Optional[Tensor],
context: Optional[Tensor],
context_mask: Optional[Tensor],
mask_value: Tensor,
device: str,
) -> Tensor:
"""Applies an input mask."""
- if any(x is not None for x in (mask, context_mask)):
- q_mask = mask if mask is not None else torch.ones((b, n), device=device).bool()
+ if any(x is not None for x in (input_mask, context_mask)):
+ q_mask = (
+ input_mask
+ if input_mask is not None
+ else torch.ones((b, n), device=device).bool()
+ )
k_mask = q_mask if context is None else context_mask
k_mask = (
torch.ones((b, k.shape[-2]), device=device).bool()
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index f740244..8387fa4 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -77,7 +77,7 @@ class AttentionLayers(nn.Module):
self,
x: Tensor,
context: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
+ input_mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward pass."""
@@ -91,9 +91,11 @@ class AttentionLayers(nn.Module):
x = norm(x)
if layer_type == "a":
- out = block(x=x, mask=mask)
+ out = block(x=x, input_mask=input_mask)
elif layer_type == "c":
- out = block(x, context=context, mask=mask, context_mask=context_mask)
+ out = block(
+ x, context=context, input_mask=input_mask, context_mask=context_mask
+ )
elif layer_type == "f":
out = block(x)