From cf0ad26a89c089579dab879646aab80a407b7bce Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 5 Jan 2024 21:29:15 -0800 Subject: [PATCH] force fp16 for quantized models (#240) --- llms/hf_llm/convert.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/hf_llm/convert.py b/llms/hf_llm/convert.py index 0704b24b..f93d01c3 100644 --- a/llms/hf_llm/convert.py +++ b/llms/hf_llm/convert.py @@ -154,12 +154,12 @@ if __name__ == "__main__": print("[INFO] Loading") weights, config, tokenizer = fetch_from_hub(args.hf_path) + + dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) + weights = {k: v.astype(dtype) for k, v in weights.items()} if args.quantize: print("[INFO] Quantizing") weights, config = quantize(weights, config, args) - if not args.quantize: - dtype = getattr(mx, args.dtype) - weights = {k: v.astype(dtype) for k, v in weights.items()} mlx_path = Path(args.mlx_path) mlx_path.mkdir(parents=True, exist_ok=True)