14 Commits

Author SHA1 Message Date
Angelos Katharopoulos
b9eff0d744 Improve printing for FLUX distributed training 2025-01-13 22:47:54 -08:00
Awni Hannun
c117af83b8 fix gpt bigcode (#1204) 2025-01-13 10:22:32 -08:00
Chime Ogbuji
0228c46434 Custom local dataset features (#1085)
* Generalize prompt_feature and completion_feature for use in local datasets to facilitate compatibility with many other training dataset formats.

* Persist configured prompt/completion key

* rebase + nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-13 10:01:18 -08:00
Prince Canuma
bf2da36fc6 Fix Cohere2: mask shape error (long context) (#1202)
* fix mask shape error (long context)

* Update llms/mlx_lm/models/cohere2.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* revert layer_idx

* black formatting

* Update cohere2.py

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-12 12:58:08 -08:00
Xingjun.Wang
514502da22 Support snapshot_download for ModelScope (#1194)
* add MLX_USE_MODELSCOPE env

* update

* update snapshot_download

* update

* remove modelscope dependency and add import check

* update

* nits

* fix

---------

Co-authored-by: wangxingjun778 <jason@U-C7X6TX5G-2239.local>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-10 15:29:34 -08:00
Awni Hannun
93c5cfd781 Add a speculative decoding generator (#1155)
* add a speculative decoding generator

* fix

* fixes

* optional kwarg pop
2025-01-10 15:27:08 -08:00
Awni Hannun
5cae0a60e6 deepseek v3 model with pipeline parallelism (#1191)
* deepseekv3

* use upload_large_file instead of deprecated multi comit

* add pipeline generation and example

* comment

* get fp16 working

* use mlx==0.22
2025-01-09 15:55:53 -08:00
Jarrett
40b88eff48 fix(lora): config yaml & arg default merge bug (#1196) 2025-01-09 11:33:54 -08:00
Pedro Cuenca
b8f0cacfa8 Use upload_large_folder (#1193) 2025-01-07 09:18:31 -08:00
Awni Hannun
9183fe8b6d fix (#1192) 2025-01-06 10:12:07 -08:00
Chime Ogbuji
f2619f507c Add support for fewshot and apply chat template lm_eval functionality (#1180)
* Add support for multiturn fewshot examples and chat templates

Added two new arguments to the evaluation script: `--fewshot-as-multiturn` and `--apply-chat-template` which correspond to lm_eval options of similar names and are very often used to ensure apples-to-apples comparisons of lm_evaluation results

* Add HF overrides for methods needed by added options

* don't add duplicate bos

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-06 07:58:43 -08:00
Angelos Katharopoulos
25ec2d8c44 Change the eos-token argument for mlx_lm.generate (#1176) 2025-01-05 22:26:05 -08:00
Awni Hannun
c4833a2f55 fix encoding with special tokens + chat template (#1189) 2025-01-03 10:50:59 -08:00
Ivan Fioravanti
3a58c36109 Improvements to mlx_lm.manage (#1178)
* improvements to manage. Default value is N and size added to deletion confirmation.

* Fixing case for no case

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-01 07:25:57 -08:00
28 changed files with 1040 additions and 648 deletions

View File

@@ -32,7 +32,7 @@ jobs:
pip install --upgrade pip pip install --upgrade pip
pip install unittest-xml-reporting pip install unittest-xml-reporting
cd llms/ cd llms/
pip install -e ".[testing]" pip install -e ".[test]"
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |

View File

@@ -261,19 +261,23 @@ if __name__ == "__main__":
generate_progress_images(0, flux, args) generate_progress_images(0, flux, args)
grads = None grads = None
losses = [] batch_cnt = 0
total_loss = 0
tic = time.time() tic = time.time()
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)): for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state) total_loss = total_loss + loss
losses.append(loss.item()) batch_cnt += 1
mx.eval(total_loss, grads, state)
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0 and mx.distributed.init().rank() == 0:
toc = time.time() toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3 peak_mem = mx.metal.get_peak_memory() / 1024**3
total_loss = mx.distributed.all_sum(total_loss, stream=mx.cpu)
total_loss = total_loss.item() / batch_cnt
print( print(
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"Iter: {i + 1} Loss: {total_loss:.3f} "
f"It/s: {10 / (toc - tic):.3f} " f"It/s: {batch_cnt / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB", f"Peak mem: {peak_mem:.3f} GB",
flush=True, flush=True,
) )
@@ -285,7 +289,8 @@ if __name__ == "__main__":
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args) save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
losses = [] total_loss = 0
batch_cnt = 0
tic = time.time() tic = time.time()
save_adapters("final_adapters.safetensors", flux, args) save_adapters("final_adapters.safetensors", flux, args)

View File

@@ -58,7 +58,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=True
) )
text = generate(model, tokenizer, prompt=prompt, verbose=True) text = generate(model, tokenizer, prompt=prompt, verbose=True)
@@ -115,7 +115,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=True
) )
for response in stream_generate(model, tokenizer, prompt, max_tokens=512): for response in stream_generate(model, tokenizer, prompt, max_tokens=512):

View File

@@ -241,14 +241,25 @@ Refer to the documentation for the model you are fine-tuning for more details.
{"prompt": "What is the capital of France?", "completion": "Paris."} {"prompt": "What is the capital of France?", "completion": "Paris."}
``` ```
For the `completions` data format, a different key can be used for the prompt
and completion by specifying the following in the YAML config:
```yaml
prompt_feature: "input"
completion_feature: "output"
```
Here, `"input"` is the expected key instead of the default `"prompt"`, and
`"output"` is the expected key instead of `"completion"`.
`text`: `text`:
```jsonl ```jsonl
{"text": "This is an example for the model."} {"text": "This is an example for the model."}
``` ```
Note, the format is automatically determined by the dataset. Note also, keys in Note, the format is automatically determined by the dataset. Note also, keys
each line not expected by the loader will be ignored. in each line not expected by the loader will be ignored.
> [!NOTE] > [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than > Each example in the datasets must be on a single line. Do not put more than
@@ -270,7 +281,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example: example:
``` ```yaml
hf_dataset: hf_dataset:
name: "billsum" name: "billsum"
prompt_feature: "text" prompt_feature: "text"

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.20.4" __version__ = "0.21.0"

View File

@@ -1,439 +0,0 @@
# Learned quantization using AWQ:
# References:
# AWQ
# https://arxiv.org/abs/2306.00978
# https://github.com/mit-han-lab/llm-awq
import argparse
import glob
import shutil
from pathlib import Path
from typing import Callable
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from datasets import Dataset, load_dataset
from mlx.utils import tree_flatten, tree_map_with_path
from mlx_lm.models.base import create_attention_mask
from mlx_lm.tokenizer_utils import TokenizerWrapper
from mlx_lm.utils import fetch_from_hub, get_model_path, save_config, save_weights
from tqdm import tqdm
def mse(x, y):
return ((x - y).astype(mx.float32)) ** 2
def run_layer(layer: nn.Module, x: mx.array, batch_size: int = 32, **kwargs):
y = []
for i in range(0, x.shape[0], batch_size):
y.append(layer(x[i : i + batch_size], **kwargs))
mx.eval(y)
y = mx.concatenate(y, axis=0)
return y
def dist_split(x: mx.array, group: mx.distributed.Group):
B = x.shape[0]
N = group.size()
assert B % N == 0
r = group.rank()
local_B = (B + N - 1) // N
return x[r * local_B : (r + 1) * local_B]
def search_best_scale(
layers: list[nn.Module],
x: mx.array,
quantize_func: Callable,
block: nn.Module | None = None,
layer_kwargs: dict | None = None,
n_grid: int = 20,
):
group = mx.distributed.init() if mx.distributed.is_available() else None
layer_kwargs = layer_kwargs or {}
block = block or layers[0]
out = block(x, **layer_kwargs)
x_max = x.abs().mean(axis=(0, 1))
best_error = float("inf")
best_scales = None
weights = tree_flatten(block.parameters())
for ratio in tqdm(range(n_grid)):
ratio = ratio * 1 / n_grid
scales = mx.maximum(x_max**ratio, 1e-4).reshape(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
for layer in layers:
layer.weight = quantize_func(layer.weight * scales) / scales
out_q = run_layer(block, x, **layer_kwargs)
loss = mse(out, out_q).sum()
if group is not None:
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
loss /= out.size
mx.eval(loss)
is_best = loss < best_error
if is_best:
best_error = loss
best_scales = scales
# reload the original weights
block.load_weights(weights)
best_scales = best_scales.reshape(-1)
mx.eval(best_scales)
return best_scales
def apply_scale(prev_op, layers, scales):
# Apply the scales to the layers
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
prev_op.weight = prev_op.weight / scales[:, mx.newaxis]
if hasattr(prev_op, "bias"):
prev_op.bias = prev_op.bias / scales
layers[0].weight = layers[0].weight * scales[mx.newaxis]
elif isinstance(prev_op, (nn.LayerNorm, nn.RMSNorm)):
prev_op.weight = prev_op.weight / scales
if hasattr(prev_op, "bias"):
prev_op.bias = prev_op.bias / scales
for layer in layers:
layer.weight = layer.weight * scales
else:
raise NotImplementedError(f"Could not apply scale to prev_op: {prev_op}")
def scale_block(
block, input_feat, quantize_func: Callable, layer_kwargs: dict | None = None
):
layers = [
block.self_attn.q_proj,
block.self_attn.k_proj,
block.self_attn.v_proj,
]
scales = search_best_scale(
layers=layers,
block=block.self_attn,
x=input_feat["q_proj"],
quantize_func=quantize_func,
layer_kwargs=layer_kwargs,
)
apply_scale(block.input_layernorm, layers, scales)
for name in ["q_proj", "k_proj", "v_proj"]:
input_feat[name] = input_feat[name] / scales
layers = [
block.mlp.gate_proj,
block.mlp.up_proj,
]
scales = search_best_scale(
block=block.mlp,
layers=layers,
x=input_feat["gate_proj"],
quantize_func=quantize_func,
)
mlp_norm = getattr(
block, "pre_feedforward_layernorm", block.post_attention_layernorm
)
apply_scale(mlp_norm, layers, scales)
for name in ["gate_proj", "up_proj"]:
input_feat[name] = input_feat[name] / scales
layers = [block.mlp.down_proj]
scales = search_best_scale(
layers=layers,
x=input_feat["down_proj"],
quantize_func=quantize_func,
)
apply_scale(block.mlp.up_proj, layers, scales)
input_feat["down_proj"] = input_feat["down_proj"] / scales
def search_best_clip(
w: mx.array,
x: mx.array,
quantize_func: Callable,
group_size: int,
n_grid: int = 20,
max_shrink: float = 0.5,
subsample: int = 4,
batch_size: int = 64,
):
group = mx.distributed.init() if mx.distributed.is_available() else None
x = x[:, ::subsample]
x = x.reshape(*x.shape[:-1], -1, group_size)
w_all = w
w_max_all = []
w_min_all = []
for b in range(0, w.shape[0], batch_size):
w = w_all[b : b + batch_size]
group_shape = (w.shape[0], w.shape[-1] // group_size)
best_error = mx.full(group_shape, float("inf"))
best_w_max = mx.zeros((*group_shape, 1), dtype=x.dtype)
best_w_min = mx.zeros((*group_shape, 1), dtype=x.dtype)
w_shape = w.shape
w = w.reshape(*w.shape[:-1], -1, group_size)
out = mx.einsum("btdg,odg->btod", x, w)
for i in range(int(max_shrink * n_grid)):
p = 1 - i / n_grid
w_max = p * w.max(axis=-1, keepdims=True)
w_min = p * w.min(axis=-1, keepdims=True)
w_m = mx.clip(w, w_min, w_max).reshape(w_shape)
w_q = quantize_func(w_m)
w_q = w_q.reshape(*w_q.shape[:-1], -1, group_size)
out_q = mx.einsum("btdg,odg->btod", x, w_q)
# Take the mean across the input batch
loss = mse(out, out_q).sum(axis=(0, 1))
if group is not None:
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
loss /= out.shape[0] * out.shape[1]
best_indices = loss < best_error
best_error = mx.where(best_indices, loss, best_error)
best_w_max = mx.where(best_indices[..., mx.newaxis], w_max, best_w_max)
best_w_min = mx.where(best_indices[..., mx.newaxis], w_min, best_w_min)
mx.eval(best_w_max, best_w_min, best_error)
w_max_all.append(best_w_max)
w_min_all.append(best_w_min)
best_w_max = mx.concatenate(w_max_all, axis=0)
best_w_min = mx.concatenate(w_min_all, axis=0)
w_r = w_all.reshape(*w_all.shape[:-1], -1, group_size)
best_w = mx.clip(w_r, best_w_min, best_w_max)
best_w = best_w.reshape(w_all.shape)
mx.eval(best_w)
return best_w
def clip_block(
block: nn.Module,
input_feat: dict[str, mx.array],
quantize_func: Callable,
group_size: int,
):
def apply_clip(path, module):
if (
isinstance(module, nn.Linear)
and "q_proj" not in path
and "k_proj" not in path
):
name = path.split(".")[-1]
best_weight = search_best_clip(
module.weight,
input_feat[name],
quantize_func=quantize_func,
group_size=group_size,
)
module.weight = best_weight
tree_map_with_path(apply_clip, block.leaf_modules(), is_leaf=nn.Module.is_module)
def awq_quantize(
model,
inputs: mx.array,
group_size: int = 64,
bits: int = 3,
embed_group_size: int = 32,
embed_bits: int = 4,
):
group = mx.distributed.init() if mx.distributed.is_available() else None
def quantize_func(w):
wq = mx.quantize(w, bits=bits, group_size=group_size)
return mx.dequantize(*wq, bits=bits, group_size=group_size)
mask = create_attention_mask(inputs)
model.model.embed_tokens = model.model.embed_tokens.to_quantized(
group_size=embed_group_size, bits=embed_bits
)
inputs = model.model.embed_tokens(inputs)
input_feat = {}
def capture(path, module):
if not isinstance(module, nn.Linear):
return module
class Catcher(nn.Module):
def __call__(self, x: mx.array):
name = path.split(".")[-1]
input_feat[name] = x
return module(x)
return Catcher()
for i, layer in enumerate(model.model.layers):
import time
s = time.time()
print(f"Starting block {i}")
# capture the inputs to each layer
orig_leaves = layer.leaf_modules()
capture_leaves = tree_map_with_path(
capture, orig_leaves, is_leaf=nn.Module.is_module
)
layer.update_modules(capture_leaves)
outputs = run_layer(layer, inputs, mask=mask)
layer.update_modules(orig_leaves)
del capture_leaves
nn.quantize(layer, group_size=group_size, bits=bits)
outputs_q = run_layer(layer, inputs, mask=mask)
loss = mse(outputs, outputs_q).sum()
if group is not None:
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
loss /= outputs.size
print("Before Loss", loss, flush=True)
layer.update_modules(orig_leaves)
del orig_leaves
print("Scaling block", flush=True)
scale_block(
block=layer,
input_feat=input_feat,
quantize_func=quantize_func,
layer_kwargs={"mask": mask},
)
print("Clipping block", flush=True)
clip_block(
block=layer,
input_feat=input_feat,
quantize_func=quantize_func,
group_size=group_size,
)
nn.quantize(layer, group_size=group_size, bits=bits)
outputs_q = run_layer(layer, inputs, mask=mask)
loss = mse(outputs, outputs_q).sum()
if group is not None:
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
loss /= outputs.size
print("After Loss", loss, flush=True)
input_feat = {}
inputs = outputs
mx.eval(layer)
mx.metal.clear_cache()
e = time.time()
print("Loop time: ", e - s)
if hasattr(model, "lm_head"):
model.lm_head = model.lm_head.to_quantized(
group_size=embed_group_size, bits=embed_bits
)
def load_wikitext(
tokenizer, num_samples: int = 32, sequence_length: int = 2048, split: str = "train"
) -> mx.array:
dataset = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=split)
texts = "\n\n".join(dataset["text"])
tokens = tokenizer.encode(texts, return_tensors="mlx")[0]
# Select random chunks
starts = mx.random.randint(
0, len(tokens) - sequence_length - 1, shape=(num_samples, 1)
)
data = tokens[starts + mx.arange(sequence_length)]
if tokenizer.bos_token_id:
data = mx.concatenate(
[mx.full((*data.shape[:2], 1), tokenizer.bos_token_id), data], axis=-1
)
return data
def save_model(
model: nn.Module,
tokenizer: TokenizerWrapper,
config,
model_path: Path,
mlx_path: str,
):
weights = dict(tree_flatten(model.parameters()))
mlx_path = Path(mlx_path)
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
config["quantization"] = {"group_size": 64, "bits": 4}
def update_config(path, module):
if hasattr(module, "bits"):
config["quantization"][path] = {
"group_size": module.group_size,
"bits": module.bits,
}
else:
config["quantization"][path] = False
tree_map_with_path(update_config, model.leaf_modules(), is_leaf=nn.Module.is_module)
save_config(config, config_path=mlx_path / "config.json")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", "-m", default="mlx-community/Qwen2.5-7B-Instruct-bf16"
)
parser.add_argument("--mlx-path", default="mlx_model")
parser.add_argument("--bits", type=int, default=3)
parser.add_argument("--group-size", type=int, default=64)
parser.add_argument("--num-samples", type=int, default=32)
parser.add_argument("--sequence-length", type=int, default=2048)
parser.add_argument("--seed", type=int, default=123)
args = parser.parse_args()
group = mx.distributed.init() if mx.distributed.is_available() else None
num_samples = args.num_samples
if group is not None and num_samples % group.size() > 0:
num_samples += group.size() - num_samples % group.size()
mx.random.seed(args.seed)
model_path = get_model_path(args.model, revision=None)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
calibration_data = load_wikitext(tokenizer, args.num_samples, args.sequence_length)
if group is not None:
calibration_data = dist_split(calibration_data, group)
awq_quantize(model, calibration_data, bits=args.bits, group_size=args.group_size)
save_model(model, tokenizer, config, model_path, args.mlx_path)
if __name__ == "__main__":
main()

View File

@@ -110,29 +110,17 @@ def main():
if tokenizer.chat_template is None: if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and ( if not args.ignore_chat_template and tokenizer.chat_template is not None:
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}] messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=False, continue_final_message=True
) )
# Treat the prompt as a prefix assuming that the suffix will be
# provided at generation time.
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
prompt = prompt[:-n]
else: else:
prompt = args.prompt prompt = tokenizer.encode(args.prompt)
cache = make_prompt_cache(model, args.max_kv_size) cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt)) y = mx.array(prompt)
# Process the prompt # Process the prompt
start = time.time() start = time.time()

View File

@@ -72,9 +72,7 @@ def main():
if query == "q": if query == "q":
break break
messages = [{"role": "user", "content": query}] messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate( for response in stream_generate(
model, model,
tokenizer, tokenizer,

View File

@@ -1,4 +1,8 @@
# Adapted from a PyTorch implementation by David Grangier # Copyright © 2024 Apple Inc.
"""
Adapted from a PyTorch implementation by David Grangier
"""
import argparse import argparse
import json import json
@@ -6,7 +10,7 @@ import logging
import os import os
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union
import lm_eval import lm_eval
import mlx.core as mx import mlx.core as mx
@@ -73,15 +77,19 @@ class MLXLM(LM):
path_or_hf_repo: str, path_or_hf_repo: str,
batch_size: int = 16, batch_size: int = 16,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
use_chat_template: Optional[bool] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self._batch_size = batch_size self._batch_size = batch_size
self._model, self._tokenizer = load(path_or_hf_repo) self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self._tokenizer.model_max_length self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template or (
self.tokenizer.chat_template is not None
)
def _score_fn(self, inputs, tokenize=True, step_size=32): def _score_fn(self, inputs, tokenize=True, step_size=32):
if tokenize: if tokenize:
inputs = self._tokenizer.encode(inputs) inputs = self._tokenize(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False) inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs) inputs = mx.array(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:] inputs, targets = inputs[..., :-1], inputs[..., 1:]
@@ -145,7 +153,12 @@ class MLXLM(LM):
return results return results
def _tokenize(self, texts): def _tokenize(self, texts):
return [tuple(self._tokenizer.encode(t)) for t in texts] return [
tuple(
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
)
for t in texts
]
def loglikelihood(self, requests) -> list[tuple[float, bool]]: def loglikelihood(self, requests) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context. """Compute log-likelihood of generating a continuation from a context.
@@ -217,6 +230,9 @@ class MLXLM(LM):
) )
return [(r[0], r[1] == r[2]) for r in results] return [(r[0], r[1] == r[2]) for r in results]
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
def loglikelihood_rolling(self, requests) -> list[float]: def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model. - We will use the full max context length of the model.
@@ -277,23 +293,16 @@ class MLXLM(LM):
assert "until" in keys assert "until" in keys
untils = [x["until"] for x in options] untils = [x["until"] for x in options]
completions = [] completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
if (
hasattr(self._tokenizer, "apply_chat_template")
and self._tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": context}]
context = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
context = self._tokenize(context)
max_tokens = min( max_tokens = min(
self._max_tokens, self._max_tokens,
self._tokenizer.model_max_length - len(self._tokenizer.encode(context)), self.tokenizer.model_max_length - len(context),
) )
text = "" text = ""
for response in stream_generate( for response in stream_generate(
self._model, self._tokenizer, prompt=context, max_tokens=max_tokens self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
): ):
text += response.text text += response.text
if any(u in text for u in until): if any(u in text for u in until):
@@ -321,7 +330,28 @@ def main():
type=int, type=int,
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.", help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
) )
parser.add_argument(
"--limit",
default=1.0,
help="Limit the number of examples per task.",
type=float,
)
parser.add_argument("--seed", type=int, default=123, help="Random seed.") parser.add_argument("--seed", type=int, default=123, help="Random seed.")
parser.add_argument(
"--fewshot-as-multiturn",
action="store_true",
help="Whether to provide the fewshot examples as a multiturn "
"conversation or a single user turn.",
default=False,
)
parser.add_argument(
"--apply-chat-template",
action=argparse.BooleanOptionalAction,
help="Specifies whether to apply a chat template to the prompt. If "
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
)
args = parser.parse_args() args = parser.parse_args()
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
@@ -332,12 +362,19 @@ def main():
mx.random.seed(args.seed) mx.random.seed(args.seed)
lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens) lm = MLXLM(
args.model,
batch_size=args.batch_size,
max_tokens=args.max_tokens,
use_chat_template=args.apply_chat_template,
)
results = lm_eval.simple_evaluate( results = lm_eval.simple_evaluate(
model=lm, model=lm,
tasks=args.tasks, tasks=args.tasks,
fewshot_as_multiturn=args.fewshot_as_multiturn,
apply_chat_template=lm.use_chat_template,
num_fewshot=args.num_shots, num_fewshot=args.num_shots,
limit=args.limit,
random_seed=args.seed, random_seed=args.seed,
numpy_random_seed=args.seed, numpy_random_seed=args.seed,
torch_random_seed=args.seed, torch_random_seed=args.seed,

View File

@@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model)
# User turn # User turn
prompt = "Hi my name is <Name>." prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response # Assistant response
response = generate( response = generate(
@@ -32,9 +30,7 @@ response = generate(
# User turn # User turn
prompt = "What's my name?" prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response # Assistant response
response = generate( response = generate(

View File

@@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}]
# Transform the prompt into the chat template # Transform the prompt into the chat template
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
conversation=conversation, tokenize=False, add_generation_prompt=True conversation=conversation, add_generation_prompt=True
) )
# Specify the maximum number of tokens # Specify the maximum number of tokens

View File

@@ -0,0 +1,75 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
```
/path/to/mpirun \
-np 2 \
--hostfile /path/to/hosts.txt \
python /path/to/pipeline_generate.py --prompt "hello world"
```
Make sure you can run MLX over MPI on two hosts. For more information see the
documentation:
https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
"""
import argparse
import mlx.core as mx
from mlx_lm import load, stream_generate
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
model_repo = "mlx-community/DeepSeek-V3-3bit"
model, tokenizer = load(model_repo, lazy=True)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
group = mx.distributed.init()
rank = group.rank()
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
def rprint(*args, **kwargs):
if rank == 0:
print(*args, **kwargs)
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens):
rprint(response.text, end="", flush=True)
rprint()
rprint("=" * 10)
rprint(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
rprint(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
rprint(f"Peak memory: {response.peak_memory:.3f} GB")

View File

@@ -43,10 +43,11 @@ def setup_arg_parser():
help="Optional path for the trained adapter weights and config.", help="Optional path for the trained adapter weights and config.",
) )
parser.add_argument( parser.add_argument(
"--eos-token", "--extra-eos-token",
type=str, type=str,
default=None, default=(),
help="End of sequence token for tokenizer", nargs="+",
help="Add tokens in the list of eos tokens that stop generation.",
) )
parser.add_argument( parser.add_argument(
"--system-prompt", "--system-prompt",
@@ -130,6 +131,18 @@ def setup_arg_parser():
type=int, type=int,
default=DEFAULT_QUANTIZED_KV_START, default=DEFAULT_QUANTIZED_KV_START,
) )
parser.add_argument(
"--draft-model",
type=str,
help="A model to be used for speculative decoding.",
default=None,
)
parser.add_argument(
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
default=2,
)
return parser return parser
@@ -161,8 +174,6 @@ def main():
{} if not using_cache else json.loads(metadata["tokenizer_config"]) {} if not using_cache else json.loads(metadata["tokenizer_config"])
) )
tokenizer_config["trust_remote_code"] = True tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model_path = args.model model_path = args.model
if using_cache: if using_cache:
@@ -181,6 +192,8 @@ def main():
adapter_path=args.adapter_path, adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config, tokenizer_config=tokenizer_config,
) )
for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token)
if args.use_default_chat_template: if args.use_default_chat_template:
if tokenizer.chat_template is None: if tokenizer.chat_template is None:
@@ -190,10 +203,7 @@ def main():
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt prompt = sys.stdin.read() if prompt == "-" else prompt
if not args.ignore_chat_template and ( if not args.ignore_chat_template and tokenizer.chat_template is not None:
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
if args.system_prompt is not None: if args.system_prompt is not None:
messages = [{"role": "system", "content": args.system_prompt}] messages = [{"role": "system", "content": args.system_prompt}]
else: else:
@@ -213,7 +223,16 @@ def main():
add_generation_prompt=True, add_generation_prompt=True,
) )
prompt = prompt[test_prompt.index("<query>") :] prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
prompt = tokenizer.encode(prompt)
if args.draft_model is not None:
draft_model, draft_tokenizer = load(args.draft_model)
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
raise ValueError("Draft model tokenizer does not match model tokenizer.")
else:
draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate( response = generate(
model, model,
@@ -227,6 +246,8 @@ def main():
kv_bits=args.kv_bits, kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size, kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start, quantized_kv_start=args.quantized_kv_start,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
) )
if not args.verbose: if not args.verbose:
print(response) print(response)

View File

@@ -2,6 +2,7 @@
import argparse import argparse
import math import math
import os
import re import re
import types import types
from pathlib import Path from pathlib import Path
@@ -57,6 +58,8 @@ CONFIG_DEFAULTS = {
"test": False, "test": False,
"test_batches": 500, "test_batches": 500,
"max_seq_length": 2048, "max_seq_length": 2048,
"config": None,
"grad_checkpoint": False,
"lr_schedule": None, "lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
} }
@@ -66,6 +69,7 @@ def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str,
help="The path to the local model directory or Hugging Face repo.", help="The path to the local model directory or Hugging Face repo.",
) )
@@ -74,7 +78,6 @@ def build_parser():
"--train", "--train",
action="store_true", action="store_true",
help="Do training", help="Do training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--data", "--data",
@@ -88,7 +91,6 @@ def build_parser():
"--fine-tune-type", "--fine-tune-type",
type=str, type=str,
choices=["lora", "dora", "full"], choices=["lora", "dora", "full"],
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.", help="Type of fine-tuning to perform: lora, dora, or full.",
) )
parser.add_argument( parser.add_argument(
@@ -133,7 +135,6 @@ def build_parser():
"--test", "--test",
action="store_true", action="store_true",
help="Evaluate on the test set after training", help="Evaluate on the test set after training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--test-batches", "--test-batches",
@@ -148,16 +149,15 @@ def build_parser():
parser.add_argument( parser.add_argument(
"-c", "-c",
"--config", "--config",
default=None, type=str,
help="A YAML configuration file with the training options", help="A YAML configuration file with the training options",
) )
parser.add_argument( parser.add_argument(
"--grad-checkpoint", "--grad-checkpoint",
action="store_true", action="store_true",
help="Use gradient checkpointing to reduce memory use.", help="Use gradient checkpointing to reduce memory use.",
default=None,
) )
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed") parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser return parser
@@ -271,6 +271,7 @@ def run(args, training_callback: TrainingCallback = None):
def main(): def main():
os.environ["TOKENIZERS_PARALLELISM"] = "true"
parser = build_parser() parser = build_parser()
args = parser.parse_args() args = parser.parse_args()
config = args.config config = args.config

View File

@@ -6,19 +6,18 @@ from transformers.commands.user import tabulate
def ask_for_confirmation(message: str) -> bool: def ask_for_confirmation(message: str) -> bool:
"""Ask user for confirmation with Y/N prompt.
Returns True for Y/yes, False for N/no/empty."""
y = ("y", "yes", "1") y = ("y", "yes", "1")
n = ("n", "no", "0") n = ("n", "no", "0", "")
all_values = y + n + ("",) full_message = f"{message} (y/n) "
full_message = f"{message} (Y/n) "
while True: while True:
answer = input(full_message).lower() answer = input(full_message).lower()
if answer == "":
return False
if answer in y: if answer in y:
return True return True
if answer in n: if answer in n:
return False return False
print(f"Invalid input. Must be one of {all_values}") print(f"Invalid input. Must be one of: yes/no/y/n or empty for no")
def main(): def main():
@@ -43,9 +42,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
if args.scan: if args.scan:
print( print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".')
"Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".'
)
hf_cache_info = scan_cache_dir() hf_cache_info = scan_cache_dir()
print( print(
tabulate( tabulate(
@@ -86,35 +83,41 @@ def main():
if args.pattern in repo.repo_id if args.pattern in repo.repo_id
] ]
if repos: if repos:
print("\nFound the following models:")
print( print(
tabulate( tabulate(
rows=[ rows=[
[ [
repo.repo_id, repo.repo_id,
repo.size_on_disk_str, # Added size information
str(repo.repo_path), str(repo.repo_path),
] ]
for repo in repos for repo in repos
], ],
headers=[ headers=[
"REPO ID", "REPO ID",
"SIZE", # Added size header
"LOCAL PATH", "LOCAL PATH",
], ],
) )
) )
confirmed = ask_for_confirmation(f"Confirm deletion ?") confirmed = ask_for_confirmation(
"\nAre you sure you want to delete these models?"
)
if confirmed: if confirmed:
for model_info in repos: for model_info in repos:
print(f"\nDeleting {model_info.repo_id}...")
for revision in sorted( for revision in sorted(
model_info.revisions, key=lambda revision: revision.commit_hash model_info.revisions, key=lambda revision: revision.commit_hash
): ):
strategy = hf_cache_info.delete_revisions(revision.commit_hash) strategy = hf_cache_info.delete_revisions(revision.commit_hash)
strategy.execute() strategy.execute()
print("Model(s) deleted.") print("\nModel(s) deleted successfully.")
else: else:
print("Deletion is cancelled. Do nothing.") print("\nDeletion cancelled - no changes made.")
else: else:
print(f"No models found.") print(f'No models found matching pattern "{args.pattern}"')
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -156,12 +156,13 @@ class CohereModel(nn.Module):
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
if mask is None:
j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j])
for layer, c in zip(self.layers, cache): for layer, c in zip(self.layers, cache):
h = layer(h, mask, c) h = layer(h, mask, c)

View File

@@ -0,0 +1,460 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek_v3"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
routed_scaling_factor: float = 1.0
kv_lora_rank: int = 512
q_lora_rank: int = 1536
qk_rope_head_dim: int = 64
v_head_dim: int = 128
qk_nope_head_dim: int = 128
topk_method: str = "noaux_tc"
scoring_func: str = "sigmoid"
norm_topk_prob: bool = True
n_group: Optional[int] = None
topk_group: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Dict = None
attention_bias: bool = False
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
return mx.clip(linear_func, 0, 1)
class DeepseekV3YarnRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
super().__init__()
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
scaling_factor, mscale_all_dim
)
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
freq_inter = scaling_factor * base ** (
mx.arange(0, dim, 2, dtype=mx.float32) / dim
)
low, high = yarn_find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_position_embeddings,
)
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
self._freqs = (freq_inter * freq_extra) / (
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
)
def __call__(self, x, offset=0):
if self.mscale != 1.0:
x = self.mscale * x
return mx.fast.rope(
x,
x.shape[-1],
traditional=True,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
class DeepseekV3Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.scale = self.q_head_dim**-0.5
if self.q_lora_rank is None:
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
)
else:
self.q_a_proj = nn.Linear(
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
self.q_b_proj = nn.Linear(
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=config.attention_bias,
)
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self.rope = DeepseekV3YarnRotaryEmbedding(
dim=self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**rope_kwargs,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekV3MLP(nn.Module):
def __init__(
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = (
config.intermediate_size if intermediate_size is None else intermediate_size
)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x):
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.sigmoid(gates.astype(mx.float32))
assert self.topk_method == "noaux_tc", "Unsupported topk method."
bsz, seq_len = x.shape[:2]
scores = scores + self.e_score_correction_bias
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if self.top_k > 1 and self.norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * self.routed_scaling_factor
return inds, scores
class DeepseekV3MoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV3MLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekV3DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekV3Attention(config)
self.mlp = (
DeepseekV3MoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekV3MLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
# Protect against overflow for fp16
if out.dtype == mx.float16:
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
return out
class DeepseekV3Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekV3DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.layers = self.layers[start : start + layers_per_rank]
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1))
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h)[: h.shape[0]]
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekV3Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
# Remove multi-token prediction layer
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")}
@property
def layers(self):
return self.model.layers

View File

@@ -145,16 +145,16 @@ class GPTBigCodeModel(nn.Module):
hidden_states = self.wte(inputs) hidden_states = self.wte(inputs)
mask = None mask = None
if hidden_states.shape[1] > 1: if mask is not None and hidden_states.shape[1] > 1:
mask = create_attention_mask(hidden_states, cache)
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None: if cache is None:
cache = [None] * len(self.h) cache = [None] * len(self.h)
position_ids = mx.array(np.arange(L))
else:
position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L))
hidden_states += self.wpe(position_ids)
for layer, c in zip(self.h, cache): for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c) hidden_states = layer(hidden_states, mask, cache=c)

View File

@@ -1,4 +1,4 @@
mlx>=0.19.2 mlx>=0.22.0
numpy numpy
transformers[sentencepiece]>=4.39.3 transformers[sentencepiece]>=4.39.3
protobuf protobuf

View File

@@ -590,14 +590,10 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type # Determine response type
self.request_id = f"chatcmpl-{uuid.uuid4()}" self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if ( if self.tokenizer.chat_template:
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
):
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
body["messages"], body["messages"],
body.get("tools", None), body.get("tools", None),
tokenize=True,
add_generation_prompt=True, add_generation_prompt=True,
) )
else: else:

View File

@@ -266,6 +266,18 @@ class TokenizerWrapper:
else {tokenizer.eos_token_id} else {tokenizer.eos_token_id}
) )
def add_eos_token(self, token: str):
token_id = None
try:
token_id = int(token)
except ValueError:
token_id = self._tokenizer.convert_tokens_to_ids(token)
if token_id is None:
raise ValueError(f"'{token}' is not a token for this tokenizer")
self._eos_token_ids.add(token_id)
def __getattr__(self, attr): def __getattr__(self, attr):
if attr == "detokenizer": if attr == "detokenizer":
return self._detokenizer return self._detokenizer

View File

@@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@@ -10,41 +10,47 @@ class Dataset:
Light-weight wrapper to hold a dataset. Light-weight wrapper to hold a dataset.
""" """
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"): def __init__(
self._text_key = text_key self,
self._data = data data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
text_key: str = "text",
):
self._data = [tokenizer.encode(d[text_key]) for d in data]
for d in self._data:
if d[-1] != tokenizer.eos_token_id:
d.append(tokenizer.eos_token_id)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx][self._text_key] return self._data[idx]
def __len__(self): def __len__(self):
if self._data is None:
return 0
return len(self._data) return len(self._data)
class ChatDataset(Dataset): class ChatDataset:
""" """
A dataset for chat data in the format of {"messages": [...]} A dataset for chat data in the format of {"messages": [...]}
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data) self._data = [
self._tokenizer = tokenizer tokenizer.apply_chat_template(
d["messages"],
tools=d.get("tools", None),
)
for d in data
]
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
messages = self._data[idx]["messages"] return self._data[idx]
text = self._tokenizer.apply_chat_template(
messages, def __len__(self):
tools=self._data[idx].get("tools", None), return len(self._data)
tokenize=False,
add_generation_prompt=True,
)
return text
class CompletionsDataset(Dataset): class CompletionsDataset:
""" """
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values or using user-provided keys for prompt and completion values
@@ -55,36 +61,41 @@ class CompletionsDataset(Dataset):
self, self,
data: List[Dict[str, str]], data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt", prompt_key: str,
completion_key: str = "completion", completion_key: str,
): ):
super().__init__(data) self._data = [
self._tokenizer = tokenizer tokenizer.apply_chat_template(
self._prompt_key = prompt_key [
self._completion_key = completion_key {"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]},
],
)
for d in data
]
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
data = self._data[idx] return self._data[idx]
text = self._tokenizer.apply_chat_template(
[ def __len__(self):
{"role": "user", "content": data[self._prompt_key]}, return len(self._data)
{"role": "assistant", "content": data[self._completion_key]},
],
tokenize=False,
add_generation_prompt=True,
)
return text
def create_dataset(data, tokenizer: PreTrainedTokenizer = None): def create_dataset(
data,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
sample = data[0] sample = data[0]
if "messages" in sample: if "messages" in sample:
return ChatDataset(data, tokenizer) return ChatDataset(data, tokenizer)
elif "prompt" in sample and "completion" in sample: elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer) return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample: elif "text" in sample:
return Dataset(data) return Dataset(data, tokenizer)
else: else:
raise ValueError( raise ValueError(
"Unsupported data format, check the supported formats here:\n" "Unsupported data format, check the supported formats here:\n"
@@ -92,20 +103,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
) )
def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer): def load_local_dataset(
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
def load_subset(path): def load_subset(path):
if not path.exists(): if not path.exists():
return [] return []
with open(path, "r") as fid: with open(path, "r") as fid:
data = [json.loads(l) for l in fid] data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer) return create_dataset(data, tokenizer, prompt_feature, completion_feature)
names = ("train", "valid", "test") names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
return train, valid, test return train, valid, test
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): def load_hf_dataset(
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
from datasets import exceptions, load_dataset from datasets import exceptions, load_dataset
try: try:
@@ -114,7 +135,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test") names = ("train", "valid", "test")
train, valid, test = [ train, valid, test = [
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else [] (
create_dataset(
dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys()
else []
)
for n in names for n in names
] ]
@@ -143,7 +170,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
if prompt_feature and completion_feature: if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature: elif text_feature:
return Dataset(train_ds, text_key=text_feature) return Dataset(train_ds, tokenizer, text_key=text_feature)
else: else:
raise ValueError( raise ValueError(
"Specify either a prompt and completion feature or a text " "Specify either a prompt and completion feature or a text "
@@ -166,15 +193,22 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
def load_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None: if getattr(args, "hf_dataset", False):
train, valid, test = load_custom_hf_dataset(args, tokenizer) train, valid, test = load_custom_hf_dataset(args, tokenizer)
else: else:
data_path = Path(args.data) data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists(): if data_path.exists():
train, valid, test = load_local_dataset(data_path, tokenizer) train, valid, test = load_local_dataset(
data_path, tokenizer, prompt_feature, completion_feature
)
else: else:
print(f"Loading Hugging Face dataset {args.data}.") print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args.data, tokenizer) train, valid, test = load_hf_dataset(
args.data, tokenizer, prompt_feature, completion_feature
)
if args.train and len(train) == 0: if args.train and len(train) == 0:
raise ValueError( raise ValueError(

View File

@@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
while True: while True:
indices = np.random.permutation(len(batch_idx)) indices = np.random.permutation(len(batch_idx))
for i in indices: for i in indices:
# Encode batch batch = [dataset[j] for j in batch_idx[i]]
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length: if max(lengths) > max_seq_length:
print( print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "

View File

@@ -2,10 +2,12 @@
import contextlib import contextlib
import copy import copy
import functools
import glob import glob
import importlib import importlib
import json import json
import logging import logging
import os
import shutil import shutil
import time import time
from dataclasses import dataclass from dataclasses import dataclass
@@ -15,7 +17,17 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
try:
from modelscope import snapshot_download
except ImportError:
raise ImportError(
"Please run `pip install modelscope` to activate the ModelScope."
)
else:
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@@ -153,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
Path: The path to the model. Path: The path to the model.
""" """
model_path = Path(path_or_hf_repo) model_path = Path(path_or_hf_repo)
if not model_path.exists(): if not model_path.exists():
try: try:
model_path = Path( model_path = Path(
snapshot_download( snapshot_download(
repo_id=path_or_hf_repo, path_or_hf_repo,
revision=revision, revision=revision,
allow_patterns=[ allow_patterns=[
"*.json", "*.json",
@@ -207,12 +220,6 @@ def generate_step(
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None, prompt_progress_callback: Optional[Callable[int, int]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@@ -256,25 +263,17 @@ def generate_step(
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
print(
"[Warning] Specifying sampling arguments to ``generate_step`` is "
"deprecated. Pass in a ``sampler`` instead."
)
if repetition_penalty is not None:
print(
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
"Pass in ``logits_processors`` instead."
)
sampler = sampler or make_sampler(
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
)
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
prompt_progress_callback = prompt_progress_callback or (lambda *_: None) prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
def _step(y): def _step(y):
with mx.stream(generation_stream): with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache) logits = model(y[None], cache=prompt_cache)
@@ -287,9 +286,7 @@ def generate_step(
for processor in logits_processors: for processor in logits_processors:
logits = processor(tokens, logits) logits = processor(tokens, logits)
maybe_quantize_kv_cache( quantize_cache_fn(prompt_cache)
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
logprobs = logits - mx.logsumexp(logits, keepdims=True) logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs) y = sampler(logprobs)
@@ -300,9 +297,7 @@ def generate_step(
prompt_processed_tokens = 0 prompt_processed_tokens = 0
while y.size > prefill_step_size: while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache) model(y[:prefill_step_size][None], cache=prompt_cache)
maybe_quantize_kv_cache( quantize_cache_fn(prompt_cache)
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
mx.eval([c.state for c in prompt_cache]) mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size prompt_processed_tokens += prefill_step_size
@@ -329,10 +324,162 @@ def generate_step(
n += 1 n += 1
def speculative_generate_step(
prompt: mx.array,
model: nn.Module,
draft_model: nn.Module,
*,
num_draft_tokens=2,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
draft_model (nn.Module): The draft model for speculative decoding.
num_draft_tokens (int, optional): The number of draft tokens for
speculative decoding. Default: ``2``.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. The cache must be trimmable.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
y = prompt
tokens = None
# Create the KV cache for generation
if prompt_cache is None:
model_cache = cache.make_prompt_cache(model)
draft_cache = cache.make_prompt_cache(draft_model)
elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)):
raise ValueError("Wrong number of layers in the prompt cache.")
else:
model_cache = prompt_cache[: len(model.layers)]
draft_cache = prompt_cache[len(model.layers) :]
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
def _step(model, cache, y, n_predict=1):
with mx.stream(generation_stream):
logits = model(y[None], cache=cache)
logits = logits[:, -n_predict:, :]
quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs).squeeze(0)
return y, logprobs.squeeze(0)
def _prefill(model, cache, y):
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
quantize_cache_fn(cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
mx.metal.clear_cache()
return y
def _rewind_cache(num_draft, num_accept):
cache.trim_prompt_cache(model_cache, num_draft - num_accept)
cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0))
def _draft_generate(y, num_draft):
if num_draft == 0:
return mx.array([], mx.uint32)
ys = []
for _ in range(num_draft):
y, _ = _step(draft_model, draft_cache, y)
mx.async_eval(y)
ys.append(y)
return mx.concatenate(ys)
with mx.stream(generation_stream):
draft_y = _prefill(draft_model, draft_cache, y)
y = _prefill(model, model_cache, y)
ntoks = 0
# Set these so the finally block doesn't raise
num_draft = 0
n = 0
try:
while True:
num_draft = min(max_tokens - ntoks, num_draft_tokens)
draft_tokens = _draft_generate(draft_y, num_draft)
y = mx.concatenate([y, draft_tokens])
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
mx.eval(tokens, draft_tokens)
draft_tokens = draft_tokens.tolist()
tokens = tokens.tolist()
n = 0
while n < num_draft:
tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n]
if tn != dtn:
break
n += 1
ntoks += 1
yield tn, lpn
if ntoks == max_tokens:
break
if ntoks < max_tokens:
ntoks += 1
yield tokens[n], logprobs[n]
if ntoks == max_tokens:
break
y = mx.array([tokens[n]], mx.uint32)
draft_y = y
# If we accpeted all the draft tokens, include the last
# draft token in the next draft step since it hasn't been
# processed yet by the draft model
if n == num_draft:
draft_y = mx.concatenate(
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
)
_rewind_cache(num_draft, n)
finally:
_rewind_cache(num_draft, n)
def stream_generate( def stream_generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]], prompt: Union[str, mx.array, List[int]],
draft_model: Optional[nn.Module] = None,
**kwargs, **kwargs,
) -> Generator[GenerationResponse, None, None]: ) -> Generator[GenerationResponse, None, None]:
""" """
@@ -341,7 +488,11 @@ def stream_generate(
Args: Args:
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. prompt (Union[str, mx.array, List[int]]): The input prompt string or
integer tokens.
draft_model (Optional[nn.Module]): An optional draft model. If provided
then speculative decoding is used. The draft model must use the same
tokenizer as the main model. Default: ``None``.
kwargs: The remaining options get passed to :func:`generate_step`. kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details. See :func:`generate_step` for more details.
@@ -353,16 +504,28 @@ def stream_generate(
tokenizer = TokenizerWrapper(tokenizer) tokenizer = TokenizerWrapper(tokenizer)
if not isinstance(prompt, mx.array): if not isinstance(prompt, mx.array):
prompt = mx.array( if isinstance(prompt, str):
prompt if isinstance(prompt, list) else tokenizer.encode(prompt) # Try to infer if special tokens are needed
) add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
prompt = mx.array(prompt)
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
if draft_model is None:
kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs)
else:
kwargs.pop("max_kv_size", None)
token_generator = speculative_generate_step(
prompt, model, draft_model, **kwargs
)
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
tic = time.perf_counter() tic = time.perf_counter()
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): for n, (token, logprobs) in enumerate(token_generator):
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time prompt_tps = prompt.size / prompt_time
@@ -401,7 +564,7 @@ def stream_generate(
def generate( def generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str, prompt: Union[str, List[int]],
verbose: bool = False, verbose: bool = False,
formatter: Optional[Callable] = None, formatter: Optional[Callable] = None,
**kwargs, **kwargs,
@@ -412,7 +575,7 @@ def generate(
Args: Args:
model (nn.Module): The language model. model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt. prompt (Union[str, List[int]]): The input prompt string or integer tokens.
verbose (bool): If ``True``, print tokens and timing information. verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``. Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`. kwargs: The remaining options get passed to :func:`stream_generate`.
@@ -425,7 +588,6 @@ def generate(
) )
if verbose: if verbose:
print("=" * 10) print("=" * 10)
print("Prompt:", prompt)
text = "" text = ""
for response in stream_generate(model, tokenizer, prompt, **kwargs): for response in stream_generate(model, tokenizer, prompt, **kwargs):
@@ -558,7 +720,7 @@ def load(
Defaults to an empty dictionary. Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``. to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If ``False`` eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
Returns: Returns:
@@ -652,12 +814,12 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
model, tokenizer = load("{upload_repo}") model, tokenizer = load("{upload_repo}")
prompt="hello" prompt = "hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: if tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}] messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=True
) )
response = generate(model, tokenizer, prompt=prompt, verbose=True) response = generate(model, tokenizer, prompt=prompt, verbose=True)
@@ -670,12 +832,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
api = HfApi() api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True) api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder( api.upload_large_folder(
folder_path=path, folder_path=path,
repo_id=upload_repo, repo_id=upload_repo,
repo_type="model", repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
) )
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.") print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")

View File

@@ -27,12 +27,11 @@ setup(
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
python_requires=">=3.8", python_requires=">=3.8",
extras_require={ extras_require={
"testing": ["datasets"], "test": ["datasets"],
"evaluation": ["lm-eval"], "evaluate": ["lm-eval", "tqdm"],
}, },
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"mlx_lm.awq = mlx_lm.awq:main",
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
"mlx_lm.chat = mlx_lm.chat:main", "mlx_lm.chat = mlx_lm.chat:main",
"mlx_lm.convert = mlx_lm.convert:main", "mlx_lm.convert = mlx_lm.convert:main",

View File

@@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase):
data = {"text": "This is an example for the model."} data = {"text": "This is an example for the model."}
self.save_data(4 * [data]) self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
train, valid, test = datasets.load_dataset(args, None) tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertEqual(len(train), 4) self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4) self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0) self.assertEqual(len(test), 0)
@@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase):
"name": "billsum", "name": "billsum",
"prompt_feature": "text", "prompt_feature": "text",
"completion_feature": "summary", "completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
}, },
test=False, test=False,
train=True, train=True,

View File

@@ -682,6 +682,43 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers model, args.model_type, args.vocab_size, args.num_hidden_layers
) )
def test_deepseek_v3(self):
from mlx_lm.models import deepseek_v3
args = deepseek_v3.ModelArgs(
model_type="deepseek_v3",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
n_routed_experts=4,
n_group=2,
topk_group=1,
num_experts_per_tok=2,
n_shared_experts=1,
kv_lora_rank=4,
q_lora_rank=4,
qk_rope_head_dim=32,
v_head_dim=16,
qk_nope_head_dim=32,
rope_scaling={
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
},
)
model = deepseek_v3.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma2(self): def test_gemma2(self):
from mlx_lm.models import gemma2 from mlx_lm.models import gemma2

View File

@@ -17,7 +17,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
self.config = args self.config = args
self.custom_attribute = "This is a custom model" self.custom_attribute = "This is a custom model"
def load_weights(self, weights): def load_weights(self, weights, **kwargs):
self.qwenWeights = weights self.qwenWeights = weights
class CustomQwenConfig: class CustomQwenConfig: