mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
load q4_k_m inefficiently
This commit is contained in:
parent
042280ce50
commit
64ceb62674
@ -1,11 +1,19 @@
|
|||||||
|
import importlib
|
||||||
import re
|
import re
|
||||||
|
import tempfile
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
import gguf
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from gguf import GGMLQuantizationType
|
||||||
|
from gguf.gguf_reader import GGUFReader
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from .tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
|
||||||
class TokenType(IntEnum):
|
class TokenType(IntEnum):
|
||||||
NORMAL = 1
|
NORMAL = 1
|
||||||
@ -312,3 +320,297 @@ def convert_to_gguf(
|
|||||||
output_file_path = output_file_path
|
output_file_path = output_file_path
|
||||||
mx.save_gguf(output_file_path, weights, metadata)
|
mx.save_gguf(output_file_path, weights, metadata)
|
||||||
print(f"Converted GGUF model saved as: {output_file_path}")
|
print(f"Converted GGUF model saved as: {output_file_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L568
|
||||||
|
def parse_q4_k(tensor):
|
||||||
|
bits = 4
|
||||||
|
pack_factor = 32 // bits
|
||||||
|
group_size = 32
|
||||||
|
block_size = 144
|
||||||
|
|
||||||
|
data = mx.array(tensor.data)
|
||||||
|
shape = [int(d) for d in reversed(tensor.shape)]
|
||||||
|
wshape = (*shape[:-1], shape[-1] // pack_factor)
|
||||||
|
gshape = (*shape[:-1], shape[-1] // group_size)
|
||||||
|
num_blocks = data.size // block_size
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="parse_q4_k",
|
||||||
|
input_names=["data"],
|
||||||
|
output_names=["w", "scales", "biases"],
|
||||||
|
header="""
|
||||||
|
typedef struct {
|
||||||
|
float16_t d;
|
||||||
|
float16_t d_min;
|
||||||
|
uint8_t scales[12];
|
||||||
|
uint8_t qs[128];
|
||||||
|
} block_q4_K;
|
||||||
|
""",
|
||||||
|
source="""
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
|
||||||
|
const device block_q4_K* block = reinterpret_cast<const device block_q4_K*>(data);
|
||||||
|
|
||||||
|
block += elem;
|
||||||
|
w += elem * 32;
|
||||||
|
scales += elem * 8;
|
||||||
|
biases += elem * 8;
|
||||||
|
|
||||||
|
// First unpack the quantized scales/biases
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
uint8_t d, m;
|
||||||
|
if (j < 4) {
|
||||||
|
d = block->scales[j] & 63;
|
||||||
|
m = block->scales[j + 4] & 63;
|
||||||
|
} else {
|
||||||
|
d = (block->scales[j + 4] & 0xF) | ((block->scales[j - 4] >> 6) << 4);
|
||||||
|
m = (block->scales[j + 4] >> 4) | ((block->scales[j - 0] >> 6) << 4);
|
||||||
|
}
|
||||||
|
scales[j] = d * block->d;
|
||||||
|
biases[j] = -m * block->d_min;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t outputs[32] = {0};
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
for (int j = 0; j < 32; j++) {
|
||||||
|
uint8_t val = block->qs[i * 32 + j] & 0xf;
|
||||||
|
int index = i * 8 + (j / 8);
|
||||||
|
outputs[index] += val << (4 * (j % 8));
|
||||||
|
}
|
||||||
|
for (int j = 0; j < 32; j++) {
|
||||||
|
uint8_t val = block->qs[i * 32 + j] >> 4;
|
||||||
|
int index = i * 8 + 4 + (j / 8);
|
||||||
|
outputs[index] += val << (4 * (j % 8));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 32; i++) {
|
||||||
|
w[i] = outputs[i];
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
w, scales, biases = kernel(
|
||||||
|
inputs=[data],
|
||||||
|
grid=(num_blocks, 1, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
output_shapes=[wshape, gshape, gshape],
|
||||||
|
output_dtypes=[mx.uint32, mx.float16, mx.float16],
|
||||||
|
)
|
||||||
|
return w, scales, biases
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L658
|
||||||
|
def parse_q6_k(tensor):
|
||||||
|
bits = 6
|
||||||
|
group_size = 16
|
||||||
|
block_size = 210
|
||||||
|
|
||||||
|
data = mx.array(tensor.data)
|
||||||
|
shape = [int(d) for d in reversed(tensor.shape)]
|
||||||
|
wshape = (*shape[:-1], shape[-1] * bits // 8)
|
||||||
|
gshape = (*shape[:-1], shape[-1] // group_size)
|
||||||
|
num_blocks = data.size // block_size
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="parse_q6_k",
|
||||||
|
input_names=["data"],
|
||||||
|
output_names=["w", "scales", "biases"],
|
||||||
|
header="""
|
||||||
|
typedef struct {
|
||||||
|
uint8_t ql[128]; // quants, lower 4 bits
|
||||||
|
uint8_t qh[64]; // quants, upper 2 bits
|
||||||
|
int8_t scales[16]; // scales, quantized with 8 bits
|
||||||
|
float16_t d; // super-block scale
|
||||||
|
} block_q6_K;
|
||||||
|
""",
|
||||||
|
source="""
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
|
||||||
|
const device block_q6_K* block = reinterpret_cast<const device block_q6_K*>(data);
|
||||||
|
|
||||||
|
block += elem;
|
||||||
|
w += elem * 192;
|
||||||
|
scales += elem * 16;
|
||||||
|
biases += elem * 16;
|
||||||
|
|
||||||
|
const device uint8_t* ql = &block->ql[0];
|
||||||
|
const device uint8_t* qh = &block->qh[0];
|
||||||
|
const device int8_t* bscales = &block->scales[0];
|
||||||
|
|
||||||
|
uint32_t output = 0;
|
||||||
|
for (int cluster = 0; cluster < 2; cluster++) {
|
||||||
|
for (uint64_t j = 0; j < 128; j++) {
|
||||||
|
uint8_t val = ((ql[j%64] >> (j/64*4)) & 0xF) | (((qh[j%32] >> (j/32*2)) & 3) << 4);
|
||||||
|
|
||||||
|
output += val << (6 * (j % 4));
|
||||||
|
|
||||||
|
// Every 4 values write out 3 bytes
|
||||||
|
if (j % 4 == 3) {
|
||||||
|
w[0] = output & 0xff;
|
||||||
|
w[1] = (output & 0xff00) >> 8;
|
||||||
|
w[2] = (output & 0xff0000) >> 16;
|
||||||
|
w += 3;
|
||||||
|
output = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (j % 16 == 0) {
|
||||||
|
scales[j/16] = block->d * bscales[j/16];
|
||||||
|
biases[j/16] = -32.0f * scales[j/16];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ql += 64;
|
||||||
|
qh += 32;
|
||||||
|
bscales += 8;
|
||||||
|
scales += 8;
|
||||||
|
biases += 8;
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
w, scales, biases = kernel(
|
||||||
|
inputs=[data],
|
||||||
|
grid=(num_blocks, 1, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
output_shapes=[wshape, gshape, gshape],
|
||||||
|
output_dtypes=[mx.uint8, mx.float16, mx.float16],
|
||||||
|
)
|
||||||
|
w = mx.view(w, dtype=mx.uint32)
|
||||||
|
return w, scales, biases
|
||||||
|
|
||||||
|
|
||||||
|
def parse_gguf_tensor(tensor):
|
||||||
|
from gguf import GGMLQuantizationType
|
||||||
|
|
||||||
|
if tensor.tensor_type == GGMLQuantizationType.Q4_K:
|
||||||
|
return parse_q4_k(tensor)
|
||||||
|
elif tensor.tensor_type == GGMLQuantizationType.Q6_K:
|
||||||
|
return parse_q6_k(tensor)
|
||||||
|
elif tensor.tensor_type in [GGMLQuantizationType.F16, GGMLQuantizationType.F32]:
|
||||||
|
return mx.array(tensor.data)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Type: {tensor.tensor_type} is not yet supported.")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_name(name):
|
||||||
|
name = name.replace("blk", "model.layers")
|
||||||
|
name = name.replace("attn_norm", "input_layernorm")
|
||||||
|
name = name.replace("ffn_norm", "post_attention_layernorm")
|
||||||
|
name = name.replace("attn_q", "self_attn.q_proj")
|
||||||
|
name = name.replace("attn_k", "self_attn.k_proj")
|
||||||
|
name = name.replace("attn_v", "self_attn.v_proj")
|
||||||
|
name = name.replace("attn_output", "self_attn.o_proj")
|
||||||
|
name = name.replace("ffn_up", "mlp.up_proj")
|
||||||
|
name = name.replace("ffn_down", "mlp.down_proj")
|
||||||
|
name = name.replace("ffn_gate", "mlp.gate_proj")
|
||||||
|
if "output_norm" in name:
|
||||||
|
name = name.replace("output_norm", "model.norm")
|
||||||
|
else:
|
||||||
|
name = name.replace("output", "lm_head")
|
||||||
|
name = name.replace("token_embd", "model.embed_tokens")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
FIELD_MAPPING = {
|
||||||
|
"{model}.embedding_length": "hidden_size",
|
||||||
|
"{model}.feed_forward_length": "intermediate_size",
|
||||||
|
"{model}.attention.head_count": "num_attention_heads",
|
||||||
|
"{model}.attention.head_count_kv": "num_key_value_heads",
|
||||||
|
"{model}.block_count": "num_hidden_layers",
|
||||||
|
"{model}.attention.layer_norm_rms_epsilon": "rms_norm_eps",
|
||||||
|
"{model}.rope.freq_base": "rope_theta",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
QUANT_MAPPING = {
|
||||||
|
GGMLQuantizationType.Q4_K: {
|
||||||
|
"bits": 4,
|
||||||
|
"group_size": 32,
|
||||||
|
},
|
||||||
|
GGMLQuantizationType.Q6_K: {
|
||||||
|
"bits": 6,
|
||||||
|
"group_size": 16,
|
||||||
|
},
|
||||||
|
GGMLQuantizationType.F16: None,
|
||||||
|
GGMLQuantizationType.F32: None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/ggerganov/llama.cpp/blob/40c6d79fb52f995f47507fedfeaae2ac05d9b35c/gguf-py/scripts/gguf_new_metadata.py#L46
|
||||||
|
def decode_field(field):
|
||||||
|
if field and field.types:
|
||||||
|
main_type = field.types[0]
|
||||||
|
|
||||||
|
if main_type == gguf.GGUFValueType.ARRAY:
|
||||||
|
sub_type = field.types[-1]
|
||||||
|
|
||||||
|
if sub_type == gguf.GGUFValueType.STRING:
|
||||||
|
return [
|
||||||
|
str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
|
||||||
|
if main_type == gguf.GGUFValueType.STRING:
|
||||||
|
return str(bytes(field.parts[-1]), encoding="utf-8")
|
||||||
|
else:
|
||||||
|
return field.parts[-1][0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_gguf(model_path: str) -> tuple[nn.Module, TokenizerWrapper]:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
base_name = Path(model_path).name
|
||||||
|
(Path(tmp_dir) / base_name).symlink_to(model_path)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, gguf_file=base_name)
|
||||||
|
|
||||||
|
reader = GGUFReader(model_path)
|
||||||
|
model_type = "qwen2"
|
||||||
|
config = {
|
||||||
|
"model_type": model_type,
|
||||||
|
"vocab_size": tokenizer.vocab_size,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
}
|
||||||
|
mapping = {k.format(model=model_type): v for k, v in FIELD_MAPPING.items()}
|
||||||
|
for field in reader.fields:
|
||||||
|
if field in mapping:
|
||||||
|
config[mapping[field]] = decode_field(reader.get_field(field))
|
||||||
|
config["quantization"] = {}
|
||||||
|
|
||||||
|
weights = {}
|
||||||
|
|
||||||
|
# Look for any extra gguf files
|
||||||
|
parts = Path(model_path).name.split("-")
|
||||||
|
parts[-3] = "*"
|
||||||
|
gguf_pattern = "-".join(parts)
|
||||||
|
|
||||||
|
for filename in Path(model_path).parent.glob(gguf_pattern):
|
||||||
|
reader = GGUFReader(str(filename))
|
||||||
|
for tensor in reader.tensors:
|
||||||
|
w = parse_gguf_tensor(tensor)
|
||||||
|
mx.eval(w)
|
||||||
|
name = convert_name(tensor.name)
|
||||||
|
base_name = ".".join(name.split(".")[:-1])
|
||||||
|
if quant := QUANT_MAPPING[tensor.tensor_type]:
|
||||||
|
config["quantization"][base_name] = quant
|
||||||
|
if len(w) == 3:
|
||||||
|
w, scales, biases = w
|
||||||
|
weights[name] = w
|
||||||
|
weights[base_name + ".scales"] = scales
|
||||||
|
weights[base_name + ".biases"] = biases
|
||||||
|
else:
|
||||||
|
weights[name] = w
|
||||||
|
|
||||||
|
arch = importlib.import_module(f"mlx_lm.models.{config['model_type']}")
|
||||||
|
model_class, model_args_class = arch.Model, arch.ModelArgs
|
||||||
|
|
||||||
|
model_args = model_args_class.from_dict(config)
|
||||||
|
model = model_class(model_args)
|
||||||
|
|
||||||
|
quant_config = config["quantization"]
|
||||||
|
|
||||||
|
def pred(p, m):
|
||||||
|
return quant_config.get(p)
|
||||||
|
|
||||||
|
nn.quantize(model, class_predicate=pred)
|
||||||
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
return model, tokenizer
|
||||||
|
@ -19,6 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce
|
|||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
|
from .gguf import load_gguf
|
||||||
from .models import cache
|
from .models import cache
|
||||||
from .sample_utils import make_logits_processors, make_sampler
|
from .sample_utils import make_logits_processors, make_sampler
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
@ -458,15 +459,20 @@ def load_model(
|
|||||||
weights = model.sanitize(weights)
|
weights = model.sanitize(weights)
|
||||||
|
|
||||||
if (quantization := config.get("quantization", None)) is not None:
|
if (quantization := config.get("quantization", None)) is not None:
|
||||||
# Handle legacy models which may not have everything quantized
|
|
||||||
def class_predicate(p, m):
|
def class_predicate(p, m):
|
||||||
|
# Handle custom per layer quantizations
|
||||||
|
if p in config["quantization"]:
|
||||||
|
return config["quantization"][p]
|
||||||
if not hasattr(m, "to_quantized"):
|
if not hasattr(m, "to_quantized"):
|
||||||
return False
|
return False
|
||||||
|
# Handle legacy models which may not have everything quantized
|
||||||
return f"{p}.scales" in weights
|
return f"{p}.scales" in weights
|
||||||
|
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
model,
|
model,
|
||||||
**quantization,
|
group_size=quantization["group_size"],
|
||||||
|
bits=quantization["bits"],
|
||||||
class_predicate=class_predicate,
|
class_predicate=class_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -507,6 +513,10 @@ def load(
|
|||||||
FileNotFoundError: If config file or safetensors are not found.
|
FileNotFoundError: If config file or safetensors are not found.
|
||||||
ValueError: If model class or args class are not found.
|
ValueError: If model class or args class are not found.
|
||||||
"""
|
"""
|
||||||
|
if path_or_hf_repo.endswith(".gguf"):
|
||||||
|
model, tokenizer = load_gguf(path_or_hf_repo)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
model = load_model(model_path, lazy, model_config)
|
model = load_model(model_path, lazy, model_config)
|
||||||
@ -669,7 +679,13 @@ def save_weights(
|
|||||||
|
|
||||||
|
|
||||||
def quantize_model(
|
def quantize_model(
|
||||||
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
model: nn.Module,
|
||||||
|
config: dict,
|
||||||
|
q_group_size: int,
|
||||||
|
q_bits: int,
|
||||||
|
quant_predicate: Optional[
|
||||||
|
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
||||||
|
] = None,
|
||||||
) -> Tuple:
|
) -> Tuple:
|
||||||
"""
|
"""
|
||||||
Applies quantization to the model weights.
|
Applies quantization to the model weights.
|
||||||
@ -679,13 +695,31 @@ def quantize_model(
|
|||||||
config (dict): Model configuration.
|
config (dict): Model configuration.
|
||||||
q_group_size (int): Group size for quantization.
|
q_group_size (int): Group size for quantization.
|
||||||
q_bits (int): Bits per weight for quantization.
|
q_bits (int): Bits per weight for quantization.
|
||||||
|
quant_predicate (Callable): A callable that decides how
|
||||||
|
to quantize each layer based on the path.
|
||||||
|
Accepts the layer `path`, the `module` and the model `config`.
|
||||||
|
Returns either a bool to signify quantize/no quantize or
|
||||||
|
a dict of quantization parameters to pass to `to_quantized`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple: Tuple containing quantized weights and config.
|
Tuple: Tuple containing quantized weights and config.
|
||||||
"""
|
"""
|
||||||
quantized_config = copy.deepcopy(config)
|
quantized_config = copy.deepcopy(config)
|
||||||
nn.quantize(model, q_group_size, q_bits)
|
|
||||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||||
|
|
||||||
|
# Add any custom quantization parameters to the config as we go
|
||||||
|
def _class_predicate(p, m):
|
||||||
|
bool_or_params = quant_predicate(p, m, config)
|
||||||
|
if isinstance(bool_or_params, dict):
|
||||||
|
quantized_config["quantization"][p] = bool_or_params
|
||||||
|
return bool_or_params
|
||||||
|
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
q_group_size,
|
||||||
|
q_bits,
|
||||||
|
class_predicate=_class_predicate if quant_predicate else None,
|
||||||
|
)
|
||||||
# support hf model tree #957
|
# support hf model tree #957
|
||||||
quantized_config["quantization_config"] = quantized_config["quantization"]
|
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||||
@ -726,6 +760,9 @@ def convert(
|
|||||||
upload_repo: str = None,
|
upload_repo: str = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
dequantize: bool = False,
|
dequantize: bool = False,
|
||||||
|
quant_predicate: Optional[
|
||||||
|
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
||||||
|
] = None,
|
||||||
):
|
):
|
||||||
# Check the save path is empty
|
# Check the save path is empty
|
||||||
if isinstance(mlx_path, str):
|
if isinstance(mlx_path, str):
|
||||||
@ -751,7 +788,9 @@ def convert(
|
|||||||
if quantize:
|
if quantize:
|
||||||
print("[INFO] Quantizing")
|
print("[INFO] Quantizing")
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
weights, config = quantize_model(
|
||||||
|
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
|
||||||
|
)
|
||||||
|
|
||||||
if dequantize:
|
if dequantize:
|
||||||
print("[INFO] Dequantizing")
|
print("[INFO] Dequantizing")
|
||||||
|
Loading…
Reference in New Issue
Block a user