mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
force fp16 for quantized models (#240)
This commit is contained in:
parent
37856f70a8
commit
cf0ad26a89
@ -154,12 +154,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
print("[INFO] Loading")
|
print("[INFO] Loading")
|
||||||
weights, config, tokenizer = fetch_from_hub(args.hf_path)
|
weights, config, tokenizer = fetch_from_hub(args.hf_path)
|
||||||
|
|
||||||
|
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
|
||||||
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||||
if args.quantize:
|
if args.quantize:
|
||||||
print("[INFO] Quantizing")
|
print("[INFO] Quantizing")
|
||||||
weights, config = quantize(weights, config, args)
|
weights, config = quantize(weights, config, args)
|
||||||
if not args.quantize:
|
|
||||||
dtype = getattr(mx, args.dtype)
|
|
||||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
|
||||||
|
|
||||||
mlx_path = Path(args.mlx_path)
|
mlx_path = Path(args.mlx_path)
|
||||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user