force fp16 for quantized models (#240)

This commit is contained in:
Awni Hannun 2024-01-05 21:29:15 -08:00 committed by GitHub
parent 37856f70a8
commit cf0ad26a89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -154,12 +154,12 @@ if __name__ == "__main__":
print("[INFO] Loading")
weights, config, tokenizer = fetch_from_hub(args.hf_path)
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
if not args.quantize:
dtype = getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)