diff --git a/llms/mlx_lm/gguf.py b/llms/mlx_lm/gguf.py index 241ac35a..b28779a8 100644 --- a/llms/mlx_lm/gguf.py +++ b/llms/mlx_lm/gguf.py @@ -1,11 +1,19 @@ +import importlib import re +import tempfile from enum import IntEnum from pathlib import Path from typing import Iterable, Optional, Set, Tuple, Union +import gguf import mlx.core as mx +import mlx.nn as nn +from gguf import GGMLQuantizationType +from gguf.gguf_reader import GGUFReader from transformers import AutoTokenizer +from .tokenizer_utils import TokenizerWrapper + class TokenType(IntEnum): NORMAL = 1 @@ -312,3 +320,297 @@ def convert_to_gguf( output_file_path = output_file_path mx.save_gguf(output_file_path, weights, metadata) print(f"Converted GGUF model saved as: {output_file_path}") + + +# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L568 +def parse_q4_k(tensor): + bits = 4 + pack_factor = 32 // bits + group_size = 32 + block_size = 144 + + data = mx.array(tensor.data) + shape = [int(d) for d in reversed(tensor.shape)] + wshape = (*shape[:-1], shape[-1] // pack_factor) + gshape = (*shape[:-1], shape[-1] // group_size) + num_blocks = data.size // block_size + kernel = mx.fast.metal_kernel( + name="parse_q4_k", + input_names=["data"], + output_names=["w", "scales", "biases"], + header=""" + typedef struct { + float16_t d; + float16_t d_min; + uint8_t scales[12]; + uint8_t qs[128]; + } block_q4_K; + """, + source=""" + uint elem = thread_position_in_grid.x; + + const device block_q4_K* block = reinterpret_cast(data); + + block += elem; + w += elem * 32; + scales += elem * 8; + biases += elem * 8; + + // First unpack the quantized scales/biases + for (int j = 0; j < 8; j++) { + uint8_t d, m; + if (j < 4) { + d = block->scales[j] & 63; + m = block->scales[j + 4] & 63; + } else { + d = (block->scales[j + 4] & 0xF) | ((block->scales[j - 4] >> 6) << 4); + m = (block->scales[j + 4] >> 4) | ((block->scales[j - 0] >> 6) << 4); + } + scales[j] = d * block->d; + biases[j] = -m * block->d_min; + } + + uint32_t outputs[32] = {0}; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 32; j++) { + uint8_t val = block->qs[i * 32 + j] & 0xf; + int index = i * 8 + (j / 8); + outputs[index] += val << (4 * (j % 8)); + } + for (int j = 0; j < 32; j++) { + uint8_t val = block->qs[i * 32 + j] >> 4; + int index = i * 8 + 4 + (j / 8); + outputs[index] += val << (4 * (j % 8)); + } + } + + for (int i = 0; i < 32; i++) { + w[i] = outputs[i]; + } + """, + ) + w, scales, biases = kernel( + inputs=[data], + grid=(num_blocks, 1, 1), + threadgroup=(256, 1, 1), + output_shapes=[wshape, gshape, gshape], + output_dtypes=[mx.uint32, mx.float16, mx.float16], + ) + return w, scales, biases + + +# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L658 +def parse_q6_k(tensor): + bits = 6 + group_size = 16 + block_size = 210 + + data = mx.array(tensor.data) + shape = [int(d) for d in reversed(tensor.shape)] + wshape = (*shape[:-1], shape[-1] * bits // 8) + gshape = (*shape[:-1], shape[-1] // group_size) + num_blocks = data.size // block_size + kernel = mx.fast.metal_kernel( + name="parse_q6_k", + input_names=["data"], + output_names=["w", "scales", "biases"], + header=""" + typedef struct { + uint8_t ql[128]; // quants, lower 4 bits + uint8_t qh[64]; // quants, upper 2 bits + int8_t scales[16]; // scales, quantized with 8 bits + float16_t d; // super-block scale + } block_q6_K; + """, + source=""" + uint elem = thread_position_in_grid.x; + + const device block_q6_K* block = reinterpret_cast(data); + + block += elem; + w += elem * 192; + scales += elem * 16; + biases += elem * 16; + + const device uint8_t* ql = &block->ql[0]; + const device uint8_t* qh = &block->qh[0]; + const device int8_t* bscales = &block->scales[0]; + + uint32_t output = 0; + for (int cluster = 0; cluster < 2; cluster++) { + for (uint64_t j = 0; j < 128; j++) { + uint8_t val = ((ql[j%64] >> (j/64*4)) & 0xF) | (((qh[j%32] >> (j/32*2)) & 3) << 4); + + output += val << (6 * (j % 4)); + + // Every 4 values write out 3 bytes + if (j % 4 == 3) { + w[0] = output & 0xff; + w[1] = (output & 0xff00) >> 8; + w[2] = (output & 0xff0000) >> 16; + w += 3; + output = 0; + } + + if (j % 16 == 0) { + scales[j/16] = block->d * bscales[j/16]; + biases[j/16] = -32.0f * scales[j/16]; + } + } + ql += 64; + qh += 32; + bscales += 8; + scales += 8; + biases += 8; + } + """, + ) + w, scales, biases = kernel( + inputs=[data], + grid=(num_blocks, 1, 1), + threadgroup=(256, 1, 1), + output_shapes=[wshape, gshape, gshape], + output_dtypes=[mx.uint8, mx.float16, mx.float16], + ) + w = mx.view(w, dtype=mx.uint32) + return w, scales, biases + + +def parse_gguf_tensor(tensor): + from gguf import GGMLQuantizationType + + if tensor.tensor_type == GGMLQuantizationType.Q4_K: + return parse_q4_k(tensor) + elif tensor.tensor_type == GGMLQuantizationType.Q6_K: + return parse_q6_k(tensor) + elif tensor.tensor_type in [GGMLQuantizationType.F16, GGMLQuantizationType.F32]: + return mx.array(tensor.data) + else: + raise NotImplementedError(f"Type: {tensor.tensor_type} is not yet supported.") + + +def convert_name(name): + name = name.replace("blk", "model.layers") + name = name.replace("attn_norm", "input_layernorm") + name = name.replace("ffn_norm", "post_attention_layernorm") + name = name.replace("attn_q", "self_attn.q_proj") + name = name.replace("attn_k", "self_attn.k_proj") + name = name.replace("attn_v", "self_attn.v_proj") + name = name.replace("attn_output", "self_attn.o_proj") + name = name.replace("ffn_up", "mlp.up_proj") + name = name.replace("ffn_down", "mlp.down_proj") + name = name.replace("ffn_gate", "mlp.gate_proj") + if "output_norm" in name: + name = name.replace("output_norm", "model.norm") + else: + name = name.replace("output", "lm_head") + name = name.replace("token_embd", "model.embed_tokens") + return name + + +FIELD_MAPPING = { + "{model}.embedding_length": "hidden_size", + "{model}.feed_forward_length": "intermediate_size", + "{model}.attention.head_count": "num_attention_heads", + "{model}.attention.head_count_kv": "num_key_value_heads", + "{model}.block_count": "num_hidden_layers", + "{model}.attention.layer_norm_rms_epsilon": "rms_norm_eps", + "{model}.rope.freq_base": "rope_theta", +} + + +QUANT_MAPPING = { + GGMLQuantizationType.Q4_K: { + "bits": 4, + "group_size": 32, + }, + GGMLQuantizationType.Q6_K: { + "bits": 6, + "group_size": 16, + }, + GGMLQuantizationType.F16: None, + GGMLQuantizationType.F32: None, +} + + +# from https://github.com/ggerganov/llama.cpp/blob/40c6d79fb52f995f47507fedfeaae2ac05d9b35c/gguf-py/scripts/gguf_new_metadata.py#L46 +def decode_field(field): + if field and field.types: + main_type = field.types[0] + + if main_type == gguf.GGUFValueType.ARRAY: + sub_type = field.types[-1] + + if sub_type == gguf.GGUFValueType.STRING: + return [ + str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data + ] + else: + return [pv for idx in field.data for pv in field.parts[idx].tolist()] + if main_type == gguf.GGUFValueType.STRING: + return str(bytes(field.parts[-1]), encoding="utf-8") + else: + return field.parts[-1][0] + + return None + + +def load_gguf(model_path: str) -> tuple[nn.Module, TokenizerWrapper]: + with tempfile.TemporaryDirectory() as tmp_dir: + base_name = Path(model_path).name + (Path(tmp_dir) / base_name).symlink_to(model_path) + tokenizer = AutoTokenizer.from_pretrained(tmp_dir, gguf_file=base_name) + + reader = GGUFReader(model_path) + model_type = "qwen2" + config = { + "model_type": model_type, + "vocab_size": tokenizer.vocab_size, + "tie_word_embeddings": False, + } + mapping = {k.format(model=model_type): v for k, v in FIELD_MAPPING.items()} + for field in reader.fields: + if field in mapping: + config[mapping[field]] = decode_field(reader.get_field(field)) + config["quantization"] = {} + + weights = {} + + # Look for any extra gguf files + parts = Path(model_path).name.split("-") + parts[-3] = "*" + gguf_pattern = "-".join(parts) + + for filename in Path(model_path).parent.glob(gguf_pattern): + reader = GGUFReader(str(filename)) + for tensor in reader.tensors: + w = parse_gguf_tensor(tensor) + mx.eval(w) + name = convert_name(tensor.name) + base_name = ".".join(name.split(".")[:-1]) + if quant := QUANT_MAPPING[tensor.tensor_type]: + config["quantization"][base_name] = quant + if len(w) == 3: + w, scales, biases = w + weights[name] = w + weights[base_name + ".scales"] = scales + weights[base_name + ".biases"] = biases + else: + weights[name] = w + + arch = importlib.import_module(f"mlx_lm.models.{config['model_type']}") + model_class, model_args_class = arch.Model, arch.ModelArgs + + model_args = model_args_class.from_dict(config) + model = model_class(model_args) + + quant_config = config["quantization"] + + def pred(p, m): + return quant_config.get(p) + + nn.quantize(model, class_predicate=pred) + model.load_weights(list(weights.items())) + + model.eval() + return model, tokenizer diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d4afd428..74725f12 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -19,6 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports +from .gguf import load_gguf from .models import cache from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer @@ -458,15 +459,20 @@ def load_model( weights = model.sanitize(weights) if (quantization := config.get("quantization", None)) is not None: - # Handle legacy models which may not have everything quantized + def class_predicate(p, m): + # Handle custom per layer quantizations + if p in config["quantization"]: + return config["quantization"][p] if not hasattr(m, "to_quantized"): return False + # Handle legacy models which may not have everything quantized return f"{p}.scales" in weights nn.quantize( model, - **quantization, + group_size=quantization["group_size"], + bits=quantization["bits"], class_predicate=class_predicate, ) @@ -507,6 +513,10 @@ def load( FileNotFoundError: If config file or safetensors are not found. ValueError: If model class or args class are not found. """ + if path_or_hf_repo.endswith(".gguf"): + model, tokenizer = load_gguf(path_or_hf_repo) + return model, tokenizer + model_path = get_model_path(path_or_hf_repo) model = load_model(model_path, lazy, model_config) @@ -669,7 +679,13 @@ def save_weights( def quantize_model( - model: nn.Module, config: dict, q_group_size: int, q_bits: int + model: nn.Module, + config: dict, + q_group_size: int, + q_bits: int, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ) -> Tuple: """ Applies quantization to the model weights. @@ -679,13 +695,31 @@ def quantize_model( config (dict): Model configuration. q_group_size (int): Group size for quantization. q_bits (int): Bits per weight for quantization. + quant_predicate (Callable): A callable that decides how + to quantize each layer based on the path. + Accepts the layer `path`, the `module` and the model `config`. + Returns either a bool to signify quantize/no quantize or + a dict of quantization parameters to pass to `to_quantized`. Returns: Tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) - nn.quantize(model, q_group_size, q_bits) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} + + # Add any custom quantization parameters to the config as we go + def _class_predicate(p, m): + bool_or_params = quant_predicate(p, m, config) + if isinstance(bool_or_params, dict): + quantized_config["quantization"][p] = bool_or_params + return bool_or_params + + nn.quantize( + model, + q_group_size, + q_bits, + class_predicate=_class_predicate if quant_predicate else None, + ) # support hf model tree #957 quantized_config["quantization_config"] = quantized_config["quantization"] quantized_weights = dict(tree_flatten(model.parameters())) @@ -726,6 +760,9 @@ def convert( upload_repo: str = None, revision: Optional[str] = None, dequantize: bool = False, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ): # Check the save path is empty if isinstance(mlx_path, str): @@ -751,7 +788,9 @@ def convert( if quantize: print("[INFO] Quantizing") model.load_weights(list(weights.items())) - weights, config = quantize_model(model, config, q_group_size, q_bits) + weights, config = quantize_model( + model, config, q_group_size, q_bits, quant_predicate=quant_predicate + ) if dequantize: print("[INFO] Dequantizing")