diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-18 18:11:21 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-18 18:11:53 +0200 |
commit | 2cc6aa059139b57057609817913ad515063c2eab (patch) | |
tree | 5433f69a5eaf63e064a100bf900783127c7b1ff4 /text_recognizer/networks/transformer/attention.py | |
parent | 88caa5c466225d4752541c352c5777235f8f0c61 (diff) |
Format imports
Format imports
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 3df5333..fca260d 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,12 +1,10 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple -from einops import rearrange import torch -from torch import einsum -from torch import nn -from torch import Tensor import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, einsum, nn from text_recognizer.networks.transformer.embeddings.rotary import ( RotaryEmbedding, @@ -35,7 +33,7 @@ class Attention(nn.Module): self.dropout_rate = dropout_rate self.rotary_embedding = rotary_embedding - self.scale = self.dim ** -0.5 + self.scale = self.dim**-0.5 inner_dim = self.num_heads * self.dim_head self.to_q = nn.Linear(self.dim, inner_dim, bias=False) |