override dtype with quant (#1062)

This commit is contained in:
Awni Hannun
2024-10-22 09:56:45 -07:00
committed by GitHub
parent 743763bc2e
commit 66e7bcb886
3 changed files with 3 additions and 3 deletions

View File

@@ -111,7 +111,7 @@ class MLP(nn.Module):
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):