summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/attend.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network/transformer/attend.py')
-rw-r--r--text_recognizer/network/transformer/attend.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py
index a5c23c6..23a6487 100644
--- a/text_recognizer/network/transformer/attend.py
+++ b/text_recognizer/network/transformer/attend.py
@@ -1,10 +1,10 @@
-from typing import Optional
from collections import namedtuple
+from typing import Optional
import torch
-from torch import Tensor, einsum, nn
-from einops import rearrange
import torch.nn.functional as F
+from einops import rearrange
+from torch import Tensor, einsum, nn
Config = namedtuple(
"FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
@@ -79,6 +79,11 @@ class Attend(nn.Module):
causal: bool,
mask: Optional[Tensor] = None,
) -> Tensor:
+ if k.ndim == 3:
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
+ if v.ndim == 3:
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
+
if mask is not None:
mask = rearrange(mask, "b j -> b 1 1 j")
if self.use_flash: