mlx-examples/llms/mlx_lm/gguf.py

617 lines
21 KiB
Python
Raw Normal View History

2024-12-04 11:54:57 +08:00
import importlib
import re
2024-12-04 11:54:57 +08:00
import tempfile
from enum import IntEnum
from pathlib import Path
from typing import Iterable, Optional, Set, Tuple, Union
2024-12-04 11:54:57 +08:00
import gguf
import mlx.core as mx
2024-12-04 11:54:57 +08:00
import mlx.nn as nn
from gguf import GGMLQuantizationType
from gguf.gguf_reader import GGUFReader
from transformers import AutoTokenizer
2024-12-04 11:54:57 +08:00
from .tokenizer_utils import TokenizerWrapper
class TokenType(IntEnum):
NORMAL = 1
UNKNOWN = 2
CONTROL = 3
USER_DEFINED = 4
UNUSED = 5
BYTE = 6
class GGMLFileType(IntEnum):
GGML_TYPE_F16 = 1
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
class HfVocab:
def __init__(
self,
fname_tokenizer: Path,
fname_added_tokens: Optional[Union[Path, None]] = None,
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
fname_tokenizer,
cache_dir=fname_tokenizer,
local_files_only=True,
)
self.added_tokens_list = []
self.added_tokens_dict = dict()
self.added_tokens_ids = set()
for tok, tokidx in sorted(
self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
):
if tokidx >= self.tokenizer.vocab_size:
self.added_tokens_list.append(tok)
self.added_tokens_dict[tok] = tokidx
self.added_tokens_ids.add(tokidx)
self.specials = {
tok: self.tokenizer.get_vocab()[tok]
for tok in self.tokenizer.all_special_tokens
}
self.special_ids = set(self.tokenizer.all_special_ids)
self.vocab_size_base = self.tokenizer.vocab_size
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
reverse_vocab = {
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
}
for token_id in range(self.vocab_size_base):
if token_id in self.added_tokens_ids:
continue
token_text = reverse_vocab[token_id]
yield token_text, self.get_token_score(token_id), self.get_token_type(
token_id, token_text, self.special_ids
)
def get_token_type(
self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
2024-09-21 04:33:45 +08:00
if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text):
return TokenType.BYTE
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
def get_token_score(self, token_id: int) -> float:
return -1000.0
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
for text in self.added_tokens_list:
if text in self.specials:
2024-09-21 04:33:45 +08:00
toktype = self.get_token_type(self.specials[text], "", self.special_ids)
score = self.get_token_score(self.specials[text])
else:
toktype = TokenType.USER_DEFINED
score = -1000.0
yield text, score, toktype
def has_newline_token(self):
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
yield from self.hf_tokens()
yield from self.added_tokens()
def __repr__(self) -> str:
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
@staticmethod
def load(path: Path) -> "HfVocab":
added_tokens_path = path.parent / "added_tokens.json"
return HfVocab(path, added_tokens_path if added_tokens_path.exists() else None)
def translate_weight_names(name):
name = name.replace("model.layers.", "blk.")
# for mixtral gate
name = name.replace("block_sparse_moe.gate", "ffn_gate_inp")
# for mixtral experts ffns
pattern = r"block_sparse_moe\.experts\.(\d+)\.w1\.weight"
replacement = r"ffn_gate.\1.weight"
name = re.sub(pattern, replacement, name)
pattern = r"block_sparse_moe\.experts\.(\d+)\.w2\.weight"
replacement = r"ffn_down.\1.weight"
name = re.sub(pattern, replacement, name)
pattern = r"block_sparse_moe\.experts\.(\d+)\.w3\.weight"
replacement = r"ffn_up.\1.weight"
name = re.sub(pattern, replacement, name)
name = name.replace("mlp.gate_proj", "ffn_gate")
name = name.replace("mlp.down_proj", "ffn_down")
name = name.replace("mlp.up_proj", "ffn_up")
name = name.replace("self_attn.q_proj", "attn_q")
name = name.replace("self_attn.k_proj", "attn_k")
name = name.replace("self_attn.v_proj", "attn_v")
name = name.replace("self_attn.o_proj", "attn_output")
name = name.replace("input_layernorm", "attn_norm")
name = name.replace("post_attention_layernorm", "ffn_norm")
name = name.replace("model.embed_tokens", "token_embd")
name = name.replace("model.norm", "output_norm")
name = name.replace("lm_head", "output")
return name
def permute_weights(weights, n_head, n_head_kv=None):
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
reshaped = weights.reshape(
n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]
)
swapped = reshaped.swapaxes(1, 2)
final_shape = weights.shape
return swapped.reshape(final_shape)
def prepare_metadata(config, vocab):
metadata = {
"general.name": "llama",
"llama.context_length": (
mx.array(config["max_position_embeddings"], dtype=mx.uint32)
if config.get("max_position_embeddings") is not None
else None
),
"llama.embedding_length": (
mx.array(config["hidden_size"], dtype=mx.uint32)
if config.get("hidden_size") is not None
else None
),
"llama.block_count": (
mx.array(config["num_hidden_layers"], dtype=mx.uint32)
if config.get("num_hidden_layers") is not None
else None
),
"llama.feed_forward_length": (
mx.array(config["intermediate_size"], dtype=mx.uint32)
if config.get("intermediate_size") is not None
else None
),
"llama.rope.dimension_count": (
mx.array(
config["hidden_size"] // config["num_attention_heads"], dtype=mx.uint32
)
if config.get("hidden_size") is not None
and config.get("num_attention_heads") is not None
else None
),
"llama.attention.head_count": (
mx.array(config["num_attention_heads"], dtype=mx.uint32)
if config.get("num_attention_heads") is not None
else None
),
"llama.attention.head_count_kv": (
mx.array(
config.get("num_key_value_heads", config["num_attention_heads"]),
dtype=mx.uint32,
)
if config.get("num_attention_heads") is not None
else None
),
"llama.expert_count": (
mx.array(config.get("num_local_experts", None), dtype=mx.uint32)
if config.get("num_local_experts") is not None
else None
),
"llama.expert_used_count": (
mx.array(config.get("num_experts_per_tok", None), dtype=mx.uint32)
if config.get("num_experts_per_tok") is not None
else None
),
"llama.attention.layer_norm_rms_epsilon": (
mx.array(config.get("rms_norm_eps", 1e-05))
if config.get("rms_norm_eps") is not None
else None
),
"llama.rope.freq_base": (
mx.array(config.get("rope_theta", 10000), dtype=mx.float32)
if config.get("rope_theta") is not None
else None
),
}
rope_scaling = config.get("rope_scaling")
if rope_scaling is not None and (typ := rope_scaling.get("type")):
rope_factor = rope_scaling.get("factor")
f_rope_scale = rope_factor
if typ == "linear":
rope_scaling_type = "linear"
metadata["llama.rope.scaling.type"] = rope_scaling_type
metadata["llama.rope.scaling.factor"] = mx.array(f_rope_scale)
metadata["general.file_type"] = mx.array(
GGMLFileType.GGML_TYPE_F16.value,
dtype=mx.uint32,
)
metadata["general.quantization_version"] = mx.array(
GGMLFileType.GGML_TYPE_F16.value,
dtype=mx.uint32,
)
metadata["general.name"] = config.get("_name_or_path", "llama").split("/")[-1]
metadata["general.architecture"] = "llama"
metadata["general.alignment"] = mx.array(32, dtype=mx.uint32)
# add metadata for vocab
metadata["tokenizer.ggml.model"] = "llama"
tokens = []
scores = []
toktypes = []
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
toktypes.append(toktype.value)
assert len(tokens) == vocab.vocab_size
metadata["tokenizer.ggml.tokens"] = tokens
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
2024-09-21 04:33:45 +08:00
if vocab.tokenizer.bos_token_id is not None:
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
vocab.tokenizer.bos_token_id, dtype=mx.uint32
)
if vocab.tokenizer.eos_token_id is not None:
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
vocab.tokenizer.eos_token_id, dtype=mx.uint32
)
if vocab.tokenizer.unk_token_id is not None:
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
vocab.tokenizer.unk_token_id, dtype=mx.uint32
)
metadata = {k: v for k, v in metadata.items() if v is not None}
return metadata
def convert_to_gguf(
model_path: Union[str, Path],
weights: dict,
config: dict,
output_file_path: str,
):
if isinstance(model_path, str):
model_path = Path(model_path)
quantization = config.get("quantization", None)
if quantization:
raise NotImplementedError(
"Conversion of quantized models is not yet supported."
)
print("Converting to GGUF format")
# https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1182 seems relate to llama.cpp's multihead attention
weights = {
k: (
permute_weights(
v, config["num_attention_heads"], config["num_attention_heads"]
)
if "self_attn.q_proj.weight" in k
else (
permute_weights(
v, config["num_attention_heads"], config["num_key_value_heads"]
)
if "self_attn.k_proj.weight" in k
else v
)
)
for k, v in weights.items()
}
# rename weights for gguf format
weights = {translate_weight_names(k): v for k, v in weights.items()}
if not (model_path / "tokenizer.json").exists():
raise ValueError("Tokenizer json not found")
vocab = HfVocab.load(model_path)
metadata = prepare_metadata(config, vocab)
weights = {
k: (
v.astype(mx.float32).astype(mx.float16)
if v.dtype == mx.bfloat16
else v.astype(mx.float32) if "norm" in k else v
)
for k, v in weights.items()
}
output_file_path = output_file_path
mx.save_gguf(output_file_path, weights, metadata)
print(f"Converted GGUF model saved as: {output_file_path}")
2024-12-04 11:54:57 +08:00
# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L568
def parse_q4_k(tensor):
bits = 4
pack_factor = 32 // bits
group_size = 32
block_size = 144
data = mx.array(tensor.data)
shape = [int(d) for d in reversed(tensor.shape)]
wshape = (*shape[:-1], shape[-1] // pack_factor)
gshape = (*shape[:-1], shape[-1] // group_size)
num_blocks = data.size // block_size
kernel = mx.fast.metal_kernel(
name="parse_q4_k",
input_names=["data"],
output_names=["w", "scales", "biases"],
header="""
typedef struct {
float16_t d;
float16_t d_min;
uint8_t scales[12];
uint8_t qs[128];
} block_q4_K;
""",
source="""
uint elem = thread_position_in_grid.x;
const device block_q4_K* block = reinterpret_cast<const device block_q4_K*>(data);
block += elem;
w += elem * 32;
scales += elem * 8;
biases += elem * 8;
// First unpack the quantized scales/biases
for (int j = 0; j < 8; j++) {
uint8_t d, m;
if (j < 4) {
d = block->scales[j] & 63;
m = block->scales[j + 4] & 63;
} else {
d = (block->scales[j + 4] & 0xF) | ((block->scales[j - 4] >> 6) << 4);
m = (block->scales[j + 4] >> 4) | ((block->scales[j - 0] >> 6) << 4);
}
scales[j] = d * block->d;
biases[j] = -m * block->d_min;
}
uint32_t outputs[32] = {0};
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 32; j++) {
uint8_t val = block->qs[i * 32 + j] & 0xf;
int index = i * 8 + (j / 8);
outputs[index] += val << (4 * (j % 8));
}
for (int j = 0; j < 32; j++) {
uint8_t val = block->qs[i * 32 + j] >> 4;
int index = i * 8 + 4 + (j / 8);
outputs[index] += val << (4 * (j % 8));
}
}
for (int i = 0; i < 32; i++) {
w[i] = outputs[i];
}
""",
)
w, scales, biases = kernel(
inputs=[data],
grid=(num_blocks, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[wshape, gshape, gshape],
output_dtypes=[mx.uint32, mx.float16, mx.float16],
)
return w, scales, biases
# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L658
def parse_q6_k(tensor):
bits = 6
group_size = 16
block_size = 210
data = mx.array(tensor.data)
shape = [int(d) for d in reversed(tensor.shape)]
wshape = (*shape[:-1], shape[-1] * bits // 8)
gshape = (*shape[:-1], shape[-1] // group_size)
num_blocks = data.size // block_size
kernel = mx.fast.metal_kernel(
name="parse_q6_k",
input_names=["data"],
output_names=["w", "scales", "biases"],
header="""
typedef struct {
uint8_t ql[128]; // quants, lower 4 bits
uint8_t qh[64]; // quants, upper 2 bits
int8_t scales[16]; // scales, quantized with 8 bits
float16_t d; // super-block scale
} block_q6_K;
""",
source="""
uint elem = thread_position_in_grid.x;
const device block_q6_K* block = reinterpret_cast<const device block_q6_K*>(data);
block += elem;
w += elem * 192;
scales += elem * 16;
biases += elem * 16;
const device uint8_t* ql = &block->ql[0];
const device uint8_t* qh = &block->qh[0];
const device int8_t* bscales = &block->scales[0];
uint32_t output = 0;
for (int cluster = 0; cluster < 2; cluster++) {
for (uint64_t j = 0; j < 128; j++) {
uint8_t val = ((ql[j%64] >> (j/64*4)) & 0xF) | (((qh[j%32] >> (j/32*2)) & 3) << 4);
output += val << (6 * (j % 4));
// Every 4 values write out 3 bytes
if (j % 4 == 3) {
w[0] = output & 0xff;
w[1] = (output & 0xff00) >> 8;
w[2] = (output & 0xff0000) >> 16;
w += 3;
output = 0;
}
if (j % 16 == 0) {
scales[j/16] = block->d * bscales[j/16];
biases[j/16] = -32.0f * scales[j/16];
}
}
ql += 64;
qh += 32;
bscales += 8;
scales += 8;
biases += 8;
}
""",
)
w, scales, biases = kernel(
inputs=[data],
grid=(num_blocks, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[wshape, gshape, gshape],
output_dtypes=[mx.uint8, mx.float16, mx.float16],
)
w = mx.view(w, dtype=mx.uint32)
return w, scales, biases
def parse_gguf_tensor(tensor):
from gguf import GGMLQuantizationType
if tensor.tensor_type == GGMLQuantizationType.Q4_K:
return parse_q4_k(tensor)
elif tensor.tensor_type == GGMLQuantizationType.Q6_K:
return parse_q6_k(tensor)
elif tensor.tensor_type in [GGMLQuantizationType.F16, GGMLQuantizationType.F32]:
return mx.array(tensor.data)
else:
raise NotImplementedError(f"Type: {tensor.tensor_type} is not yet supported.")
def convert_name(name):
name = name.replace("blk", "model.layers")
name = name.replace("attn_norm", "input_layernorm")
name = name.replace("ffn_norm", "post_attention_layernorm")
name = name.replace("attn_q", "self_attn.q_proj")
name = name.replace("attn_k", "self_attn.k_proj")
name = name.replace("attn_v", "self_attn.v_proj")
name = name.replace("attn_output", "self_attn.o_proj")
name = name.replace("ffn_up", "mlp.up_proj")
name = name.replace("ffn_down", "mlp.down_proj")
name = name.replace("ffn_gate", "mlp.gate_proj")
if "output_norm" in name:
name = name.replace("output_norm", "model.norm")
else:
name = name.replace("output", "lm_head")
name = name.replace("token_embd", "model.embed_tokens")
return name
FIELD_MAPPING = {
"{model}.embedding_length": "hidden_size",
"{model}.feed_forward_length": "intermediate_size",
"{model}.attention.head_count": "num_attention_heads",
"{model}.attention.head_count_kv": "num_key_value_heads",
"{model}.block_count": "num_hidden_layers",
"{model}.attention.layer_norm_rms_epsilon": "rms_norm_eps",
"{model}.rope.freq_base": "rope_theta",
}
QUANT_MAPPING = {
GGMLQuantizationType.Q4_K: {
"bits": 4,
"group_size": 32,
},
GGMLQuantizationType.Q6_K: {
"bits": 6,
"group_size": 16,
},
GGMLQuantizationType.F16: None,
GGMLQuantizationType.F32: None,
}
# from https://github.com/ggerganov/llama.cpp/blob/40c6d79fb52f995f47507fedfeaae2ac05d9b35c/gguf-py/scripts/gguf_new_metadata.py#L46
def decode_field(field):
if field and field.types:
main_type = field.types[0]
if main_type == gguf.GGUFValueType.ARRAY:
sub_type = field.types[-1]
if sub_type == gguf.GGUFValueType.STRING:
return [
str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data
]
else:
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
if main_type == gguf.GGUFValueType.STRING:
return str(bytes(field.parts[-1]), encoding="utf-8")
else:
return field.parts[-1][0]
return None
def load_gguf(model_path: str) -> tuple[nn.Module, TokenizerWrapper]:
with tempfile.TemporaryDirectory() as tmp_dir:
base_name = Path(model_path).name
(Path(tmp_dir) / base_name).symlink_to(model_path)
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, gguf_file=base_name)
reader = GGUFReader(model_path)
model_type = "qwen2"
config = {
"model_type": model_type,
"vocab_size": tokenizer.vocab_size,
"tie_word_embeddings": False,
}
mapping = {k.format(model=model_type): v for k, v in FIELD_MAPPING.items()}
for field in reader.fields:
if field in mapping:
config[mapping[field]] = decode_field(reader.get_field(field))
config["quantization"] = {}
weights = {}
# Look for any extra gguf files
parts = Path(model_path).name.split("-")
parts[-3] = "*"
gguf_pattern = "-".join(parts)
for filename in Path(model_path).parent.glob(gguf_pattern):
reader = GGUFReader(str(filename))
for tensor in reader.tensors:
w = parse_gguf_tensor(tensor)
mx.eval(w)
name = convert_name(tensor.name)
base_name = ".".join(name.split(".")[:-1])
if quant := QUANT_MAPPING[tensor.tensor_type]:
config["quantization"][base_name] = quant
if len(w) == 3:
w, scales, biases = w
weights[name] = w
weights[base_name + ".scales"] = scales
weights[base_name + ".biases"] = biases
else:
weights[name] = w
arch = importlib.import_module(f"mlx_lm.models.{config['model_type']}")
model_class, model_args_class = arch.Model, arch.ModelArgs
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
quant_config = config["quantization"]
def pred(p, m):
return quant_config.get(p)
nn.quantize(model, class_predicate=pred)
model.load_weights(list(weights.items()))
model.eval()
return model, tokenizer