summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/transformer/attention.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 3d2ece1..9b33944 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -3,7 +3,6 @@ from typing import Optional, Tuple
import attr
from einops import rearrange
-from einops.layers.torch import Rearrange
import torch
from torch import einsum
from torch import nn
@@ -18,6 +17,7 @@ class Attention(nn.Module):
"""Standard attention."""
def __attrs_pre_init__(self) -> None:
+ """Pre init constructor."""
super().__init__()
dim: int = attr.ib()
@@ -52,6 +52,7 @@ class Attention(nn.Module):
context_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
+ """Computes the attention."""
b, n, _, device = *x.shape, x.device
q = self.query(x)