summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/convnext/norm.py
blob: 3355de93783e5d3271bb569537d3a0e2b9ebbc14 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""Layer norm for conv layers."""
import torch
from torch import Tensor, nn


class LayerNorm(nn.Module):
    """Layer norm for convolutions."""

    def __init__(self, dim: int) -> None:
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x: Tensor) -> Tensor:
        """Applies layer norm."""
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim=1, unbiased=False, keepdim=True)
        mean = torch.mean(x, dim=1, keepdim=True)
        return (x - mean) / (var + eps).sqrt() * self.gamma