14 Commits

Author SHA1 Message Date
Awni Hannun
65b792d7c0 fix lazy load 2025-02-06 07:28:59 -08:00
Angelos Katharopoulos
617f9289b9 Make the chat distributed 2025-02-06 07:28:59 -08:00
Angelos Katharopoulos
026362e0f8 Remove async eval and add sequential load 2025-02-06 07:28:58 -08:00
Angelos Katharopoulos
a0ce0594f6 Temporarily remove async_eval 2025-02-06 07:28:03 -08:00
Angelos Katharopoulos
d77840207c Start distributed inference for llama models 2025-02-06 07:28:03 -08:00
Pedro Cuenca
e2e5478da5 READMEs: fix typo in link, minor update. (#1246) 2025-02-04 11:52:32 -08:00
Awni Hannun
21d0ab6e8a fix deepseek sharding (#1242) 2025-02-03 16:59:50 -08:00
Gökdeniz Gülmez
0989c073b0 Optimizations for mamba1 (#1213)
* added mx.einsum() operations: before: 41.293 tokens-per-sec, after: 57.822 tokens-per-sec

* Fused Operations in delta, B, C = ... :. Before: 57.822 tokens-per-sec, after: 83.890 tokens-per-sec

* Pre-computing A_log. After: 83.890 tokens-per-sec, before: 85.848 tokens-per-sec

* Update MambaBlock, Batched Input Processing, Improved Cache Handling, Pre-computed Constants, Cleaner State Management, Explicit Return Values:. Before: 82.442 tokens-per-sec, after: 129.130 tokens-per-sec.

* cleaning up and adding apple copyright to helium modelfile

* update Copyright to this year

* nits + even faster

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-02-03 13:36:08 -08:00
Awni Hannun
d9924d08d1 Fix no validation in lora (#1241) 2025-02-03 09:55:24 -08:00
Awni Hannun
9c2ef38d4d only download local shard (#1240) 2025-02-02 13:58:44 -08:00
Awni Hannun
e8afb59de4 better overflow correction (#1229) 2025-01-28 14:37:30 -08:00
Anchen
7a83077cd7 chore(mlx-lm): support text type content in messages (#1225)
* chore(mlx-lm): support text type content

* chore: optimize the messagef content processing

* nits + format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-27 17:13:50 -08:00
Awni Hannun
f44a52e2dc batched min p and fix spec gen sampling (#1222) 2025-01-27 15:40:31 -08:00
Gökdeniz Gülmez
77faa14ba4 adding support for kyutai's helium (#1208)
* initial commit

* adding helium into training

* Update ACKNOWLEDGMENTS.md

* nits

* nits

* fixes / nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-26 07:19:07 -08:00
20 changed files with 660 additions and 219 deletions

View File

@@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples. - Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models. - Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`. - Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`. - Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.

View File

@@ -45,7 +45,7 @@ Some more useful examples are listed below.
### Hugging Face ### Hugging Face
Note: You can now directly download a few converted checkpoints from the [MLX You can directly use or download converted checkpoints from the [MLX
Community](https://huggingface.co/mlx-community) organization on Hugging Face. Community](https://huggingface.co/mlx-community) organization on Hugging Face.
We encourage you to join the community and [contribute new We encourage you to join the community and [contribute new
models](https://github.com/ml-explore/mlx-examples/issues/155). models](https://github.com/ml-explore/mlx-examples/issues/155).

View File

@@ -164,7 +164,7 @@ mlx_lm.convert \
``` ```
Models can also be converted and quantized directly in the Models can also be converted and quantized directly in the
[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging [mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
Face Space. Face Space.
### Long Prompts and Generations ### Long Prompts and Generations

View File

@@ -16,6 +16,25 @@ DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def share_message(world, prompt):
if world.size() == 1:
return prompt
if world.rank() == 0:
size = mx.array([len(prompt)])
else:
size = mx.array([0])
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
if size == 0:
return []
if world.rank() == 0:
prompt = mx.array(prompt)
else:
prompt = mx.array([0] * len(prompt))
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
def setup_arg_parser(): def setup_arg_parser():
"""Set up and return the argument parser.""" """Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM") parser = argparse.ArgumentParser(description="Chat with an LLM")
@@ -54,6 +73,7 @@ def setup_arg_parser():
def main(): def main():
world = mx.distributed.init()
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
@@ -63,16 +83,30 @@ def main():
args.model, args.model,
adapter_path=args.adapter_path, adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True}, tokenizer_config={"trust_remote_code": True},
sequential_load=mx.distributed.init().size() > 1,
) )
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") print(f"Node {world.rank()} of {world.size()}", flush=True)
print(
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
flush=True,
)
world.barrier()
prompt_cache = make_prompt_cache(model, args.max_kv_size) prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True: while True:
query = input(">> ") if world.rank() == 0:
if query == "q": query = input(">> ")
if query == "q":
prompt = []
else:
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True
)
prompt = share_message(world, prompt)
if len(prompt) == 0:
break break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate( for response in stream_generate(
model, model,
tokenizer, tokenizer,
@@ -81,8 +115,10 @@ def main():
sampler=make_sampler(args.temp, args.top_p), sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
): ):
print(response.text, flush=True, end="") if world.rank() == 0:
print() print(response, flush=True, end="")
if world.rank() == 0:
print()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -10,7 +10,7 @@ import logging
import os import os
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union
import lm_eval import lm_eval
import mlx.core as mx import mlx.core as mx
@@ -20,10 +20,11 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
from .models.base import create_causal_mask
from .models.cache import make_prompt_cache from .models.cache import make_prompt_cache
from .utils import load, stream_generate from .utils import load, stream_generate
PAD = 0
def _len_longest_common_prefix(a, b): def _len_longest_common_prefix(a, b):
l = 0 l = 0
@@ -42,14 +43,31 @@ def _rstrip_until(s, untils):
return s[: min(f)] return s[: min(f)]
def _pad_inputs(inputs): def _pad_inputs(
lengths = np.array([len(x) for x in inputs]) inputs,
maxlen = lengths.max() maxlen,
padded = np.stack( genlen=0,
[np.pad(x, (0, maxlen - len(x))) for x in inputs], pad_left=False,
pad_multiple=32,
truncate=False,
):
# pad the prompts to the left with at least genlen tokens.
actual_maxlen = max(len(p) for p in inputs) + genlen
if actual_maxlen > maxlen:
if not truncate:
raise ValueError("Inputs are too long.")
else: # drop begining
actual_maxlen = maxlen
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
if pad_multiple > 0:
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
maxlen *= pad_multiple
assert PAD == 0
lr = np.array((1, 0) if pad_left else (0, 1))
return np.stack(
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
axis=0, axis=0,
) )
return mx.array(padded), mx.array(lengths)
@register_model("mlxlm") @register_model("mlxlm")
@@ -65,33 +83,32 @@ class MLXLM(LM):
self._batch_size = batch_size self._batch_size = batch_size
self._model, self.tokenizer = load(path_or_hf_repo) self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self.tokenizer.model_max_length self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template and ( self.use_chat_template = use_chat_template or (
self.tokenizer.chat_template is not None self.tokenizer.chat_template is not None
) )
def _score_fn(self, inputs, step_size: int = 64): def _score_fn(self, inputs, tokenize=True, step_size=32):
inputs, lengths = _pad_inputs(inputs) if tokenize:
inputs = self._tokenize(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:] inputs, targets = inputs[..., :-1], inputs[..., 1:]
cache = make_prompt_cache(self._model) cache = make_prompt_cache(self._model)
mask = targets != PAD
scores, is_greedy = [], [] scores, is_greedy = [], []
for i in range(0, inputs.shape[1], step_size): for i in range(0, inputs.shape[1], step_size):
inp = inputs[:, i : i + step_size] logits = self._model(inputs[:, i : i + step_size], cache=cache)
T = inp.shape[1]
offset = cache[0].offset
mask = create_causal_mask(T, offset, lengths=lengths)
mask = mask == 0
logits = self._model(inp, cache=cache, mask=mask)
log_probs = nn.log_softmax(logits.astype(mx.float32)) log_probs = nn.log_softmax(logits.astype(mx.float32))
score = mx.take_along_axis( score = mx.take_along_axis(
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1 log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
)[..., 0] )[..., 0]
ig = targets[:, i : i + step_size] == mx.argmax(logits, axis=-1) ig = mask[:, i : i + step_size] * (
ig = mx.where(mx.arange(T) + offset < lengths[:, None], ig, False) targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
)
mx.eval(score, ig) mx.eval(score, ig)
mx.metal.clear_cache() mx.metal.clear_cache()
@@ -102,32 +119,38 @@ class MLXLM(LM):
scores = mx.concatenate(scores, axis=1) scores = mx.concatenate(scores, axis=1)
is_greedy = mx.concatenate(is_greedy, axis=1) is_greedy = mx.concatenate(is_greedy, axis=1)
return scores, lengths, is_greedy return scores, mask.sum(axis=-1), is_greedy
def _loglikelihood(self, texts, score_spans=None): def _loglikelihood(self, texts, score_spans=None, tokenize=True):
all_scores = mx.zeros(len(texts)) # sort by length to get batches with little padding.
all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_) sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
for i in tqdm(range(0, len(texts), self._batch_size)): sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
batch = texts[i : i + self._batch_size] sorted_spans = None
scores, lengths, is_greedy = self._score_fn(batch) if score_spans is not None:
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
ind = np.arange(scores.shape[-1]) results = []
if score_spans is not None: for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
spans = score_spans[i : i + self._batch_size] batch = sorted_inputs[i : i + self._batch_size]
lengths = [end - start for start, end in spans] scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
masks = mx.array( for j in range(len(batch)):
np.array([(ind >= start) & (ind < end) for start, end in spans]) if sorted_spans is None: # full sequence score
) mask = mx.arange(scores[j].shape[-1]) < length
else: score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
masks = ind[None] < lengths[:, None] ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
else: # subsequence score
start, end = sorted_spans[i + j]
score = scores[j][start:end].astype(mx.float32).sum()
ig = is_greedy[j][start:end].astype(mx.int32).sum()
length = end - start
scores = (masks * scores).sum(axis=-1) results.append((score.item(), ig.item(), length))
is_greedy = (masks * is_greedy).sum(axis=-1)
all_scores[i : i + self._batch_size] = scores # reorder the outputs
all_is_greedy[i : i + self._batch_size] = is_greedy == lengths inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(results))]
return all_scores, all_is_greedy return results
def _tokenize(self, texts): def _tokenize(self, texts):
return [ return [
@@ -199,53 +222,16 @@ class MLXLM(LM):
+ "completion longer than context." + "completion longer than context."
) )
num_results = len(shortened)
# sort by length to get batches with little padding.
sorted_indices = sorted(range(len(shortened)), key=lambda i: -len(shortened[i]))
shortened = [shortened[i] for i in sorted_indices]
completion_spans = [completion_spans[i] for i in sorted_indices]
group = mx.distributed.init()
# split strided so we have approximately the same lengths on each node
shortened = shortened[group.rank() :: group.size()]
completion_spans = completion_spans[group.rank() :: group.size()]
# model scoring, returns num_requests x (logp, is_greedy, length). # model scoring, returns num_requests x (logp, is_greedy, length).
scores, is_greedy = self._loglikelihood( results = self._loglikelihood(
shortened, shortened,
score_spans=completion_spans, score_spans=completion_spans,
tokenize=False,
) )
return [(r[0], r[1] == r[2]) for r in results]
# all gather the results across groups
if group.size() > 1:
per_group = int(np.ceil(num_results / group.size()))
scores = mx.pad(scores, ((0, per_group - len(scores)),))
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
scores = scores.T.reshape(-1)
is_greedy = is_greedy.T.reshape(-1)
scores = np.array(scores[:num_results])
is_greedy = np.array(is_greedy[:num_results])
results = [(score, ig) for score, ig in zip(scores, is_greedy)]
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(inv_sort))]
return results
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
if len(chat_history) == 0:
return ""
return lm_eval.models.huggingface.HFLM.apply_chat_template(
chat_history, add_generation_prompt
)
def loglikelihood_rolling(self, requests) -> list[float]: def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
@@ -282,9 +268,8 @@ class MLXLM(LM):
logging.info( logging.info(
"Estimating loglikelihood rolling for %d sequences." % len(requests) "Estimating loglikelihood rolling for %d sequences." % len(requests)
) )
inputs = self._tokenize([req.args[0] for req in requests]) inputs = [req.args[0] for req in requests]
scores, _ = self._loglikelihood(inputs) return [t[0] for t in self._loglikelihood(inputs)]
return scores.tolist()
def generate_until(self, requests) -> list[str]: def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
@@ -347,7 +332,7 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--limit", "--limit",
default=None, default=1.0,
help="Limit the number of examples per task.", help="Limit the number of examples per task.",
type=float, type=float,
) )
@@ -361,8 +346,11 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--apply-chat-template", "--apply-chat-template",
action="store_true", action=argparse.BooleanOptionalAction,
help="Specifies whether to apply a chat template to the prompt.", help="Specifies whether to apply a chat template to the prompt. If "
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
) )
args = parser.parse_args() args = parser.parse_args()

View File

@@ -4,10 +4,11 @@
Run with: Run with:
``` ```
/path/to/mpirun \ mlx.launch \
-np 2 \
--hostfile /path/to/hosts.txt \ --hostfile /path/to/hosts.txt \
python /path/to/pipeline_generate.py --prompt "hello world" --backend mpi \
/path/to/pipeline_generate.py \
--prompt "hello world"
``` ```
Make sure you can run MLX over MPI on two hosts. For more information see the Make sure you can run MLX over MPI on two hosts. For more information see the
@@ -17,62 +18,110 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
""" """
import argparse import argparse
import json
from pathlib import Path
import mlx.core as mx import mlx.core as mx
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from mlx_lm import load, stream_generate from mlx_lm import load, stream_generate
from mlx_lm.utils import load_model, load_tokenizer
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
model, tokenizer = load(args.model, lazy=True)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
group = mx.distributed.init()
rank = group.rank()
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
def rprint(*args, **kwargs): def download(repo: str, allow_patterns: list[str]) -> Path:
if rank == 0: return Path(
print(*args, **kwargs) snapshot_download(
repo,
allow_patterns=allow_patterns,
)
)
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens): def shard_and_load(repo):
rprint(response.text, end="", flush=True) # Get model path with everything but weight safetensors
model_path = download(
args.model,
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
)
rprint() # Lazy load and shard model to figure out
rprint("=" * 10) # which weights we need
rprint( model, _ = load_model(model_path, lazy=True, strict=False)
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec" group = mx.distributed.init(backend="mpi")
) rank = group.rank()
rprint( model.model.pipeline(group)
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec" # Figure out which files we need for the local shard
) with open(model_path / "model.safetensors.index.json", "r") as fid:
rprint(f"Peak memory: {response.peak_memory:.3f} GB") weight_index = json.load(fid)["weight_map"]
local_files = set()
for k, _ in tree_flatten(model.parameters()):
local_files.add(weight_index[k])
# Download weights for local shard
download(args.model, allow_patterns=local_files)
# Load and shard the model, and load the weights
tokenizer = load_tokenizer(model_path)
model, _ = load_model(model_path, lazy=True, strict=False)
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
group = mx.distributed.init(backend="mpi")
rank = group.rank()
def rprint(*args, **kwargs):
if rank == 0:
print(*args, **kwargs)
model, tokenizer = shard_and_load(args.model)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model, tokenizer, prompt, max_tokens=args.max_tokens
):
rprint(response.text, end="", flush=True)
rprint()
rprint("=" * 10)
rprint(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
rprint(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
rprint(f"Peak memory: {response.peak_memory:.3f} GB")

View File

@@ -191,6 +191,7 @@ def main():
model_path, model_path,
adapter_path=args.adapter_path, adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config, tokenizer_config=tokenizer_config,
sequential_load=mx.distributed.init().size() > 1,
) )
for eos_token in args.extra_eos_token: for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token) tokenizer.add_eos_token(eos_token)
@@ -234,13 +235,17 @@ def main():
else: else:
draft_model = None draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
world = mx.distributed.init()
print(f"Node {world.rank()} of {world.size()}", flush=True)
world.barrier()
response = generate( response = generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
verbose=args.verbose,
sampler=sampler, sampler=sampler,
verbose=args.verbose and world.rank() == 0,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits, kv_bits=args.kv_bits,
@@ -249,8 +254,10 @@ def main():
draft_model=draft_model, draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens, num_draft_tokens=args.num_draft_tokens,
) )
if not args.verbose:
if not args.verbose and mx.distributed.init().rank() == 0:
print(response) print(response)
mx.synchronize()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -364,8 +364,30 @@ class DeepseekV2Model(nn.Module):
DeepseekV2DecoderLayer(config, idx) DeepseekV2DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers) for idx in range(config.num_hidden_layers)
] ]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__( def __call__(
self, self,
x: mx.array, x: mx.array,
@@ -374,14 +396,31 @@ class DeepseekV2Model(nn.Module):
) -> mx.array: ) -> mx.array:
h = self.embed_tokens(x) h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None: if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * self.num_layers
for layer, c in zip(self.layers, cache): # Receive from the previous process in the pipeline
h = layer(h, mask, c) if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for i in range(self.num_layers):
h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
return self.norm(h) return self.norm(h)
@@ -418,4 +457,4 @@ class Model(nn.Module):
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@@ -2,6 +2,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import mlx.core as mx import mlx.core as mx
@@ -125,6 +126,12 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module):
) )
# A clipped silu to prevent fp16 from overflowing
@partial(mx.compile, shapeless=True)
def clipped_silu(x):
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
class DeepseekV3Attention(nn.Module): class DeepseekV3Attention(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
@@ -312,7 +319,10 @@ class DeepseekV3MoE(nn.Module):
self.config = config self.config = config
self.num_experts_per_tok = config.num_experts_per_tok self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU( self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts config.hidden_size,
config.moe_intermediate_size,
config.n_routed_experts,
activation=clipped_silu,
) )
self.gate = MoEGate(config) self.gate = MoEGate(config)
@@ -359,11 +369,7 @@ class DeepseekV3DecoderLayer(nn.Module):
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r
r = self.mlp(self.post_attention_layernorm(h)) r = self.mlp(self.post_attention_layernorm(h))
out = h + r return h + r
# Protect against overflow for fp16
if out.dtype == mx.float16:
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
return out
class DeepseekV3Model(nn.Module): class DeepseekV3Model(nn.Module):
@@ -375,6 +381,10 @@ class DeepseekV3Model(nn.Module):
DeepseekV3DecoderLayer(config, idx) DeepseekV3DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers) for idx in range(config.num_hidden_layers)
] ]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0 self.pipeline_rank = 0
self.pipeline_size = 1 self.pipeline_size = 1
@@ -387,8 +397,11 @@ class DeepseekV3Model(nn.Module):
layers_per_rank = ( layers_per_rank = (
len(self.layers) + self.pipeline_size - 1 len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size ) // self.pipeline_size
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.layers = self.layers[start : start + layers_per_rank] self.end_idx = self.start_idx + layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__( def __call__(
self, self,
@@ -406,15 +419,15 @@ class DeepseekV3Model(nn.Module):
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * self.num_layers
# Receive from the previous process in the pipeline # Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1: if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for layer, c in zip(self.layers, cache): for i in range(self.num_layers):
h = layer(h, mask, c) h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline # Send to the next process in the pipeline
if pipeline_rank != 0: if pipeline_rank != 0:
@@ -462,4 +475,4 @@ class Model(nn.Module):
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@@ -0,0 +1,185 @@
# Copyright © 2025 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
rms_norm_eps: float
vocab_size: int
attention_bias: bool
head_dim: int
max_position_embeddings: int
mlp_bias: bool
model_type: str
rope_theta: float
tie_word_embeddings: bool
class HeliumAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class HeliumMLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.intermediate_size = args.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=args.mlp_bias
)
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class HeliumDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = HeliumAttention(args)
self.mlp = HeliumMLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class HeliumModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_hidden_layers = args.num_hidden_layers
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = HeliumModel(args)
self.vocab_size = args.vocab_size
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.model.layers

View File

@@ -200,6 +200,36 @@ class Model(nn.Module):
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
} }
def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()
def all_to_sharded(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
else:
return nn.AllToShardedLinear.from_linear(l, group)
def sharded_to_all(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
else:
return nn.ShardedToAllLinear.from_linear(l, group)
N = group.size()
for layer in self.model.layers:
# Shard the self attention
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
layer.self_attn.n_heads //= N
layer.self_attn.n_kv_heads //= N
# Shard the MLP
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers

View File

@@ -1,4 +1,4 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024-2025 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
@@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
self.intermediate_size, self.hidden_size, bias=args.use_bias self.intermediate_size, self.hidden_size, bias=args.use_bias
) )
def ssm_step(self, x, state=None): def ssm_step(self, x, A, state=None):
A = -mx.exp(self.A_log)
D = self.D D = self.D
deltaBC = self.x_proj(x) deltaBC = self.x_proj(x)
delta, B, C = mx.split( delta, B, C = map(
deltaBC, self.mixer_norm if self.use_bcdt_rms else lambda x: x,
indices_or_sections=[ mx.split(
self.time_step_rank, deltaBC,
self.time_step_rank + self.ssm_state_size, [self.time_step_rank, self.time_step_rank + self.ssm_state_size],
], axis=-1,
axis=-1, ),
) )
if self.use_bcdt_rms: if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C)) delta, B, C = map(self.mixer_norm, (delta, B, C))
@@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
y = y + D * x y = y + D * x
return y, new_state return y, new_state
def __call__(self, x, cache): def _process_sequence(self, x, conv_cache, state_cache):
B, T, D = x.shape B, T, D = x.shape
if cache is None: xz = self.in_proj(x)
cache = [None, None] x, z = xz.split(indices_or_sections=2, axis=-1)
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
x = nn.silu(conv_out)
A = -mx.exp(self.A_log)
outputs = [] outputs = []
current_state = state_cache
y = []
for t in range(T): for t in range(T):
xt = x[:, t, :] y_t, current_state = self.ssm_step(x[:, t], A, current_state)
xz = self.in_proj(xt) y.append(y_t)
x_t, z_t = xz.split(indices_or_sections=2, axis=1) y = mx.stack(y, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) z = self.out_proj(nn.silu(z) * y)
x_t = conv_out.squeeze(1) return z, (new_conv_cache, current_state)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1]) def __call__(self, x, cache):
z_t = nn.silu(z_t) if cache is None:
output_t = y_t * z_t conv_cache, state_cache = None, None
output_t = self.out_proj(output_t) else:
outputs.append(output_t) conv_cache, state_cache = cache[0], cache[1]
output = mx.stack(outputs, axis=1)
output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache
)
if isinstance(cache, MambaCache):
cache[0] = new_conv_cache
cache[1] = new_state_cache
return output return output

View File

@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2025 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union

View File

@@ -147,11 +147,11 @@ def min_p_sampling(
logprobs = logprobs * (1 / temperature) logprobs = logprobs * (1 / temperature)
# Indices sorted in decreasing order # Indices sorted in decreasing order
sorted_indices = mx.argsort(-logprobs).squeeze(0) sorted_indices = mx.argsort(-logprobs, axis=-1)
sorted_logprobs = logprobs[..., sorted_indices] sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
# Top probability # Top probability
top_logprobs = logprobs[..., sorted_indices[0]] top_logprobs = sorted_logprobs[:, 0:1]
# Calculate the min_p threshold # Calculate the min_p threshold
scaled_min_p = top_logprobs + math.log(min_p) scaled_min_p = top_logprobs + math.log(min_p)
@@ -163,9 +163,9 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p # Create pool of tokens with probability less than scaled min_p
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
# Return sampled token # Return sampled tokens
sorted_token = mx.random.categorical(selected_logprobs) sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
return sorted_indices[sorted_token] return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
@@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
# sort probs in ascending order # sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1) sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)] sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
cumulative_probs = mx.cumsum(sorted_probs, axis=-1) cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
@@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
0, 0,
) )
sorted_token = mx.random.categorical(mx.log(top_probs)) sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
token = sorted_indices.squeeze(0)[sorted_token] return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
return token
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)

View File

@@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip() return prompt.rstrip()
def process_message_content(messages):
"""
Convert message content to a format suitable for `apply_chat_template`.
The function operates on messages in place. It converts the 'content' field
to a string instead of a list of text fragments.
Args:
message_list (list): A list of dictionaries, where each dictionary may
have a 'content' key containing a list of dictionaries with 'type' and
'text' keys.
Raises:
ValueError: If the 'content' type is not supported or if 'text' is missing.
"""
for message in messages:
content = message["content"]
if isinstance(content, list):
text_fragments = [
fragment["text"] for fragment in content if fragment["type"] == "text"
]
if len(text_fragments) != len(content):
raise ValueError("Only 'text' content type is supported.")
message["content"] = "".join(text_fragments)
@dataclass @dataclass
class PromptCache: class PromptCache:
cache: List[Any] = field(default_factory=list) cache: List[Any] = field(default_factory=list)
@@ -591,8 +618,10 @@ class APIHandler(BaseHTTPRequestHandler):
self.request_id = f"chatcmpl-{uuid.uuid4()}" self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if self.tokenizer.chat_template: if self.tokenizer.chat_template:
messages = body["messages"]
process_message_content(messages)
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
body["messages"], messages,
body.get("tools", None), body.get("tools", None),
add_generation_prompt=True, add_generation_prompt=True,
) )

View File

@@ -140,8 +140,8 @@ def evaluate(
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches, iterate_batches: callable = iterate_batches,
): ):
all_losses = 0 all_losses = mx.array(0.0)
ntokens = 0 ntokens = mx.array(0)
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)

View File

@@ -94,6 +94,7 @@ def linear_to_lora_layers(
"phimoe", "phimoe",
"gemma", "gemma",
"gemma2", "gemma2",
"helium",
"starcoder2", "starcoder2",
"cohere", "cohere",
"cohere2", "cohere2",

View File

@@ -306,12 +306,12 @@ def generate_step(
y, logprobs = _step(y) y, logprobs = _step(y)
mx.async_eval(y, logprobs) mx.eval(y, logprobs)
n = 0 n = 0
while True: while True:
if n != max_tokens: if n != max_tokens:
next_y, next_logprobs = _step(y) next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs) mx.eval(next_y, next_logprobs)
if n == 0: if n == 0:
mx.eval(y) mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
@@ -398,8 +398,9 @@ def speculative_generate_step(
quantize_cache_fn(cache) quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True) logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs).squeeze(0) logprobs = logprobs.squeeze(0)
return y, logprobs.squeeze(0) y = sampler(logprobs)
return y, logprobs
def _prefill(model, cache, y): def _prefill(model, cache, y):
while y.size > prefill_step_size: while y.size > prefill_step_size:
@@ -626,6 +627,8 @@ def load_config(model_path: Path) -> dict:
def load_model( def load_model(
model_path: Path, model_path: Path,
lazy: bool = False, lazy: bool = False,
strict: bool = True,
sequential_load: bool = False,
model_config: dict = {}, model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module: ) -> nn.Module:
@@ -637,6 +640,8 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
strict (bool): Whether or not to raise an exception if weights don't
match. Default: ``True``
model_config (dict, optional): Optional configuration parameters for the model_config (dict, optional): Optional configuration parameters for the
model. Defaults to an empty dictionary. model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
@@ -659,7 +664,7 @@ def load_model(
# Try weight for back-compat # Try weight for back-compat
weight_files = glob.glob(str(model_path / "weight*.safetensors")) weight_files = glob.glob(str(model_path / "weight*.safetensors"))
if not weight_files: if not weight_files and strict:
logging.error(f"No safetensors found in {model_path}") logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}") raise FileNotFoundError(f"No safetensors found in {model_path}")
@@ -693,9 +698,18 @@ def load_model(
class_predicate=class_predicate, class_predicate=class_predicate,
) )
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()), strict=strict)
if mx.distributed.init().size() > 1:
if not hasattr(model, "shard"):
raise RuntimeError("Model doesn't support distributed inference.")
model.shard()
if not lazy: if not lazy:
weights.clear()
if sequential_load:
for layer in model.layers:
mx.eval(layer.parameters())
mx.eval(model.parameters()) mx.eval(model.parameters())
model.eval() model.eval()
@@ -708,6 +722,7 @@ def load(
model_config={}, model_config={},
adapter_path: Optional[str] = None, adapter_path: Optional[str] = None,
lazy: bool = False, lazy: bool = False,
sequential_load: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]: ) -> Tuple[nn.Module, TokenizerWrapper]:
""" """
Load the model and tokenizer from a given path or a huggingface repository. Load the model and tokenizer from a given path or a huggingface repository.
@@ -723,6 +738,8 @@ def load(
lazy (bool): If ``False`` eval the model parameters to make sure they are lazy (bool): If ``False`` eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
sequential_load (bool): If True then load each layer sequentially to
ensure that we are not wasting memory.
Returns: Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer. Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
@@ -732,7 +749,7 @@ def load(
""" """
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
model, config = load_model(model_path, lazy) model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
if adapter_path is not None: if adapter_path is not None:
model = load_adapters(model, adapter_path) model = load_adapters(model, adapter_path)
model.eval() model.eval()
@@ -746,7 +763,7 @@ def load(
def fetch_from_hub( def fetch_from_hub(
model_path: Path, lazy: bool = False model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model, config = load_model(model_path, lazy) model, config = load_model(model_path, lazy=lazy)
tokenizer = load_tokenizer( tokenizer = load_tokenizer(
model_path, eos_token_ids=config.get("eos_token_id", None) model_path, eos_token_ids=config.get("eos_token_id", None)
) )

View File

@@ -28,6 +28,12 @@ class TestSampleUtils(unittest.TestCase):
token = top_p_sampling(logits, 0.95, temperature).item() token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (1, 2, 3)) self.assertTrue(token in (1, 2, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = top_p_sampling(logits, 0.5, temperature)
self.assertEqual(tokens.tolist(), [0, 1])
def test_min_p_sampling(self): def test_min_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs) logits = mx.log(probs)
@@ -42,6 +48,12 @@ class TestSampleUtils(unittest.TestCase):
token = min_p_sampling(logits, 0.05) token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3)) self.assertTrue(token in (0, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = min_p_sampling(logits, 0.7)
self.assertEqual(tokens.tolist(), [0, 1])
def test_top_k_sampling(self): def test_top_k_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs) logits = mx.log(probs)

View File

@@ -80,6 +80,29 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body) self.assertIn("id", response_body)
self.assertIn("choices", response_body) self.assertIn("choices", response_body)
def test_handle_chat_completions_with_content_fragments(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
chat_post_data = {
"model": "chat_model",
"max_tokens": 10,
"temperature": 0.7,
"top_p": 0.85,
"repetition_penalty": 1.2,
"messages": [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
],
},
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
],
}
response = requests.post(url, json=chat_post_data)
response_body = response.text
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_handle_models(self): def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models" url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url) response = requests.get(url)