summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/norm.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:14 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:14 +0200
commit49ca6ade1a19f7f9c702171537fe4be0dfcda66d (patch)
tree20062ed1910758481f3d5fff11159706c7b990c6 /text_recognizer/networks/transformer/norm.py
parent0421daf6bd97596703f426ba61c401599b538eeb (diff)
Rename and add flash atten
Diffstat (limited to 'text_recognizer/networks/transformer/norm.py')
-rw-r--r--text_recognizer/networks/transformer/norm.py51
1 files changed, 0 insertions, 51 deletions
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
deleted file mode 100644
index 1431327..0000000
--- a/text_recognizer/networks/transformer/norm.py
+++ /dev/null
@@ -1,51 +0,0 @@
-"""Normalization layers for transfromers.
-
-Copied from lucidrains:
- https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
-
-"""
-from typing import Optional, Type
-
-import torch
-from torch import Tensor, nn
-
-
-class RMSNorm(nn.Module):
- """Root mean square layer normalization."""
-
- def __init__(self, dim: int, eps: float = 1e-8) -> None:
- super().__init__()
- self.scale = dim**-0.5
- self.eps = eps
- self.g = nn.Parameter(torch.ones(dim))
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies normalization."""
- norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
- return x / norm.clamp(min=self.eps) * self.g
-
-
-class PreNorm(nn.Module):
- """Applies layer normalization then function."""
-
- def __init__(
- self,
- normalized_shape: int,
- fn: Type[nn.Module],
- context_dim: Optional[int] = None,
- ) -> None:
- super().__init__()
- self.norm = nn.LayerNorm(normalized_shape)
- self.fn = fn
- self.norm_context = (
- nn.LayerNorm(context_dim) if context_dim is not None else None
- )
-
- def forward(self, x: Tensor, **kwargs) -> Tensor:
- """Applies pre norm."""
- x = self.norm(x)
- if self.norm_context is not None:
- context = kwargs["context"]
- normed_context = self.norm_context(context)
- kwargs.update(context=normed_context)
- return self.fn(x, **kwargs)