summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/norm.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/norm.py')
-rw-r--r--text_recognizer/networks/transformer/norm.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
index 8bc3221..4930adf 100644
--- a/text_recognizer/networks/transformer/norm.py
+++ b/text_recognizer/networks/transformer/norm.py
@@ -12,9 +12,9 @@ from torch import Tensor
class ScaleNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1.0e-5) -> None:
+ def __init__(self, normalized_shape: int, eps: float = 1.0e-5) -> None:
super().__init__()
- self.scale = dim ** -0.5
+ self.scale = normalized_shape ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
@@ -24,9 +24,9 @@ class ScaleNorm(nn.Module):
class PreNorm(nn.Module):
- def __init__(self, dim: int, fn: Type[nn.Module]) -> None:
+ def __init__(self, normalized_shape: int, fn: Type[nn.Module]) -> None:
super().__init__()
- self.norm = nn.LayerNorm(dim)
+ self.norm = nn.LayerNorm(normalized_shape)
self.fn = fn
def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: