From c899c05e801b5c07159353434390e10b8625fe06 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 30 Sep 2021 23:08:46 +0200 Subject: Major bug fix in attention layer --- text_recognizer/networks/transformer/attention.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 37ce29e..34b6101 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -37,11 +37,10 @@ class Attention(nn.Module): self.scale = self.dim ** -0.5 inner_dim = self.dim * self.dim_head - # Attnetion - self.qkv_fn = nn.Sequential( - nn.Linear(self.dim, 3 * inner_dim, bias=False), - Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads), - ) + self.query = nn.Linear(self.dim, inner_dim, bias=False) + self.key = nn.Linear(self.dim, inner_dim, bias=False) + self.value = nn.Linear(self.dim, inner_dim, bias=False) + self.dropout = nn.Dropout(p=self.dropout_rate) # Feedforward @@ -72,7 +71,7 @@ class Attention(nn.Module): q_mask = ( mask if mask is not None else torch.ones((b, n), device=device).bool() ) - k_mask = q_mask if context is not None else context_mask + k_mask = q_mask if context is None else context_mask k_mask = ( torch.ones((b, k.shape[-2]), device=device).bool() if k_mask is None @@ -104,7 +103,13 @@ class Attention(nn.Module): rotary_pos_emb: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: b, n, _, device = *x.shape, x.device - q, k, v = self.qkv_fn(x) + + q = self.query(x) + k = self.key(context) if context is not None else self.key(x) + v = self.value(context) if context is not None else self.value(x) + q, k, v = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) + ) q, k = ( self._apply_rotary_emb(q, k, rotary_pos_emb) if rotary_pos_emb is not None -- cgit v1.2.3-70-g09d2