summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-01-29 15:54:28 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-01-29 15:54:28 +0100
commitc2dd53291e34f2ca75c8dbcd9b0653899682fae4 (patch)
treed2dd7ce2f52e88fd530c7dfc1b102f2d5c2441c3 /text_recognizer
parent8f08ac1d32bfb3c2df2b5530e9ac6563647fea7b (diff)
feat: add RMSNorm
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/transformer/norm.py28
1 files changed, 6 insertions, 22 deletions
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
index 98f4d7f..2b416e6 100644
--- a/text_recognizer/networks/transformer/norm.py
+++ b/text_recognizer/networks/transformer/norm.py
@@ -4,37 +4,21 @@ Copied from lucidrains:
https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
"""
-from typing import Dict, Type
-
import torch
from torch import nn
from torch import Tensor
-class ScaleNorm(nn.Module):
- """Scaled normalization."""
+class RMSNorm(nn.Module):
+ """Root mean square layer normalization."""
- def __init__(self, normalized_shape: int, eps: float = 1.0e-5) -> None:
+ def __init__(self, dim: int, eps: float = 1e-8) -> None:
super().__init__()
- self.scale = normalized_shape ** -0.5
+ self.scale = dim ** -0.5
self.eps = eps
- self.g = nn.Parameter(torch.ones(1))
+ self.g = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
- """Applies scale norm."""
+ """Applies normalization."""
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
-
-
-class PreNorm(nn.Module):
- """Applies layer normalization then function."""
-
- def __init__(self, normalized_shape: int, fn: Type[nn.Module]) -> None:
- super().__init__()
- self.norm = nn.LayerNorm(normalized_shape)
- self.fn = fn
-
- def forward(self, x: Tensor, **kwargs: Dict) -> Tensor:
- """Applies pre norm."""
- x = self.norm(x)
- return self.fn(x, **kwargs)