From 66e7bcb8866a050727849d9a303c54a0119f0f99 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Oct 2024 09:56:45 -0700 Subject: [PATCH] override dtype with quant (#1062) --- llms/mlx_lm/convert.py | 2 +- llms/mlx_lm/models/gemma2.py | 2 +- llms/mlx_lm/utils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index a3f43f71..9bac77a5 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -31,7 +31,7 @@ def configure_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--dtype", - help="Type to save the parameters, ignored if -q is given.", + help="Type to save the non-quantized parameters.", type=str, choices=["float16", "bfloat16", "float32"], default="float16", diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index ccc327a8..64951ae4 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -111,7 +111,7 @@ class MLP(nn.Module): self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: - return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4f872982..92741b68 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -720,7 +720,7 @@ def convert( model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters())) - dtype = mx.float16 if quantize else getattr(mx, dtype) + dtype = getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} if quantize and dequantize: