summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-18 18:11:21 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-18 18:11:53 +0200
commit2cc6aa059139b57057609817913ad515063c2eab (patch)
tree5433f69a5eaf63e064a100bf900783127c7b1ff4 /text_recognizer/networks/transformer/attention.py
parent88caa5c466225d4752541c352c5777235f8f0c61 (diff)
Format imports
Format imports
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r--text_recognizer/networks/transformer/attention.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 3df5333..fca260d 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,12 +1,10 @@
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
-from einops import rearrange
import torch
-from torch import einsum
-from torch import nn
-from torch import Tensor
import torch.nn.functional as F
+from einops import rearrange
+from torch import Tensor, einsum, nn
from text_recognizer.networks.transformer.embeddings.rotary import (
RotaryEmbedding,
@@ -35,7 +33,7 @@ class Attention(nn.Module):
self.dropout_rate = dropout_rate
self.rotary_embedding = rotary_embedding
- self.scale = self.dim ** -0.5
+ self.scale = self.dim**-0.5
inner_dim = self.num_heads * self.dim_head
self.to_q = nn.Linear(self.dim, inner_dim, bias=False)