From fe96ef342f734d1b42b216ab5e31bd22d0c875a0 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Fri, 22 Mar 2024 04:34:11 +1100
Subject: [PATCH] feat(mlx-lm): export the GGUF (fp16) format model weights
from fuse.py (#555)
* wip
* wip
* feat: convert mlx model to gguf f16
* chore: conver norm layer to float32 to avoid overflow issue
* chore: add support for mixtral
* chore: clean up
* chore: remove unused import statement
* chore: clean up weight name mapping
* version and readme
* actual version bump
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/LORA.md | 23 ++-
llms/mlx_lm/fuse.py | 21 ++-
llms/mlx_lm/gguf.py | 311 +++++++++++++++++++++++++++++++++++++++++
llms/mlx_lm/version.py | 2 +-
4 files changed, 351 insertions(+), 6 deletions(-)
create mode 100644 llms/mlx_lm/gguf.py
diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md
index 0f0baf52..d48e7937 100644
--- a/llms/mlx_lm/LORA.md
+++ b/llms/mlx_lm/LORA.md
@@ -9,6 +9,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
- Phi2
- Mixtral
- Qwen2
+- Gemma
- OLMo
## Contents
@@ -17,7 +18,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
* [Fine-tune](#Fine-tune)
* [Evaluate](#Evaluate)
* [Generate](#Generate)
-* [Fuse and Upload](#Fuse-and-Upload)
+* [Fuse](#Fuse)
* [Data](#Data)
* [Memory Issues](#Memory-Issues)
@@ -93,11 +94,14 @@ python -m mlx_lm.generate \
--prompt ""
```
-## Fuse and Upload
+## Fuse
You can generate a model fused with the low-rank adapters using the
-`mlx_lm.fuse` command. This command also allows you to upload the fused model
-to the Hugging Face Hub.
+`mlx_lm.fuse` command. This command also allows you to optionally:
+
+- Upload the fused model to the Hugging Face Hub.
+- Export the fused model to GGUF. Note GGUF support is limited to Mistral,
+ Mixtral, and Llama style models in fp16 precision.
To see supported options run:
@@ -127,6 +131,17 @@ python -m mlx_lm.fuse \
--hf-path mistralai/Mistral-7B-v0.1
```
+To export a fused model to GGUF, run:
+
+```shell
+python -m mlx_lm.fuse \
+ --model mistralai/Mistral-7B-v0.1 \
+ --export-gguf
+```
+
+This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You
+can specify the file name with `--gguf-path`.
+
## Data
The LoRA command expects you to provide a dataset with `--data`. The MLX
diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py
index c10b09b2..44c7eaaa 100644
--- a/llms/mlx_lm/fuse.py
+++ b/llms/mlx_lm/fuse.py
@@ -3,10 +3,10 @@ import glob
import json
import shutil
from pathlib import Path
-from typing import Any, Dict, Union
from mlx.utils import tree_flatten, tree_unflatten
+from .gguf import convert_to_gguf
from .tuner.lora import LoRALinear
from .tuner.utils import apply_lora_layers, dequantize
from .utils import (
@@ -53,6 +53,17 @@ def parse_arguments() -> argparse.Namespace:
help="Generate a de-quantized model.",
action="store_true",
)
+ parser.add_argument(
+ "--export-gguf",
+ help="Export model weights in GGUF format.",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--gguf-path",
+ help="Path to save the exported GGUF format model weights. Default is ggml-model-f16.gguf.",
+ default="ggml-model-f16.gguf",
+ type=str,
+ )
return parser.parse_args()
@@ -95,6 +106,14 @@ def main() -> None:
save_config(config, config_path=save_path / "config.json")
+ if args.export_gguf:
+ model_type = config["model_type"]
+ if model_type not in ["llama", "mixtral", "mistral"]:
+ raise ValueError(
+ f"Model type {model_type} not supported for GGUF conversion."
+ )
+ convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path))
+
if args.upload_repo is not None:
hf_path = args.hf_path or (
args.model if not Path(args.model).exists() else None
diff --git a/llms/mlx_lm/gguf.py b/llms/mlx_lm/gguf.py
new file mode 100644
index 00000000..382e1dce
--- /dev/null
+++ b/llms/mlx_lm/gguf.py
@@ -0,0 +1,311 @@
+import re
+from enum import IntEnum
+from pathlib import Path
+from typing import Iterable, Union
+
+import mlx.core as mx
+from transformers import AutoTokenizer
+
+
+class TokenType(IntEnum):
+ NORMAL = 1
+ UNKNOWN = 2
+ CONTROL = 3
+ USER_DEFINED = 4
+ UNUSED = 5
+ BYTE = 6
+
+
+class GGMLFileType(IntEnum):
+ GGML_TYPE_F16 = 1
+
+
+# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
+class HfVocab:
+ def __init__(
+ self, fname_tokenizer: Path, fname_added_tokens: Path | None = None
+ ) -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ fname_tokenizer,
+ cache_dir=fname_tokenizer,
+ local_files_only=True,
+ )
+ self.added_tokens_list = []
+ self.added_tokens_dict = dict()
+ self.added_tokens_ids = set()
+ for tok, tokidx in sorted(
+ self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
+ ):
+ if tokidx >= self.tokenizer.vocab_size:
+ self.added_tokens_list.append(tok)
+ self.added_tokens_dict[tok] = tokidx
+ self.added_tokens_ids.add(tokidx)
+ self.specials = {
+ tok: self.tokenizer.get_vocab()[tok]
+ for tok in self.tokenizer.all_special_tokens
+ }
+ self.special_ids = set(self.tokenizer.all_special_ids)
+ self.vocab_size_base = self.tokenizer.vocab_size
+ self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
+ self.fname_tokenizer = fname_tokenizer
+ self.fname_added_tokens = fname_added_tokens
+
+ def hf_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
+ reverse_vocab = {
+ id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
+ }
+ for token_id in range(self.vocab_size_base):
+ if token_id in self.added_tokens_ids:
+ continue
+ token_text = reverse_vocab[token_id].encode("utf-8")
+ yield token_text, self.get_token_score(token_id), self.get_token_type(
+ token_id, token_text, self.special_ids
+ )
+
+ def get_token_type(
+ self, token_id: int, token_text: bytes, special_ids: set[int]
+ ) -> TokenType:
+ if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
+ return TokenType.BYTE
+ return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
+
+ def get_token_score(self, token_id: int) -> float:
+ return -1000.0
+
+ def added_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
+ for text in self.added_tokens_list:
+ if text in self.specials:
+ toktype = self.get_token_type(
+ self.specials[text], b"", self.special_ids
+ )
+ score = self.get_token_score(self.specials[text])
+ else:
+ toktype = TokenType.USER_DEFINED
+ score = -1000.0
+ yield text.encode("utf-8"), score, toktype
+
+ def has_newline_token(self):
+ return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
+
+ def all_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
+ yield from self.hf_tokens()
+ yield from self.added_tokens()
+
+ def __repr__(self) -> str:
+ return f""
+
+ @staticmethod
+ def load(path: Path) -> "HfVocab":
+ added_tokens_path = path.parent / "added_tokens.json"
+ return HfVocab(path, added_tokens_path if added_tokens_path.exists() else None)
+
+
+def translate_weight_names(name):
+ name = name.replace("model.layers.", "blk.")
+ # for mixtral gate
+ name = name.replace("block_sparse_moe.gate", "ffn_gate_inp")
+ # for mixtral experts ffns
+ pattern = r"block_sparse_moe\.experts\.(\d+)\.w1\.weight"
+ replacement = r"ffn_gate.\1.weight"
+ name = re.sub(pattern, replacement, name)
+ pattern = r"block_sparse_moe\.experts\.(\d+)\.w2\.weight"
+ replacement = r"ffn_down.\1.weight"
+ name = re.sub(pattern, replacement, name)
+ pattern = r"block_sparse_moe\.experts\.(\d+)\.w3\.weight"
+ replacement = r"ffn_up.\1.weight"
+ name = re.sub(pattern, replacement, name)
+
+ name = name.replace("mlp.gate_proj", "ffn_gate")
+ name = name.replace("mlp.down_proj", "ffn_down")
+ name = name.replace("mlp.up_proj", "ffn_up")
+ name = name.replace("self_attn.q_proj", "attn_q")
+ name = name.replace("self_attn.k_proj", "attn_k")
+ name = name.replace("self_attn.v_proj", "attn_v")
+ name = name.replace("self_attn.o_proj", "attn_output")
+ name = name.replace("input_layernorm", "attn_norm")
+ name = name.replace("post_attention_layernorm", "ffn_norm")
+ name = name.replace("model.embed_tokens", "token_embd")
+ name = name.replace("model.norm", "output_norm")
+ name = name.replace("lm_head", "output")
+ return name
+
+
+def permute_weights(weights, n_head, n_head_kv=None):
+ if n_head_kv is not None and n_head != n_head_kv:
+ n_head = n_head_kv
+ reshaped = weights.reshape(
+ n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]
+ )
+ swapped = reshaped.swapaxes(1, 2)
+ final_shape = weights.shape
+ return swapped.reshape(final_shape)
+
+
+def prepare_metadata(config, vocab):
+ metadata = {
+ "general.name": "llama",
+ "llama.context_length": (
+ mx.array(config["max_position_embeddings"], dtype=mx.uint32)
+ if config.get("max_position_embeddings") is not None
+ else None
+ ),
+ "llama.embedding_length": (
+ mx.array(config["hidden_size"], dtype=mx.uint32)
+ if config.get("hidden_size") is not None
+ else None
+ ),
+ "llama.block_count": (
+ mx.array(config["num_hidden_layers"], dtype=mx.uint32)
+ if config.get("num_hidden_layers") is not None
+ else None
+ ),
+ "llama.feed_forward_length": (
+ mx.array(config["intermediate_size"], dtype=mx.uint32)
+ if config.get("intermediate_size") is not None
+ else None
+ ),
+ "llama.rope.dimension_count": (
+ mx.array(
+ config["hidden_size"] // config["num_attention_heads"], dtype=mx.uint32
+ )
+ if config.get("hidden_size") is not None
+ and config.get("num_attention_heads") is not None
+ else None
+ ),
+ "llama.attention.head_count": (
+ mx.array(config["num_attention_heads"], dtype=mx.uint32)
+ if config.get("num_attention_heads") is not None
+ else None
+ ),
+ "llama.attention.head_count_kv": (
+ mx.array(
+ config.get("num_key_value_heads", config["num_attention_heads"]),
+ dtype=mx.uint32,
+ )
+ if config.get("num_attention_heads") is not None
+ else None
+ ),
+ "llama.expert_count": (
+ mx.array(config.get("num_local_experts", None), dtype=mx.uint32)
+ if config.get("num_local_experts") is not None
+ else None
+ ),
+ "llama.expert_used_count": (
+ mx.array(config.get("num_experts_per_tok", None), dtype=mx.uint32)
+ if config.get("num_experts_per_tok") is not None
+ else None
+ ),
+ "llama.attention.layer_norm_rms_epsilon": (
+ mx.array(config.get("rms_norm_eps", 1e-05))
+ if config.get("rms_norm_eps") is not None
+ else None
+ ),
+ "llama.rope.freq_base": (
+ mx.array(config.get("rope_theta", 10000), dtype=mx.float32)
+ if config.get("rope_theta") is not None
+ else None
+ ),
+ }
+
+ rope_scaling = config.get("rope_scaling")
+ if rope_scaling is not None and (typ := rope_scaling.get("type")):
+ rope_factor = rope_scaling.get("factor")
+ f_rope_scale = rope_factor
+ if typ == "linear":
+ rope_scaling_type = "linear"
+ metadata["llama.rope.scaling.type"] = rope_scaling_type
+ metadata["llama.rope.scaling.factor"] = mx.array(f_rope_scale)
+
+ metadata["general.file_type"] = mx.array(
+ GGMLFileType.GGML_TYPE_F16.value,
+ dtype=mx.uint32,
+ )
+ metadata["general.quantization_version"] = mx.array(
+ GGMLFileType.GGML_TYPE_F16.value,
+ dtype=mx.uint32,
+ )
+ metadata["general.name"] = config.get("_name_or_path", "llama").split("/")[-1]
+ metadata["general.architecture"] = "llama"
+ metadata["general.alignment"] = mx.array(32, dtype=mx.uint32)
+
+ # add metadata for vocab
+ metadata["tokenizer.ggml.model"] = "llama"
+ tokens = []
+ scores = []
+ toktypes = []
+ for text, score, toktype in vocab.all_tokens():
+ tokens.append(text)
+ scores.append(score)
+ toktypes.append(toktype.value)
+ assert len(tokens) == vocab.vocab_size
+ metadata["tokenizer.ggml.tokens"] = tokens
+ metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
+ metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
+ metadata["tokenizer.ggml.bos_token_id"] = mx.array(
+ vocab.tokenizer.bos_token_id, dtype=mx.uint32
+ )
+ metadata["tokenizer.ggml.eos_token_id"] = mx.array(
+ vocab.tokenizer.eos_token_id, dtype=mx.uint32
+ )
+ metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
+ vocab.tokenizer.unk_token_id, dtype=mx.uint32
+ )
+
+ metadata = {k: v for k, v in metadata.items() if v is not None}
+ return metadata
+
+
+def convert_to_gguf(
+ model_path: Union[str, Path],
+ weights: dict,
+ config: dict,
+ output_file_path: str,
+):
+ if isinstance(model_path, str):
+ model_path = Path(model_path)
+
+ quantization = config.get("quantization", None)
+ if quantization:
+ raise NotImplementedError(
+ "Conversion of quantized models is not yet supported."
+ )
+ print("Converting to GGUF format")
+ # https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1182 seems relate to llama.cpp's multihead attention
+ weights = {
+ k: (
+ permute_weights(
+ v, config["num_attention_heads"], config["num_attention_heads"]
+ )
+ if "self_attn.q_proj.weight" in k
+ else (
+ permute_weights(
+ v, config["num_attention_heads"], config["num_key_value_heads"]
+ )
+ if "self_attn.k_proj.weight" in k
+ else v
+ )
+ )
+ for k, v in weights.items()
+ }
+
+ # rename weights for gguf format
+ weights = {translate_weight_names(k): v for k, v in weights.items()}
+
+ if not (model_path / "tokenizer.json").exists():
+ raise ValueError("Tokenizer json not found")
+
+ vocab = HfVocab.load(model_path)
+ metadata = prepare_metadata(config, vocab)
+
+ weights = {
+ k: (
+ v.astype(mx.float32).astype(mx.float16)
+ if v.dtype == mx.bfloat16
+ else v.astype(mx.float32) if "norm" in k else v
+ )
+ for k, v in weights.items()
+ }
+
+ output_file_path = output_file_path
+ mx.save_gguf(output_file_path, weights, metadata)
+ print(f"Converted GGUF model saved as: {output_file_path}")
diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py
index 67c7397c..45e522d1 100644
--- a/llms/mlx_lm/version.py
+++ b/llms/mlx_lm/version.py
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
-__version__ = "0.3.0"
+__version__ = "0.4.0"