1 2 3 4 5 6 7 8
import torch.nn.functional as F from torch import nn class SwiGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x