summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:14:49 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:14:49 +0200
commit43cf6e431b28b60b62d5689e42a591937d122154 (patch)
treebb051e460a1b6d7b2d1bebfeb07ec4d323cd000d /text_recognizer
parent381bccada3f4c1755dda6b059097cdf7739311ff (diff)
Remove unused import and add comments in attn module
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/transformer/attention.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 3d2ece1..9b33944 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -3,7 +3,6 @@ from typing import Optional, Tuple
import attr
from einops import rearrange
-from einops.layers.torch import Rearrange
import torch
from torch import einsum
from torch import nn
@@ -18,6 +17,7 @@ class Attention(nn.Module):
"""Standard attention."""
def __attrs_pre_init__(self) -> None:
+ """Pre init constructor."""
super().__init__()
dim: int = attr.ib()
@@ -52,6 +52,7 @@ class Attention(nn.Module):
context_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
+ """Computes the attention."""
b, n, _, device = *x.shape, x.device
q = self.query(x)