summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/convnext/norm.py
blob: 23cf07a23b0f7843e8ffa6fe8f323b332757a9a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""Layer norm for conv layers."""
import torch
from torch import Tensor, nn


class LayerNorm(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x: Tensor) -> Tensor:
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim=1, unbiased=False, keepdim=True)
        mean = torch.mean(x, dim=1, keepdim=True)
        return (x - mean) / (var + eps).sqrt() * self.gamma