blob: 3355de93783e5d3271bb569537d3a0e2b9ebbc14 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
"""Layer norm for conv layers."""
import torch
from torch import Tensor, nn
class LayerNorm(nn.Module):
"""Layer norm for convolutions."""
def __init__(self, dim: int) -> None:
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x: Tensor) -> Tensor:
"""Applies layer norm."""
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + eps).sqrt() * self.gamma
|