summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/glu.py
blob: 1a7c2019c6f48a5e92851be3bf2e2e13efc63a4e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
"""GLU layer."""
from torch import nn, Tensor


class GLU(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: Tensor) -> Tensor:
        out, gate = x.chunk(2, dim=self.dim)
        return out * gate.sigmoid()