mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
override dtype with quant (#1062)
This commit is contained in:
@@ -111,7 +111,7 @@ class MLP(nn.Module):
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
|
Reference in New Issue
Block a user