override dtype with quant (#1062)

This commit is contained in:
Awni Hannun
2024-10-22 09:56:45 -07:00
committed by GitHub
parent 743763bc2e
commit 66e7bcb886
3 changed files with 3 additions and 3 deletions

View File

@@ -720,7 +720,7 @@ def convert(
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
dtype = getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if quantize and dequantize: