diff --git a/llms/hf_llm/convert.py b/llms/hf_llm/convert.py index 0704b24b..f93d01c3 100644 --- a/llms/hf_llm/convert.py +++ b/llms/hf_llm/convert.py @@ -154,12 +154,12 @@ if __name__ == "__main__": print("[INFO] Loading") weights, config, tokenizer = fetch_from_hub(args.hf_path) + + dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) + weights = {k: v.astype(dtype) for k, v in weights.items()} if args.quantize: print("[INFO] Quantizing") weights, config = quantize(weights, config, args) - if not args.quantize: - dtype = getattr(mx, args.dtype) - weights = {k: v.astype(dtype) for k, v in weights.items()} mlx_path = Path(args.mlx_path) mlx_path.mkdir(parents=True, exist_ok=True)