load q4_k_m inefficiently

This commit is contained in:
Alex Barron 2024-12-03 19:54:57 -08:00
parent 042280ce50
commit 64ceb62674
2 changed files with 346 additions and 5 deletions

View File

@ -1,11 +1,19 @@
import importlib
import re
import tempfile
from enum import IntEnum
from pathlib import Path
from typing import Iterable, Optional, Set, Tuple, Union
import gguf
import mlx.core as mx
import mlx.nn as nn
from gguf import GGMLQuantizationType
from gguf.gguf_reader import GGUFReader
from transformers import AutoTokenizer
from .tokenizer_utils import TokenizerWrapper
class TokenType(IntEnum):
NORMAL = 1
@ -312,3 +320,297 @@ def convert_to_gguf(
output_file_path = output_file_path
mx.save_gguf(output_file_path, weights, metadata)
print(f"Converted GGUF model saved as: {output_file_path}")
# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L568
def parse_q4_k(tensor):
bits = 4
pack_factor = 32 // bits
group_size = 32
block_size = 144
data = mx.array(tensor.data)
shape = [int(d) for d in reversed(tensor.shape)]
wshape = (*shape[:-1], shape[-1] // pack_factor)
gshape = (*shape[:-1], shape[-1] // group_size)
num_blocks = data.size // block_size
kernel = mx.fast.metal_kernel(
name="parse_q4_k",
input_names=["data"],
output_names=["w", "scales", "biases"],
header="""
typedef struct {
float16_t d;
float16_t d_min;
uint8_t scales[12];
uint8_t qs[128];
} block_q4_K;
""",
source="""
uint elem = thread_position_in_grid.x;
const device block_q4_K* block = reinterpret_cast<const device block_q4_K*>(data);
block += elem;
w += elem * 32;
scales += elem * 8;
biases += elem * 8;
// First unpack the quantized scales/biases
for (int j = 0; j < 8; j++) {
uint8_t d, m;
if (j < 4) {
d = block->scales[j] & 63;
m = block->scales[j + 4] & 63;
} else {
d = (block->scales[j + 4] & 0xF) | ((block->scales[j - 4] >> 6) << 4);
m = (block->scales[j + 4] >> 4) | ((block->scales[j - 0] >> 6) << 4);
}
scales[j] = d * block->d;
biases[j] = -m * block->d_min;
}
uint32_t outputs[32] = {0};
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 32; j++) {
uint8_t val = block->qs[i * 32 + j] & 0xf;
int index = i * 8 + (j / 8);
outputs[index] += val << (4 * (j % 8));
}
for (int j = 0; j < 32; j++) {
uint8_t val = block->qs[i * 32 + j] >> 4;
int index = i * 8 + 4 + (j / 8);
outputs[index] += val << (4 * (j % 8));
}
}
for (int i = 0; i < 32; i++) {
w[i] = outputs[i];
}
""",
)
w, scales, biases = kernel(
inputs=[data],
grid=(num_blocks, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[wshape, gshape, gshape],
output_dtypes=[mx.uint32, mx.float16, mx.float16],
)
return w, scales, biases
# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L658
def parse_q6_k(tensor):
bits = 6
group_size = 16
block_size = 210
data = mx.array(tensor.data)
shape = [int(d) for d in reversed(tensor.shape)]
wshape = (*shape[:-1], shape[-1] * bits // 8)
gshape = (*shape[:-1], shape[-1] // group_size)
num_blocks = data.size // block_size
kernel = mx.fast.metal_kernel(
name="parse_q6_k",
input_names=["data"],
output_names=["w", "scales", "biases"],
header="""
typedef struct {
uint8_t ql[128]; // quants, lower 4 bits
uint8_t qh[64]; // quants, upper 2 bits
int8_t scales[16]; // scales, quantized with 8 bits
float16_t d; // super-block scale
} block_q6_K;
""",
source="""
uint elem = thread_position_in_grid.x;
const device block_q6_K* block = reinterpret_cast<const device block_q6_K*>(data);
block += elem;
w += elem * 192;
scales += elem * 16;
biases += elem * 16;
const device uint8_t* ql = &block->ql[0];
const device uint8_t* qh = &block->qh[0];
const device int8_t* bscales = &block->scales[0];
uint32_t output = 0;
for (int cluster = 0; cluster < 2; cluster++) {
for (uint64_t j = 0; j < 128; j++) {
uint8_t val = ((ql[j%64] >> (j/64*4)) & 0xF) | (((qh[j%32] >> (j/32*2)) & 3) << 4);
output += val << (6 * (j % 4));
// Every 4 values write out 3 bytes
if (j % 4 == 3) {
w[0] = output & 0xff;
w[1] = (output & 0xff00) >> 8;
w[2] = (output & 0xff0000) >> 16;
w += 3;
output = 0;
}
if (j % 16 == 0) {
scales[j/16] = block->d * bscales[j/16];
biases[j/16] = -32.0f * scales[j/16];
}
}
ql += 64;
qh += 32;
bscales += 8;
scales += 8;
biases += 8;
}
""",
)
w, scales, biases = kernel(
inputs=[data],
grid=(num_blocks, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[wshape, gshape, gshape],
output_dtypes=[mx.uint8, mx.float16, mx.float16],
)
w = mx.view(w, dtype=mx.uint32)
return w, scales, biases
def parse_gguf_tensor(tensor):
from gguf import GGMLQuantizationType
if tensor.tensor_type == GGMLQuantizationType.Q4_K:
return parse_q4_k(tensor)
elif tensor.tensor_type == GGMLQuantizationType.Q6_K:
return parse_q6_k(tensor)
elif tensor.tensor_type in [GGMLQuantizationType.F16, GGMLQuantizationType.F32]:
return mx.array(tensor.data)
else:
raise NotImplementedError(f"Type: {tensor.tensor_type} is not yet supported.")
def convert_name(name):
name = name.replace("blk", "model.layers")
name = name.replace("attn_norm", "input_layernorm")
name = name.replace("ffn_norm", "post_attention_layernorm")
name = name.replace("attn_q", "self_attn.q_proj")
name = name.replace("attn_k", "self_attn.k_proj")
name = name.replace("attn_v", "self_attn.v_proj")
name = name.replace("attn_output", "self_attn.o_proj")
name = name.replace("ffn_up", "mlp.up_proj")
name = name.replace("ffn_down", "mlp.down_proj")
name = name.replace("ffn_gate", "mlp.gate_proj")
if "output_norm" in name:
name = name.replace("output_norm", "model.norm")
else:
name = name.replace("output", "lm_head")
name = name.replace("token_embd", "model.embed_tokens")
return name
FIELD_MAPPING = {
"{model}.embedding_length": "hidden_size",
"{model}.feed_forward_length": "intermediate_size",
"{model}.attention.head_count": "num_attention_heads",
"{model}.attention.head_count_kv": "num_key_value_heads",
"{model}.block_count": "num_hidden_layers",
"{model}.attention.layer_norm_rms_epsilon": "rms_norm_eps",
"{model}.rope.freq_base": "rope_theta",
}
QUANT_MAPPING = {
GGMLQuantizationType.Q4_K: {
"bits": 4,
"group_size": 32,
},
GGMLQuantizationType.Q6_K: {
"bits": 6,
"group_size": 16,
},
GGMLQuantizationType.F16: None,
GGMLQuantizationType.F32: None,
}
# from https://github.com/ggerganov/llama.cpp/blob/40c6d79fb52f995f47507fedfeaae2ac05d9b35c/gguf-py/scripts/gguf_new_metadata.py#L46
def decode_field(field):
if field and field.types:
main_type = field.types[0]
if main_type == gguf.GGUFValueType.ARRAY:
sub_type = field.types[-1]
if sub_type == gguf.GGUFValueType.STRING:
return [
str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data
]
else:
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
if main_type == gguf.GGUFValueType.STRING:
return str(bytes(field.parts[-1]), encoding="utf-8")
else:
return field.parts[-1][0]
return None
def load_gguf(model_path: str) -> tuple[nn.Module, TokenizerWrapper]:
with tempfile.TemporaryDirectory() as tmp_dir:
base_name = Path(model_path).name
(Path(tmp_dir) / base_name).symlink_to(model_path)
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, gguf_file=base_name)
reader = GGUFReader(model_path)
model_type = "qwen2"
config = {
"model_type": model_type,
"vocab_size": tokenizer.vocab_size,
"tie_word_embeddings": False,
}
mapping = {k.format(model=model_type): v for k, v in FIELD_MAPPING.items()}
for field in reader.fields:
if field in mapping:
config[mapping[field]] = decode_field(reader.get_field(field))
config["quantization"] = {}
weights = {}
# Look for any extra gguf files
parts = Path(model_path).name.split("-")
parts[-3] = "*"
gguf_pattern = "-".join(parts)
for filename in Path(model_path).parent.glob(gguf_pattern):
reader = GGUFReader(str(filename))
for tensor in reader.tensors:
w = parse_gguf_tensor(tensor)
mx.eval(w)
name = convert_name(tensor.name)
base_name = ".".join(name.split(".")[:-1])
if quant := QUANT_MAPPING[tensor.tensor_type]:
config["quantization"][base_name] = quant
if len(w) == 3:
w, scales, biases = w
weights[name] = w
weights[base_name + ".scales"] = scales
weights[base_name + ".biases"] = biases
else:
weights[name] = w
arch = importlib.import_module(f"mlx_lm.models.{config['model_type']}")
model_class, model_args_class = arch.Model, arch.ModelArgs
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
quant_config = config["quantization"]
def pred(p, m):
return quant_config.get(p)
nn.quantize(model, class_predicate=pred)
model.load_weights(list(weights.items()))
model.eval()
return model, tokenizer

View File

@ -19,6 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
# Local imports
from .gguf import load_gguf
from .models import cache
from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
@ -458,15 +459,20 @@ def load_model(
weights = model.sanitize(weights)
if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
def class_predicate(p, m):
# Handle custom per layer quantizations
if p in config["quantization"]:
return config["quantization"][p]
if not hasattr(m, "to_quantized"):
return False
# Handle legacy models which may not have everything quantized
return f"{p}.scales" in weights
nn.quantize(
model,
**quantization,
group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=class_predicate,
)
@ -507,6 +513,10 @@ def load(
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
if path_or_hf_repo.endswith(".gguf"):
model, tokenizer = load_gguf(path_or_hf_repo)
return model, tokenizer
model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy, model_config)
@ -669,7 +679,13 @@ def save_weights(
def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int
model: nn.Module,
config: dict,
q_group_size: int,
q_bits: int,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
) -> Tuple:
"""
Applies quantization to the model weights.
@ -679,13 +695,31 @@ def quantize_model(
config (dict): Model configuration.
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
quant_predicate (Callable): A callable that decides how
to quantize each layer based on the path.
Accepts the layer `path`, the `module` and the model `config`.
Returns either a bool to signify quantize/no quantize or
a dict of quantization parameters to pass to `to_quantized`.
Returns:
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
# Add any custom quantization parameters to the config as we go
def _class_predicate(p, m):
bool_or_params = quant_predicate(p, m, config)
if isinstance(bool_or_params, dict):
quantized_config["quantization"][p] = bool_or_params
return bool_or_params
nn.quantize(
model,
q_group_size,
q_bits,
class_predicate=_class_predicate if quant_predicate else None,
)
# support hf model tree #957
quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters()))
@ -726,6 +760,9 @@ def convert(
upload_repo: str = None,
revision: Optional[str] = None,
dequantize: bool = False,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
):
# Check the save path is empty
if isinstance(mlx_path, str):
@ -751,7 +788,9 @@ def convert(
if quantize:
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)
weights, config = quantize_model(
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
)
if dequantize:
print("[INFO] Dequantizing")