mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Example reading directly from gguf file (#222)
* Draft of tiny llama from gguf * Transpose all * No transposition with new layout * Read config from gguf * Create tokenizer from gguf * move gguf and update to be similar to hf_llm * change model to HF style + updates to REAMDE * nits in REAMDE * nit readme * only use mlx for metadata * fix eos/bos tokenizer * fix tokenization * quantization runs * 8-bit works * tokenizer fix * bump mlx version --------- Co-authored-by: Juarez Bochi <juarez.bochi@grammarly.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
40b61c1719
commit
f5b80c95fb
52
llms/gguf_llm/README.md
Normal file
52
llms/gguf_llm/README.md
Normal file
@ -0,0 +1,52 @@
|
||||
# LLMs in MLX with GGUF
|
||||
|
||||
An example generating text using GGUF format models in MLX.[^1]
|
||||
|
||||
> [!NOTE]
|
||||
> MLX is able to read most quantization formats from GGUF directly. However,
|
||||
> only a few quantizations are supported directly: `Q4_0`, `Q4_1`, and `Q8_0`.
|
||||
> Unsupported quantizations will be cast to `float16`.
|
||||
|
||||
## Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Run
|
||||
|
||||
Run with:
|
||||
|
||||
```bash
|
||||
python generate.py \
|
||||
--repo <hugging_face_repo> \
|
||||
--gguf <file.gguf> \
|
||||
--prompt "Write a quicksort in Python"
|
||||
```
|
||||
|
||||
For example, to generate text with Mistral 7B use:
|
||||
|
||||
```bash
|
||||
python generate.py \
|
||||
--repo TheBloke/Mistral-7B-v0.1-GGUF \
|
||||
--gguf mistral-7b-v0.1.Q8_0.gguf \
|
||||
--prompt "Write a quicksort in Python"
|
||||
```
|
||||
|
||||
Run `python generate.py --help` for more options.
|
||||
|
||||
Models that have been tested and work include:
|
||||
|
||||
- [TheBloke/Mistral-7B-v0.1-GGUF](https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF),
|
||||
for quantized models use:
|
||||
- `mistral-7b-v0.1.Q8_0.gguf`
|
||||
- `mistral-7b-v0.1.Q4_0.gguf`
|
||||
|
||||
- [TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF](https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF),
|
||||
for quantized models use:
|
||||
- `tinyllama-1.1b-chat-v1.0.Q8_0.gguf`
|
||||
- `tinyllama-1.1b-chat-v1.0.Q4_0.gguf`
|
||||
|
||||
[^1]: For more information on GGUF see [the documentation](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md).
|
86
llms/gguf_llm/generate.py
Normal file
86
llms/gguf_llm/generate.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import models
|
||||
|
||||
|
||||
def generate(
|
||||
model: models.Model,
|
||||
tokenizer: models.GGUFTokenizer,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
temp: float = 0.0,
|
||||
):
|
||||
prompt = tokenizer.encode(prompt)
|
||||
|
||||
tic = time.time()
|
||||
tokens = []
|
||||
skip = 0
|
||||
for token, n in zip(
|
||||
models.generate(prompt, model, args.temp),
|
||||
range(args.max_tokens),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if n == 0:
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
|
||||
tokens.append(token.item())
|
||||
s = tokenizer.decode(tokens)
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
print(tokenizer.decode(tokens)[skip:], flush=True)
|
||||
gen_time = time.time() - tic
|
||||
print("=" * 10)
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
gen_tps = (len(tokens) - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Inference script")
|
||||
parser.add_argument(
|
||||
"--gguf",
|
||||
type=str,
|
||||
help="The GGUF file to load (and optionally download).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The Hugging Face repo if downloading from the Hub.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="In the beginning the Universe was created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=0.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
mx.random.seed(args.seed)
|
||||
model, tokenizer = models.load(args.gguf, args.repo)
|
||||
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)
|
341
llms/gguf_llm/models.py
Normal file
341
llms/gguf_llm/models.py
Normal file
@ -0,0 +1,341 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import utils
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.norm(h), cache
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model = LlamaModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out, cache = self.model(inputs, cache)
|
||||
return self.lm_head(out), cache
|
||||
|
||||
|
||||
def get_config(metadata: dict):
|
||||
output = {
|
||||
"hidden_size": metadata["llama.embedding_length"],
|
||||
"num_hidden_layers": metadata["llama.block_count"],
|
||||
"num_attention_heads": metadata["llama.attention.head_count"],
|
||||
"intermediate_size": metadata["llama.feed_forward_length"],
|
||||
"num_key_value_heads": metadata["llama.attention.head_count_kv"],
|
||||
"rms_norm_eps": metadata["llama.attention.layer_norm_rms_epsilon"],
|
||||
"vocab_size": len(metadata["tokenizer.ggml.tokens"]),
|
||||
"rope_theta": metadata["llama.rope.freq_base"],
|
||||
"rope_traditional": True,
|
||||
}
|
||||
output = {k: v.item() if isinstance(v, mx.array) else v for k, v in output.items()}
|
||||
return output
|
||||
|
||||
|
||||
class GGUFTokenizer:
|
||||
def __init__(self, metadata):
|
||||
self._tokenizer = utils.spm_tokenizer(metadata)
|
||||
|
||||
def encode(self, s: str) -> mx.array:
|
||||
return mx.array([self._tokenizer.bos_id()] + self._tokenizer.encode(s))
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self._tokenizer.eos_id()
|
||||
|
||||
def decode(self, toks: List[int]) -> str:
|
||||
return self._tokenizer.decode(toks)
|
||||
|
||||
|
||||
def translate_weight_names(name):
|
||||
name = name.replace("blk.", "model.layers.")
|
||||
name = name.replace("ffn_gate", "mlp.gate_proj")
|
||||
name = name.replace("ffn_down", "mlp.down_proj")
|
||||
name = name.replace("ffn_up", "mlp.up_proj")
|
||||
name = name.replace("attn_q", "self_attn.q_proj")
|
||||
name = name.replace("attn_k", "self_attn.k_proj")
|
||||
name = name.replace("attn_v", "self_attn.v_proj")
|
||||
name = name.replace("attn_output", "self_attn.o_proj")
|
||||
name = name.replace("attn_norm", "input_layernorm")
|
||||
name = name.replace("ffn_norm", "post_attention_layernorm")
|
||||
name = name.replace("token_embd", "model.embed_tokens")
|
||||
name = name.replace("output_norm", "model.norm")
|
||||
name = name.replace("output", "lm_head")
|
||||
return name
|
||||
|
||||
|
||||
def load(gguf_file: str, repo: str = None):
|
||||
# If the gguf_file exists, try to load model from it.
|
||||
# Otherwise try to download and cache from the HF repo
|
||||
if not Path(gguf_file).exists():
|
||||
if repo is None:
|
||||
raise ValueError(
|
||||
f"Could not find file {gguf_file}, and no Hugging Face"
|
||||
" repo provided for download."
|
||||
)
|
||||
model_path = snapshot_download(
|
||||
repo_id=repo,
|
||||
allow_patterns=[gguf_file],
|
||||
)
|
||||
if not (Path(model_path) / gguf_file).exists():
|
||||
raise ValueError(f"File {gguf_file} not in repo {repo}.")
|
||||
gguf_file = str(Path(model_path) / gguf_file)
|
||||
|
||||
print(f"[INFO] Loading model from {gguf_file}")
|
||||
weights, metadata = mx.load(gguf_file, return_metadata=True)
|
||||
gguf_ft = metadata["general.file_type"]
|
||||
if gguf_ft == 0 or gguf_ft == 1:
|
||||
# ALL_F32 or MOSTLY_F16
|
||||
quantization = None
|
||||
pass
|
||||
elif gguf_ft == 2 or gguf_ft == 3:
|
||||
# MOSTLY_Q4_0 or MOSTLY_Q4_1
|
||||
quantization = {"group_size": 32, "bits": 4}
|
||||
elif gguf_ft == 7:
|
||||
# MOSTLY_Q8_0 = 7
|
||||
quantization = {"group_size": 32, "bits": 8}
|
||||
else:
|
||||
quantization = None
|
||||
print("[WARNING] Using unsupported GGUF quantization. Casting to float16.")
|
||||
|
||||
weights = {translate_weight_names(k): v for k, v in weights.items()}
|
||||
config = get_config(metadata)
|
||||
model = Model(ModelArgs(**config))
|
||||
if quantization is not None:
|
||||
# quantized the LM head?
|
||||
qm = model if "lm_head.scales" in weights else model.model
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
qm,
|
||||
**quantization,
|
||||
)
|
||||
|
||||
def dequantize(k):
|
||||
weight = weights.pop(f"{k}.weight")
|
||||
scales = weights.pop(f"{k}.scales")
|
||||
biases = weights.pop(f"{k}.biases")
|
||||
weights[f"{k}.weight"] = mx.dequantize(
|
||||
weight, scales=scales, biases=biases, **quantization
|
||||
)
|
||||
|
||||
# Dequantize embeddings
|
||||
dequantize("model.embed_tokens")
|
||||
|
||||
tokenizer = GGUFTokenizer(metadata)
|
||||
model.load_weights(list(weights.items()))
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits)
|
||||
yield y
|
4
llms/gguf_llm/requirements.txt
Normal file
4
llms/gguf_llm/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
mlx>=0.0.11
|
||||
numpy
|
||||
protobuf==3.20.0
|
||||
sentencepiece
|
53
llms/gguf_llm/utils.py
Normal file
53
llms/gguf_llm/utils.py
Normal file
@ -0,0 +1,53 @@
|
||||
import sentencepiece as spm
|
||||
import sentencepiece.sentencepiece_model_pb2 as model
|
||||
|
||||
|
||||
def spm_tokenizer(metadata):
|
||||
tokens = metadata["tokenizer.ggml.tokens"]
|
||||
bos = metadata["tokenizer.ggml.bos_token_id"].item()
|
||||
eos = metadata["tokenizer.ggml.eos_token_id"].item()
|
||||
unk = metadata["tokenizer.ggml.unknown_token_id"].item()
|
||||
|
||||
normalizer_spec = model.NormalizerSpec(
|
||||
name="identity",
|
||||
precompiled_charsmap=b"",
|
||||
add_dummy_prefix=True,
|
||||
remove_extra_whitespaces=False,
|
||||
normalization_rule_tsv=b"",
|
||||
)
|
||||
trainer_spec = model.TrainerSpec(
|
||||
model_type="BPE",
|
||||
vocab_size=len(tokens),
|
||||
input_format="text",
|
||||
split_by_unicode_script=True,
|
||||
split_by_whitespace=True,
|
||||
split_by_number=True,
|
||||
treat_whitespace_as_suffix=False,
|
||||
split_digits=True,
|
||||
allow_whitespace_only_pieces=True,
|
||||
vocabulary_output_piece_score=True,
|
||||
byte_fallback=True,
|
||||
unk_id=unk,
|
||||
bos_id=bos,
|
||||
eos_id=eos,
|
||||
pad_id=-1,
|
||||
unk_piece="<unk>",
|
||||
bos_piece="<s>",
|
||||
eos_piece="</s>",
|
||||
pad_piece="<pad>",
|
||||
pretokenization_delimiter="",
|
||||
)
|
||||
m = model.ModelProto(trainer_spec=trainer_spec, normalizer_spec=normalizer_spec)
|
||||
scores = metadata.get("tokenizer.ggml.scores", None)
|
||||
scores = scores.tolist() if scores is not None else None
|
||||
token_types = metadata.get("tokenizer.ggml.token_type", None)
|
||||
token_types = token_types.tolist() if token_types is not None else None
|
||||
|
||||
for i, token in enumerate(tokens):
|
||||
score = scores[i] if scores else 0
|
||||
token_type = token_types[i] if token_types else 0
|
||||
m.pieces.append(
|
||||
model.ModelProto.SentencePiece(piece=token, score=score, type=token_type)
|
||||
)
|
||||
tokenizer = spm.SentencePieceProcessor(model_proto=m.SerializeToString())
|
||||
return tokenizer
|
Loading…
Reference in New Issue
Block a user