From fe96ef342f734d1b42b216ab5e31bd22d0c875a0 Mon Sep 17 00:00:00 2001 From: Anchen Date: Fri, 22 Mar 2024 04:34:11 +1100 Subject: [PATCH] feat(mlx-lm): export the GGUF (fp16) format model weights from fuse.py (#555) * wip * wip * feat: convert mlx model to gguf f16 * chore: conver norm layer to float32 to avoid overflow issue * chore: add support for mixtral * chore: clean up * chore: remove unused import statement * chore: clean up weight name mapping * version and readme * actual version bump --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/LORA.md | 23 ++- llms/mlx_lm/fuse.py | 21 ++- llms/mlx_lm/gguf.py | 311 +++++++++++++++++++++++++++++++++++++++++ llms/mlx_lm/version.py | 2 +- 4 files changed, 351 insertions(+), 6 deletions(-) create mode 100644 llms/mlx_lm/gguf.py diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 0f0baf52..d48e7937 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -9,6 +9,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: - Phi2 - Mixtral - Qwen2 +- Gemma - OLMo ## Contents @@ -17,7 +18,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: * [Fine-tune](#Fine-tune) * [Evaluate](#Evaluate) * [Generate](#Generate) -* [Fuse and Upload](#Fuse-and-Upload) +* [Fuse](#Fuse) * [Data](#Data) * [Memory Issues](#Memory-Issues) @@ -93,11 +94,14 @@ python -m mlx_lm.generate \ --prompt "" ``` -## Fuse and Upload +## Fuse You can generate a model fused with the low-rank adapters using the -`mlx_lm.fuse` command. This command also allows you to upload the fused model -to the Hugging Face Hub. +`mlx_lm.fuse` command. This command also allows you to optionally: + +- Upload the fused model to the Hugging Face Hub. +- Export the fused model to GGUF. Note GGUF support is limited to Mistral, + Mixtral, and Llama style models in fp16 precision. To see supported options run: @@ -127,6 +131,17 @@ python -m mlx_lm.fuse \ --hf-path mistralai/Mistral-7B-v0.1 ``` +To export a fused model to GGUF, run: + +```shell +python -m mlx_lm.fuse \ + --model mistralai/Mistral-7B-v0.1 \ + --export-gguf +``` + +This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You +can specify the file name with `--gguf-path`. + ## Data The LoRA command expects you to provide a dataset with `--data`. The MLX diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index c10b09b2..44c7eaaa 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -3,10 +3,10 @@ import glob import json import shutil from pathlib import Path -from typing import Any, Dict, Union from mlx.utils import tree_flatten, tree_unflatten +from .gguf import convert_to_gguf from .tuner.lora import LoRALinear from .tuner.utils import apply_lora_layers, dequantize from .utils import ( @@ -53,6 +53,17 @@ def parse_arguments() -> argparse.Namespace: help="Generate a de-quantized model.", action="store_true", ) + parser.add_argument( + "--export-gguf", + help="Export model weights in GGUF format.", + action="store_true", + ) + parser.add_argument( + "--gguf-path", + help="Path to save the exported GGUF format model weights. Default is ggml-model-f16.gguf.", + default="ggml-model-f16.gguf", + type=str, + ) return parser.parse_args() @@ -95,6 +106,14 @@ def main() -> None: save_config(config, config_path=save_path / "config.json") + if args.export_gguf: + model_type = config["model_type"] + if model_type not in ["llama", "mixtral", "mistral"]: + raise ValueError( + f"Model type {model_type} not supported for GGUF conversion." + ) + convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path)) + if args.upload_repo is not None: hf_path = args.hf_path or ( args.model if not Path(args.model).exists() else None diff --git a/llms/mlx_lm/gguf.py b/llms/mlx_lm/gguf.py new file mode 100644 index 00000000..382e1dce --- /dev/null +++ b/llms/mlx_lm/gguf.py @@ -0,0 +1,311 @@ +import re +from enum import IntEnum +from pathlib import Path +from typing import Iterable, Union + +import mlx.core as mx +from transformers import AutoTokenizer + + +class TokenType(IntEnum): + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 + USER_DEFINED = 4 + UNUSED = 5 + BYTE = 6 + + +class GGMLFileType(IntEnum): + GGML_TYPE_F16 = 1 + + +# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455 +class HfVocab: + def __init__( + self, fname_tokenizer: Path, fname_added_tokens: Path | None = None + ) -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + fname_tokenizer, + cache_dir=fname_tokenizer, + local_files_only=True, + ) + self.added_tokens_list = [] + self.added_tokens_dict = dict() + self.added_tokens_ids = set() + for tok, tokidx in sorted( + self.tokenizer.get_added_vocab().items(), key=lambda x: x[1] + ): + if tokidx >= self.tokenizer.vocab_size: + self.added_tokens_list.append(tok) + self.added_tokens_dict[tok] = tokidx + self.added_tokens_ids.add(tokidx) + self.specials = { + tok: self.tokenizer.get_vocab()[tok] + for tok in self.tokenizer.all_special_tokens + } + self.special_ids = set(self.tokenizer.all_special_ids) + self.vocab_size_base = self.tokenizer.vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + self.fname_added_tokens = fname_added_tokens + + def hf_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]: + reverse_vocab = { + id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items() + } + for token_id in range(self.vocab_size_base): + if token_id in self.added_tokens_ids: + continue + token_text = reverse_vocab[token_id].encode("utf-8") + yield token_text, self.get_token_score(token_id), self.get_token_type( + token_id, token_text, self.special_ids + ) + + def get_token_type( + self, token_id: int, token_text: bytes, special_ids: set[int] + ) -> TokenType: + if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text): + return TokenType.BYTE + return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL + + def get_token_score(self, token_id: int) -> float: + return -1000.0 + + def added_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]: + for text in self.added_tokens_list: + if text in self.specials: + toktype = self.get_token_type( + self.specials[text], b"", self.special_ids + ) + score = self.get_token_score(self.specials[text]) + else: + toktype = TokenType.USER_DEFINED + score = -1000.0 + yield text.encode("utf-8"), score, toktype + + def has_newline_token(self): + return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab + + def all_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]: + yield from self.hf_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"" + + @staticmethod + def load(path: Path) -> "HfVocab": + added_tokens_path = path.parent / "added_tokens.json" + return HfVocab(path, added_tokens_path if added_tokens_path.exists() else None) + + +def translate_weight_names(name): + name = name.replace("model.layers.", "blk.") + # for mixtral gate + name = name.replace("block_sparse_moe.gate", "ffn_gate_inp") + # for mixtral experts ffns + pattern = r"block_sparse_moe\.experts\.(\d+)\.w1\.weight" + replacement = r"ffn_gate.\1.weight" + name = re.sub(pattern, replacement, name) + pattern = r"block_sparse_moe\.experts\.(\d+)\.w2\.weight" + replacement = r"ffn_down.\1.weight" + name = re.sub(pattern, replacement, name) + pattern = r"block_sparse_moe\.experts\.(\d+)\.w3\.weight" + replacement = r"ffn_up.\1.weight" + name = re.sub(pattern, replacement, name) + + name = name.replace("mlp.gate_proj", "ffn_gate") + name = name.replace("mlp.down_proj", "ffn_down") + name = name.replace("mlp.up_proj", "ffn_up") + name = name.replace("self_attn.q_proj", "attn_q") + name = name.replace("self_attn.k_proj", "attn_k") + name = name.replace("self_attn.v_proj", "attn_v") + name = name.replace("self_attn.o_proj", "attn_output") + name = name.replace("input_layernorm", "attn_norm") + name = name.replace("post_attention_layernorm", "ffn_norm") + name = name.replace("model.embed_tokens", "token_embd") + name = name.replace("model.norm", "output_norm") + name = name.replace("lm_head", "output") + return name + + +def permute_weights(weights, n_head, n_head_kv=None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + reshaped = weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) + swapped = reshaped.swapaxes(1, 2) + final_shape = weights.shape + return swapped.reshape(final_shape) + + +def prepare_metadata(config, vocab): + metadata = { + "general.name": "llama", + "llama.context_length": ( + mx.array(config["max_position_embeddings"], dtype=mx.uint32) + if config.get("max_position_embeddings") is not None + else None + ), + "llama.embedding_length": ( + mx.array(config["hidden_size"], dtype=mx.uint32) + if config.get("hidden_size") is not None + else None + ), + "llama.block_count": ( + mx.array(config["num_hidden_layers"], dtype=mx.uint32) + if config.get("num_hidden_layers") is not None + else None + ), + "llama.feed_forward_length": ( + mx.array(config["intermediate_size"], dtype=mx.uint32) + if config.get("intermediate_size") is not None + else None + ), + "llama.rope.dimension_count": ( + mx.array( + config["hidden_size"] // config["num_attention_heads"], dtype=mx.uint32 + ) + if config.get("hidden_size") is not None + and config.get("num_attention_heads") is not None + else None + ), + "llama.attention.head_count": ( + mx.array(config["num_attention_heads"], dtype=mx.uint32) + if config.get("num_attention_heads") is not None + else None + ), + "llama.attention.head_count_kv": ( + mx.array( + config.get("num_key_value_heads", config["num_attention_heads"]), + dtype=mx.uint32, + ) + if config.get("num_attention_heads") is not None + else None + ), + "llama.expert_count": ( + mx.array(config.get("num_local_experts", None), dtype=mx.uint32) + if config.get("num_local_experts") is not None + else None + ), + "llama.expert_used_count": ( + mx.array(config.get("num_experts_per_tok", None), dtype=mx.uint32) + if config.get("num_experts_per_tok") is not None + else None + ), + "llama.attention.layer_norm_rms_epsilon": ( + mx.array(config.get("rms_norm_eps", 1e-05)) + if config.get("rms_norm_eps") is not None + else None + ), + "llama.rope.freq_base": ( + mx.array(config.get("rope_theta", 10000), dtype=mx.float32) + if config.get("rope_theta") is not None + else None + ), + } + + rope_scaling = config.get("rope_scaling") + if rope_scaling is not None and (typ := rope_scaling.get("type")): + rope_factor = rope_scaling.get("factor") + f_rope_scale = rope_factor + if typ == "linear": + rope_scaling_type = "linear" + metadata["llama.rope.scaling.type"] = rope_scaling_type + metadata["llama.rope.scaling.factor"] = mx.array(f_rope_scale) + + metadata["general.file_type"] = mx.array( + GGMLFileType.GGML_TYPE_F16.value, + dtype=mx.uint32, + ) + metadata["general.quantization_version"] = mx.array( + GGMLFileType.GGML_TYPE_F16.value, + dtype=mx.uint32, + ) + metadata["general.name"] = config.get("_name_or_path", "llama").split("/")[-1] + metadata["general.architecture"] = "llama" + metadata["general.alignment"] = mx.array(32, dtype=mx.uint32) + + # add metadata for vocab + metadata["tokenizer.ggml.model"] = "llama" + tokens = [] + scores = [] + toktypes = [] + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype.value) + assert len(tokens) == vocab.vocab_size + metadata["tokenizer.ggml.tokens"] = tokens + metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32) + metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32) + metadata["tokenizer.ggml.bos_token_id"] = mx.array( + vocab.tokenizer.bos_token_id, dtype=mx.uint32 + ) + metadata["tokenizer.ggml.eos_token_id"] = mx.array( + vocab.tokenizer.eos_token_id, dtype=mx.uint32 + ) + metadata["tokenizer.ggml.unknown_token_id"] = mx.array( + vocab.tokenizer.unk_token_id, dtype=mx.uint32 + ) + + metadata = {k: v for k, v in metadata.items() if v is not None} + return metadata + + +def convert_to_gguf( + model_path: Union[str, Path], + weights: dict, + config: dict, + output_file_path: str, +): + if isinstance(model_path, str): + model_path = Path(model_path) + + quantization = config.get("quantization", None) + if quantization: + raise NotImplementedError( + "Conversion of quantized models is not yet supported." + ) + print("Converting to GGUF format") + # https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1182 seems relate to llama.cpp's multihead attention + weights = { + k: ( + permute_weights( + v, config["num_attention_heads"], config["num_attention_heads"] + ) + if "self_attn.q_proj.weight" in k + else ( + permute_weights( + v, config["num_attention_heads"], config["num_key_value_heads"] + ) + if "self_attn.k_proj.weight" in k + else v + ) + ) + for k, v in weights.items() + } + + # rename weights for gguf format + weights = {translate_weight_names(k): v for k, v in weights.items()} + + if not (model_path / "tokenizer.json").exists(): + raise ValueError("Tokenizer json not found") + + vocab = HfVocab.load(model_path) + metadata = prepare_metadata(config, vocab) + + weights = { + k: ( + v.astype(mx.float32).astype(mx.float16) + if v.dtype == mx.bfloat16 + else v.astype(mx.float32) if "norm" in k else v + ) + for k, v in weights.items() + } + + output_file_path = output_file_path + mx.save_gguf(output_file_path, weights, metadata) + print(f"Converted GGUF model saved as: {output_file_path}") diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 67c7397c..45e522d1 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.3.0" +__version__ = "0.4.0"