diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-15 21:49:51 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-15 21:49:51 +0200 |
commit | b3fbfd72a8f647161685b28d20b4b61519d8a643 (patch) | |
tree | a5cac4e15186396aae35231d6d6fe266691b0186 /text_recognizer/network/transformer/attend.py | |
parent | c7e5354ffa43eccfc4e411375ce2f531af7bbcff (diff) |
Update transformer model
Diffstat (limited to 'text_recognizer/network/transformer/attend.py')
-rw-r--r-- | text_recognizer/network/transformer/attend.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py index a5c23c6..23a6487 100644 --- a/text_recognizer/network/transformer/attend.py +++ b/text_recognizer/network/transformer/attend.py @@ -1,10 +1,10 @@ -from typing import Optional from collections import namedtuple +from typing import Optional import torch -from torch import Tensor, einsum, nn -from einops import rearrange import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, einsum, nn Config = namedtuple( "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"] @@ -79,6 +79,11 @@ class Attend(nn.Module): causal: bool, mask: Optional[Tensor] = None, ) -> Tensor: + if k.ndim == 3: + k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) + if v.ndim == 3: + v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + if mask is not None: mask = rearrange(mask, "b j -> b 1 1 j") if self.use_flash: |