diff options
Diffstat (limited to 'text_recognizer/networks/transformer/norm.py')
-rw-r--r-- | text_recognizer/networks/transformer/norm.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py index 4cd3b5b..0bd2e16 100644 --- a/text_recognizer/networks/transformer/norm.py +++ b/text_recognizer/networks/transformer/norm.py @@ -7,8 +7,7 @@ Copied from lucidrains: from typing import Dict, Optional, Type import torch -from torch import nn -from torch import Tensor +from torch import Tensor, nn class RMSNorm(nn.Module): @@ -16,7 +15,7 @@ class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-8) -> None: super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) |