diff --git a/llms/mlx_lm/models/gemma3_text.py b/llms/mlx_lm/models/gemma3_text.py index 6f7903f1..81127d4c 100644 --- a/llms/mlx_lm/models/gemma3_text.py +++ b/llms/mlx_lm/models/gemma3_text.py @@ -116,7 +116,6 @@ class MLP(nn.Module): self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: - # This should not be GELU approx, jax.nn.gelu return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))