summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/glu.py
blob: 016b6844618647275b8ac09636eba16fd12dc675 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
"""GLU layer."""
from torch import nn, Tensor


class GLU(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim
        
    def forward(self, x: Tensor) -> Tensor:
        out, gate = x.chunk(2, dim=self.dim)
        return out * gate.sigmoid()