summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/norm.py
blob: d73f9f8ef0fb4b175598299b86a64cbd03a43035 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
"""Normalizer block."""
import attr
from torch import nn, Tensor


@attr.s(eq=False)
class Normalize(nn.Module):
    num_channels: int = attr.ib()
    num_groups: int = attr.ib(default=32)
    norm: nn.GroupNorm = attr.ib(init=False)

    def __attrs_post_init__(self) -> None:
        """Post init configuration."""
        super().__init__()
        self.norm = nn.GroupNorm(
            num_groups=self.num_groups,
            num_channels=self.num_channels,
            eps=1.0e-6,
            affine=True,
        )

    def forward(self, x: Tensor) -> Tensor:
        """Applies group normalization."""
        return self.norm(x)