summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/norm.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-30 23:15:03 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-30 23:15:03 +0200
commit7268035fb9e57342612a8cc50a1fe04e8841ca2f (patch)
tree8d4cf3743975bd25f2c04d6a56ff3d4608a7e8d9 /text_recognizer/networks/transformer/norm.py
parent92fc1c7ed2f9f64552be8f71d9b8ab0d5a0a88d4 (diff)
attr bug fix, properly loading network
Diffstat (limited to 'text_recognizer/networks/transformer/norm.py')
-rw-r--r--text_recognizer/networks/transformer/norm.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
index 8bc3221..4930adf 100644
--- a/text_recognizer/networks/transformer/norm.py
+++ b/text_recognizer/networks/transformer/norm.py
@@ -12,9 +12,9 @@ from torch import Tensor
class ScaleNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1.0e-5) -> None:
+ def __init__(self, normalized_shape: int, eps: float = 1.0e-5) -> None:
super().__init__()
- self.scale = dim ** -0.5
+ self.scale = normalized_shape ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
@@ -24,9 +24,9 @@ class ScaleNorm(nn.Module):
class PreNorm(nn.Module):
- def __init__(self, dim: int, fn: Type[nn.Module]) -> None:
+ def __init__(self, normalized_shape: int, fn: Type[nn.Module]) -> None:
super().__init__()
- self.norm = nn.LayerNorm(dim)
+ self.norm = nn.LayerNorm(normalized_shape)
self.fn = fn
def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: