diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 16 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 8 |
3 files changed, 16 insertions, 10 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 0c838d8..59ce814 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -118,7 +118,7 @@ class ConvTransformer(nn.Module): if self.token_pos_embedding is not None else trg ) - out = self.decoder(x=trg, context=src, mask=trg_mask) + out = self.decoder(x=trg, context=src, input_mask=trg_mask) logits = self.head(out) # [B, Sy, T] logits = logits.permute(0, 2, 1) # [B, T, Sy] return logits diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 54ef5e2..b86636e 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -51,7 +51,7 @@ class Attention(nn.Module): self, x: Tensor, context: Optional[Tensor] = None, - mask: Optional[Tensor] = None, + input_mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, ) -> Tensor: """Computes the attention.""" @@ -71,10 +71,10 @@ class Attention(nn.Module): energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale mask_value = -torch.finfo(energy.dtype).max energy = apply_input_mask( - b, n, k, energy, mask, context, context_mask, mask_value, device + b, n, k, energy, input_mask, context, context_mask, mask_value, device ) if self.causal: - energy = apply_causal_mask(energy, mask, mask_value, device) + energy = apply_causal_mask(energy, input_mask, mask_value, device) attn = F.softmax(energy, dim=-1) attn = self.dropout(attn) @@ -98,15 +98,19 @@ def apply_input_mask( n: int, k: Tensor, energy: Tensor, - mask: Optional[Tensor], + input_mask: Optional[Tensor], context: Optional[Tensor], context_mask: Optional[Tensor], mask_value: Tensor, device: str, ) -> Tensor: """Applies an input mask.""" - if any(x is not None for x in (mask, context_mask)): - q_mask = mask if mask is not None else torch.ones((b, n), device=device).bool() + if any(x is not None for x in (input_mask, context_mask)): + q_mask = ( + input_mask + if input_mask is not None + else torch.ones((b, n), device=device).bool() + ) k_mask = q_mask if context is None else context_mask k_mask = ( torch.ones((b, k.shape[-2]), device=device).bool() diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index f740244..8387fa4 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -77,7 +77,7 @@ class AttentionLayers(nn.Module): self, x: Tensor, context: Optional[Tensor] = None, - mask: Optional[Tensor] = None, + input_mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, ) -> Tensor: """Forward pass.""" @@ -91,9 +91,11 @@ class AttentionLayers(nn.Module): x = norm(x) if layer_type == "a": - out = block(x=x, mask=mask) + out = block(x=x, input_mask=input_mask) elif layer_type == "c": - out = block(x, context=context, mask=mask, context_mask=context_mask) + out = block( + x, context=context, input_mask=input_mask, context_mask=context_mask + ) elif layer_type == "f": out = block(x) |