diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-18 18:11:21 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-18 18:11:53 +0200 |
commit | 2cc6aa059139b57057609817913ad515063c2eab (patch) | |
tree | 5433f69a5eaf63e064a100bf900783127c7b1ff4 /text_recognizer/networks/transformer/norm.py | |
parent | 88caa5c466225d4752541c352c5777235f8f0c61 (diff) |
Format imports
Format imports
Diffstat (limited to 'text_recognizer/networks/transformer/norm.py')
-rw-r--r-- | text_recognizer/networks/transformer/norm.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py index 4cd3b5b..0bd2e16 100644 --- a/text_recognizer/networks/transformer/norm.py +++ b/text_recognizer/networks/transformer/norm.py @@ -7,8 +7,7 @@ Copied from lucidrains: from typing import Dict, Optional, Type import torch -from torch import nn -from torch import Tensor +from torch import Tensor, nn class RMSNorm(nn.Module): @@ -16,7 +15,7 @@ class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-8) -> None: super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) |