blob: 2d896e588dcda442216f88917c36cb05734af38b (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
"""Layer norm for conv layers."""
import torch
from torch import nn, Tensor
class LayerNorm(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x: Tensor) -> Tensor:
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + eps).sqrt() * self.gamma
|