mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
load q4_k_m inefficiently
This commit is contained in:
parent
042280ce50
commit
64ceb62674
@ -1,11 +1,19 @@
|
||||
import importlib
|
||||
import re
|
||||
import tempfile
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import gguf
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from gguf import GGMLQuantizationType
|
||||
from gguf.gguf_reader import GGUFReader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
|
||||
|
||||
class TokenType(IntEnum):
|
||||
NORMAL = 1
|
||||
@ -312,3 +320,297 @@ def convert_to_gguf(
|
||||
output_file_path = output_file_path
|
||||
mx.save_gguf(output_file_path, weights, metadata)
|
||||
print(f"Converted GGUF model saved as: {output_file_path}")
|
||||
|
||||
|
||||
# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L568
|
||||
def parse_q4_k(tensor):
|
||||
bits = 4
|
||||
pack_factor = 32 // bits
|
||||
group_size = 32
|
||||
block_size = 144
|
||||
|
||||
data = mx.array(tensor.data)
|
||||
shape = [int(d) for d in reversed(tensor.shape)]
|
||||
wshape = (*shape[:-1], shape[-1] // pack_factor)
|
||||
gshape = (*shape[:-1], shape[-1] // group_size)
|
||||
num_blocks = data.size // block_size
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="parse_q4_k",
|
||||
input_names=["data"],
|
||||
output_names=["w", "scales", "biases"],
|
||||
header="""
|
||||
typedef struct {
|
||||
float16_t d;
|
||||
float16_t d_min;
|
||||
uint8_t scales[12];
|
||||
uint8_t qs[128];
|
||||
} block_q4_K;
|
||||
""",
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
|
||||
const device block_q4_K* block = reinterpret_cast<const device block_q4_K*>(data);
|
||||
|
||||
block += elem;
|
||||
w += elem * 32;
|
||||
scales += elem * 8;
|
||||
biases += elem * 8;
|
||||
|
||||
// First unpack the quantized scales/biases
|
||||
for (int j = 0; j < 8; j++) {
|
||||
uint8_t d, m;
|
||||
if (j < 4) {
|
||||
d = block->scales[j] & 63;
|
||||
m = block->scales[j + 4] & 63;
|
||||
} else {
|
||||
d = (block->scales[j + 4] & 0xF) | ((block->scales[j - 4] >> 6) << 4);
|
||||
m = (block->scales[j + 4] >> 4) | ((block->scales[j - 0] >> 6) << 4);
|
||||
}
|
||||
scales[j] = d * block->d;
|
||||
biases[j] = -m * block->d_min;
|
||||
}
|
||||
|
||||
uint32_t outputs[32] = {0};
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 32; j++) {
|
||||
uint8_t val = block->qs[i * 32 + j] & 0xf;
|
||||
int index = i * 8 + (j / 8);
|
||||
outputs[index] += val << (4 * (j % 8));
|
||||
}
|
||||
for (int j = 0; j < 32; j++) {
|
||||
uint8_t val = block->qs[i * 32 + j] >> 4;
|
||||
int index = i * 8 + 4 + (j / 8);
|
||||
outputs[index] += val << (4 * (j % 8));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < 32; i++) {
|
||||
w[i] = outputs[i];
|
||||
}
|
||||
""",
|
||||
)
|
||||
w, scales, biases = kernel(
|
||||
inputs=[data],
|
||||
grid=(num_blocks, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[wshape, gshape, gshape],
|
||||
output_dtypes=[mx.uint32, mx.float16, mx.float16],
|
||||
)
|
||||
return w, scales, biases
|
||||
|
||||
|
||||
# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L658
|
||||
def parse_q6_k(tensor):
|
||||
bits = 6
|
||||
group_size = 16
|
||||
block_size = 210
|
||||
|
||||
data = mx.array(tensor.data)
|
||||
shape = [int(d) for d in reversed(tensor.shape)]
|
||||
wshape = (*shape[:-1], shape[-1] * bits // 8)
|
||||
gshape = (*shape[:-1], shape[-1] // group_size)
|
||||
num_blocks = data.size // block_size
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="parse_q6_k",
|
||||
input_names=["data"],
|
||||
output_names=["w", "scales", "biases"],
|
||||
header="""
|
||||
typedef struct {
|
||||
uint8_t ql[128]; // quants, lower 4 bits
|
||||
uint8_t qh[64]; // quants, upper 2 bits
|
||||
int8_t scales[16]; // scales, quantized with 8 bits
|
||||
float16_t d; // super-block scale
|
||||
} block_q6_K;
|
||||
""",
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
|
||||
const device block_q6_K* block = reinterpret_cast<const device block_q6_K*>(data);
|
||||
|
||||
block += elem;
|
||||
w += elem * 192;
|
||||
scales += elem * 16;
|
||||
biases += elem * 16;
|
||||
|
||||
const device uint8_t* ql = &block->ql[0];
|
||||
const device uint8_t* qh = &block->qh[0];
|
||||
const device int8_t* bscales = &block->scales[0];
|
||||
|
||||
uint32_t output = 0;
|
||||
for (int cluster = 0; cluster < 2; cluster++) {
|
||||
for (uint64_t j = 0; j < 128; j++) {
|
||||
uint8_t val = ((ql[j%64] >> (j/64*4)) & 0xF) | (((qh[j%32] >> (j/32*2)) & 3) << 4);
|
||||
|
||||
output += val << (6 * (j % 4));
|
||||
|
||||
// Every 4 values write out 3 bytes
|
||||
if (j % 4 == 3) {
|
||||
w[0] = output & 0xff;
|
||||
w[1] = (output & 0xff00) >> 8;
|
||||
w[2] = (output & 0xff0000) >> 16;
|
||||
w += 3;
|
||||
output = 0;
|
||||
}
|
||||
|
||||
if (j % 16 == 0) {
|
||||
scales[j/16] = block->d * bscales[j/16];
|
||||
biases[j/16] = -32.0f * scales[j/16];
|
||||
}
|
||||
}
|
||||
ql += 64;
|
||||
qh += 32;
|
||||
bscales += 8;
|
||||
scales += 8;
|
||||
biases += 8;
|
||||
}
|
||||
""",
|
||||
)
|
||||
w, scales, biases = kernel(
|
||||
inputs=[data],
|
||||
grid=(num_blocks, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[wshape, gshape, gshape],
|
||||
output_dtypes=[mx.uint8, mx.float16, mx.float16],
|
||||
)
|
||||
w = mx.view(w, dtype=mx.uint32)
|
||||
return w, scales, biases
|
||||
|
||||
|
||||
def parse_gguf_tensor(tensor):
|
||||
from gguf import GGMLQuantizationType
|
||||
|
||||
if tensor.tensor_type == GGMLQuantizationType.Q4_K:
|
||||
return parse_q4_k(tensor)
|
||||
elif tensor.tensor_type == GGMLQuantizationType.Q6_K:
|
||||
return parse_q6_k(tensor)
|
||||
elif tensor.tensor_type in [GGMLQuantizationType.F16, GGMLQuantizationType.F32]:
|
||||
return mx.array(tensor.data)
|
||||
else:
|
||||
raise NotImplementedError(f"Type: {tensor.tensor_type} is not yet supported.")
|
||||
|
||||
|
||||
def convert_name(name):
|
||||
name = name.replace("blk", "model.layers")
|
||||
name = name.replace("attn_norm", "input_layernorm")
|
||||
name = name.replace("ffn_norm", "post_attention_layernorm")
|
||||
name = name.replace("attn_q", "self_attn.q_proj")
|
||||
name = name.replace("attn_k", "self_attn.k_proj")
|
||||
name = name.replace("attn_v", "self_attn.v_proj")
|
||||
name = name.replace("attn_output", "self_attn.o_proj")
|
||||
name = name.replace("ffn_up", "mlp.up_proj")
|
||||
name = name.replace("ffn_down", "mlp.down_proj")
|
||||
name = name.replace("ffn_gate", "mlp.gate_proj")
|
||||
if "output_norm" in name:
|
||||
name = name.replace("output_norm", "model.norm")
|
||||
else:
|
||||
name = name.replace("output", "lm_head")
|
||||
name = name.replace("token_embd", "model.embed_tokens")
|
||||
return name
|
||||
|
||||
|
||||
FIELD_MAPPING = {
|
||||
"{model}.embedding_length": "hidden_size",
|
||||
"{model}.feed_forward_length": "intermediate_size",
|
||||
"{model}.attention.head_count": "num_attention_heads",
|
||||
"{model}.attention.head_count_kv": "num_key_value_heads",
|
||||
"{model}.block_count": "num_hidden_layers",
|
||||
"{model}.attention.layer_norm_rms_epsilon": "rms_norm_eps",
|
||||
"{model}.rope.freq_base": "rope_theta",
|
||||
}
|
||||
|
||||
|
||||
QUANT_MAPPING = {
|
||||
GGMLQuantizationType.Q4_K: {
|
||||
"bits": 4,
|
||||
"group_size": 32,
|
||||
},
|
||||
GGMLQuantizationType.Q6_K: {
|
||||
"bits": 6,
|
||||
"group_size": 16,
|
||||
},
|
||||
GGMLQuantizationType.F16: None,
|
||||
GGMLQuantizationType.F32: None,
|
||||
}
|
||||
|
||||
|
||||
# from https://github.com/ggerganov/llama.cpp/blob/40c6d79fb52f995f47507fedfeaae2ac05d9b35c/gguf-py/scripts/gguf_new_metadata.py#L46
|
||||
def decode_field(field):
|
||||
if field and field.types:
|
||||
main_type = field.types[0]
|
||||
|
||||
if main_type == gguf.GGUFValueType.ARRAY:
|
||||
sub_type = field.types[-1]
|
||||
|
||||
if sub_type == gguf.GGUFValueType.STRING:
|
||||
return [
|
||||
str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data
|
||||
]
|
||||
else:
|
||||
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
|
||||
if main_type == gguf.GGUFValueType.STRING:
|
||||
return str(bytes(field.parts[-1]), encoding="utf-8")
|
||||
else:
|
||||
return field.parts[-1][0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def load_gguf(model_path: str) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
base_name = Path(model_path).name
|
||||
(Path(tmp_dir) / base_name).symlink_to(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, gguf_file=base_name)
|
||||
|
||||
reader = GGUFReader(model_path)
|
||||
model_type = "qwen2"
|
||||
config = {
|
||||
"model_type": model_type,
|
||||
"vocab_size": tokenizer.vocab_size,
|
||||
"tie_word_embeddings": False,
|
||||
}
|
||||
mapping = {k.format(model=model_type): v for k, v in FIELD_MAPPING.items()}
|
||||
for field in reader.fields:
|
||||
if field in mapping:
|
||||
config[mapping[field]] = decode_field(reader.get_field(field))
|
||||
config["quantization"] = {}
|
||||
|
||||
weights = {}
|
||||
|
||||
# Look for any extra gguf files
|
||||
parts = Path(model_path).name.split("-")
|
||||
parts[-3] = "*"
|
||||
gguf_pattern = "-".join(parts)
|
||||
|
||||
for filename in Path(model_path).parent.glob(gguf_pattern):
|
||||
reader = GGUFReader(str(filename))
|
||||
for tensor in reader.tensors:
|
||||
w = parse_gguf_tensor(tensor)
|
||||
mx.eval(w)
|
||||
name = convert_name(tensor.name)
|
||||
base_name = ".".join(name.split(".")[:-1])
|
||||
if quant := QUANT_MAPPING[tensor.tensor_type]:
|
||||
config["quantization"][base_name] = quant
|
||||
if len(w) == 3:
|
||||
w, scales, biases = w
|
||||
weights[name] = w
|
||||
weights[base_name + ".scales"] = scales
|
||||
weights[base_name + ".biases"] = biases
|
||||
else:
|
||||
weights[name] = w
|
||||
|
||||
arch = importlib.import_module(f"mlx_lm.models.{config['model_type']}")
|
||||
model_class, model_args_class = arch.Model, arch.ModelArgs
|
||||
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = model_class(model_args)
|
||||
|
||||
quant_config = config["quantization"]
|
||||
|
||||
def pred(p, m):
|
||||
return quant_config.get(p)
|
||||
|
||||
nn.quantize(model, class_predicate=pred)
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
@ -19,6 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .gguf import load_gguf
|
||||
from .models import cache
|
||||
from .sample_utils import make_logits_processors, make_sampler
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
@ -458,15 +459,20 @@ def load_model(
|
||||
weights = model.sanitize(weights)
|
||||
|
||||
if (quantization := config.get("quantization", None)) is not None:
|
||||
# Handle legacy models which may not have everything quantized
|
||||
|
||||
def class_predicate(p, m):
|
||||
# Handle custom per layer quantizations
|
||||
if p in config["quantization"]:
|
||||
return config["quantization"][p]
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
# Handle legacy models which may not have everything quantized
|
||||
return f"{p}.scales" in weights
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
**quantization,
|
||||
group_size=quantization["group_size"],
|
||||
bits=quantization["bits"],
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
||||
@ -507,6 +513,10 @@ def load(
|
||||
FileNotFoundError: If config file or safetensors are not found.
|
||||
ValueError: If model class or args class are not found.
|
||||
"""
|
||||
if path_or_hf_repo.endswith(".gguf"):
|
||||
model, tokenizer = load_gguf(path_or_hf_repo)
|
||||
return model, tokenizer
|
||||
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
model = load_model(model_path, lazy, model_config)
|
||||
@ -669,7 +679,13 @@ def save_weights(
|
||||
|
||||
|
||||
def quantize_model(
|
||||
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
||||
model: nn.Module,
|
||||
config: dict,
|
||||
q_group_size: int,
|
||||
q_bits: int,
|
||||
quant_predicate: Optional[
|
||||
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
||||
] = None,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Applies quantization to the model weights.
|
||||
@ -679,13 +695,31 @@ def quantize_model(
|
||||
config (dict): Model configuration.
|
||||
q_group_size (int): Group size for quantization.
|
||||
q_bits (int): Bits per weight for quantization.
|
||||
quant_predicate (Callable): A callable that decides how
|
||||
to quantize each layer based on the path.
|
||||
Accepts the layer `path`, the `module` and the model `config`.
|
||||
Returns either a bool to signify quantize/no quantize or
|
||||
a dict of quantization parameters to pass to `to_quantized`.
|
||||
|
||||
Returns:
|
||||
Tuple: Tuple containing quantized weights and config.
|
||||
"""
|
||||
quantized_config = copy.deepcopy(config)
|
||||
nn.quantize(model, q_group_size, q_bits)
|
||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||
|
||||
# Add any custom quantization parameters to the config as we go
|
||||
def _class_predicate(p, m):
|
||||
bool_or_params = quant_predicate(p, m, config)
|
||||
if isinstance(bool_or_params, dict):
|
||||
quantized_config["quantization"][p] = bool_or_params
|
||||
return bool_or_params
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
q_group_size,
|
||||
q_bits,
|
||||
class_predicate=_class_predicate if quant_predicate else None,
|
||||
)
|
||||
# support hf model tree #957
|
||||
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
@ -726,6 +760,9 @@ def convert(
|
||||
upload_repo: str = None,
|
||||
revision: Optional[str] = None,
|
||||
dequantize: bool = False,
|
||||
quant_predicate: Optional[
|
||||
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
||||
] = None,
|
||||
):
|
||||
# Check the save path is empty
|
||||
if isinstance(mlx_path, str):
|
||||
@ -751,7 +788,9 @@ def convert(
|
||||
if quantize:
|
||||
print("[INFO] Quantizing")
|
||||
model.load_weights(list(weights.items()))
|
||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||
weights, config = quantize_model(
|
||||
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
|
||||
)
|
||||
|
||||
if dequantize:
|
||||
print("[INFO] Dequantizing")
|
||||
|
Loading…
Reference in New Issue
Block a user