import torch import torch.nn as nn import torch.nn.functional as F class SwiGLU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: x, gates = x.chunk(2, dim=-1) return x * F.silu(gates) class GEGLU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates)