summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/swiglu.py
blob: 7bafd06f27fc60d22b4ec975e0619f279e76cfa1 (plain)
1
2
3
4
5
6
7
8
import torch.nn.functional as F
from torch import nn


class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x