feat(mlx-lm): export the GGUF (fp16) format model weights from fuse.py (#555)

* wip

* wip

* feat: convert mlx model to gguf f16

* chore: conver norm layer to float32 to avoid overflow issue

* chore: add support for mixtral

* chore: clean up

* chore: remove unused import statement

* chore: clean up weight name mapping

* version and readme

* actual version bump

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anchen
2024-03-22 04:34:11 +11:00
committed by GitHub
parent 8f906c859a
commit fe96ef342f
4 changed files with 351 additions and 6 deletions

View File

@@ -3,10 +3,10 @@ import glob
import json
import shutil
from pathlib import Path
from typing import Any, Dict, Union
from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf
from .tuner.lora import LoRALinear
from .tuner.utils import apply_lora_layers, dequantize
from .utils import (
@@ -53,6 +53,17 @@ def parse_arguments() -> argparse.Namespace:
help="Generate a de-quantized model.",
action="store_true",
)
parser.add_argument(
"--export-gguf",
help="Export model weights in GGUF format.",
action="store_true",
)
parser.add_argument(
"--gguf-path",
help="Path to save the exported GGUF format model weights. Default is ggml-model-f16.gguf.",
default="ggml-model-f16.gguf",
type=str,
)
return parser.parse_args()
@@ -95,6 +106,14 @@ def main() -> None:
save_config(config, config_path=save_path / "config.json")
if args.export_gguf:
model_type = config["model_type"]
if model_type not in ["llama", "mixtral", "mistral"]:
raise ValueError(
f"Model type {model_type} not supported for GGUF conversion."
)
convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path))
if args.upload_repo is not None:
hf_path = args.hf_path or (
args.model if not Path(args.model).exists() else None