1 2 3 4 5 6 7 8
from torch import nn import torch.nn.functional as F class SwiGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x