diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index ae19f4d8..b4068679 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -115,7 +115,7 @@ class MLP(nn.Module): self.down_proj = nn.Linear(hidden_dim, dim, bias=False) def __call__(self, x): - return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module):