From b3fbfd72a8f647161685b28d20b4b61519d8a643 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 15 Apr 2024 21:49:51 +0200 Subject: Update transformer model --- text_recognizer/network/transformer/attend.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'text_recognizer/network/transformer/attend.py') diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py index a5c23c6..23a6487 100644 --- a/text_recognizer/network/transformer/attend.py +++ b/text_recognizer/network/transformer/attend.py @@ -1,10 +1,10 @@ -from typing import Optional from collections import namedtuple +from typing import Optional import torch -from torch import Tensor, einsum, nn -from einops import rearrange import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, einsum, nn Config = namedtuple( "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"] @@ -79,6 +79,11 @@ class Attend(nn.Module): causal: bool, mask: Optional[Tensor] = None, ) -> Tensor: + if k.ndim == 3: + k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) + if v.ndim == 3: + v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + if mask is not None: mask = rearrange(mask, "b j -> b 1 1 j") if self.use_flash: -- cgit v1.2.3-70-g09d2