mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
3 Commits
distribute
...
dist-eval
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f787c08585 | ||
|
|
d5f49d65b9 | ||
|
|
4385363c0f |
@@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
|
|||||||
- Markus Enzweiler: Added the `cvae` examples.
|
- Markus Enzweiler: Added the `cvae` examples.
|
||||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||||
- Shiyu Li: Added the `Segment Anything Model`.
|
- Shiyu Li: Added the `Segment Anything Model`.
|
||||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.
|
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.
|
||||||
@@ -45,7 +45,7 @@ Some more useful examples are listed below.
|
|||||||
|
|
||||||
### Hugging Face
|
### Hugging Face
|
||||||
|
|
||||||
You can directly use or download converted checkpoints from the [MLX
|
Note: You can now directly download a few converted checkpoints from the [MLX
|
||||||
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
||||||
We encourage you to join the community and [contribute new
|
We encourage you to join the community and [contribute new
|
||||||
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ mlx_lm.convert \
|
|||||||
```
|
```
|
||||||
|
|
||||||
Models can also be converted and quantized directly in the
|
Models can also be converted and quantized directly in the
|
||||||
[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
||||||
Face Space.
|
Face Space.
|
||||||
|
|
||||||
### Long Prompts and Generations
|
### Long Prompts and Generations
|
||||||
|
|||||||
@@ -16,25 +16,6 @@ DEFAULT_MAX_TOKENS = 256
|
|||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
|
|
||||||
def share_message(world, prompt):
|
|
||||||
if world.size() == 1:
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
if world.rank() == 0:
|
|
||||||
size = mx.array([len(prompt)])
|
|
||||||
else:
|
|
||||||
size = mx.array([0])
|
|
||||||
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
|
|
||||||
if size == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if world.rank() == 0:
|
|
||||||
prompt = mx.array(prompt)
|
|
||||||
else:
|
|
||||||
prompt = mx.array([0] * len(prompt))
|
|
||||||
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
|
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
"""Set up and return the argument parser."""
|
"""Set up and return the argument parser."""
|
||||||
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
||||||
@@ -73,7 +54,6 @@ def setup_arg_parser():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
world = mx.distributed.init()
|
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -83,30 +63,16 @@ def main():
|
|||||||
args.model,
|
args.model,
|
||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config={"trust_remote_code": True},
|
tokenizer_config={"trust_remote_code": True},
|
||||||
sequential_load=mx.distributed.init().size() > 1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
||||||
print(
|
|
||||||
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
world.barrier()
|
|
||||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
while True:
|
while True:
|
||||||
if world.rank() == 0:
|
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
if query == "q":
|
if query == "q":
|
||||||
prompt = []
|
|
||||||
else:
|
|
||||||
messages = [{"role": "user", "content": query}]
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
|
||||||
messages, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = share_message(world, prompt)
|
|
||||||
if len(prompt) == 0:
|
|
||||||
break
|
break
|
||||||
|
messages = [{"role": "user", "content": query}]
|
||||||
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
for response in stream_generate(
|
for response in stream_generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -115,9 +81,7 @@ def main():
|
|||||||
sampler=make_sampler(args.temp, args.top_p),
|
sampler=make_sampler(args.temp, args.top_p),
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
):
|
):
|
||||||
if world.rank() == 0:
|
print(response.text, flush=True, end="")
|
||||||
print(response, flush=True, end="")
|
|
||||||
if world.rank() == 0:
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import lm_eval
|
import lm_eval
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@@ -20,11 +20,10 @@ from lm_eval.api.model import LM
|
|||||||
from lm_eval.api.registry import register_model
|
from lm_eval.api.registry import register_model
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .models.base import create_causal_mask
|
||||||
from .models.cache import make_prompt_cache
|
from .models.cache import make_prompt_cache
|
||||||
from .utils import load, stream_generate
|
from .utils import load, stream_generate
|
||||||
|
|
||||||
PAD = 0
|
|
||||||
|
|
||||||
|
|
||||||
def _len_longest_common_prefix(a, b):
|
def _len_longest_common_prefix(a, b):
|
||||||
l = 0
|
l = 0
|
||||||
@@ -43,31 +42,14 @@ def _rstrip_until(s, untils):
|
|||||||
return s[: min(f)]
|
return s[: min(f)]
|
||||||
|
|
||||||
|
|
||||||
def _pad_inputs(
|
def _pad_inputs(inputs):
|
||||||
inputs,
|
lengths = np.array([len(x) for x in inputs])
|
||||||
maxlen,
|
maxlen = lengths.max()
|
||||||
genlen=0,
|
padded = np.stack(
|
||||||
pad_left=False,
|
[np.pad(x, (0, maxlen - len(x))) for x in inputs],
|
||||||
pad_multiple=32,
|
|
||||||
truncate=False,
|
|
||||||
):
|
|
||||||
# pad the prompts to the left with at least genlen tokens.
|
|
||||||
actual_maxlen = max(len(p) for p in inputs) + genlen
|
|
||||||
if actual_maxlen > maxlen:
|
|
||||||
if not truncate:
|
|
||||||
raise ValueError("Inputs are too long.")
|
|
||||||
else: # drop begining
|
|
||||||
actual_maxlen = maxlen
|
|
||||||
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
|
|
||||||
if pad_multiple > 0:
|
|
||||||
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
|
|
||||||
maxlen *= pad_multiple
|
|
||||||
assert PAD == 0
|
|
||||||
lr = np.array((1, 0) if pad_left else (0, 1))
|
|
||||||
return np.stack(
|
|
||||||
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
|
|
||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
|
return mx.array(padded), mx.array(lengths)
|
||||||
|
|
||||||
|
|
||||||
@register_model("mlxlm")
|
@register_model("mlxlm")
|
||||||
@@ -83,32 +65,33 @@ class MLXLM(LM):
|
|||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._model, self.tokenizer = load(path_or_hf_repo)
|
self._model, self.tokenizer = load(path_or_hf_repo)
|
||||||
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
||||||
self.use_chat_template = use_chat_template or (
|
self.use_chat_template = use_chat_template and (
|
||||||
self.tokenizer.chat_template is not None
|
self.tokenizer.chat_template is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
def _score_fn(self, inputs, tokenize=True, step_size=32):
|
def _score_fn(self, inputs, step_size: int = 64):
|
||||||
if tokenize:
|
inputs, lengths = _pad_inputs(inputs)
|
||||||
inputs = self._tokenize(inputs)
|
|
||||||
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
|
|
||||||
inputs = mx.array(inputs)
|
|
||||||
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
||||||
|
|
||||||
cache = make_prompt_cache(self._model)
|
cache = make_prompt_cache(self._model)
|
||||||
|
|
||||||
mask = targets != PAD
|
|
||||||
|
|
||||||
scores, is_greedy = [], []
|
scores, is_greedy = [], []
|
||||||
for i in range(0, inputs.shape[1], step_size):
|
for i in range(0, inputs.shape[1], step_size):
|
||||||
logits = self._model(inputs[:, i : i + step_size], cache=cache)
|
inp = inputs[:, i : i + step_size]
|
||||||
|
T = inp.shape[1]
|
||||||
|
|
||||||
|
offset = cache[0].offset
|
||||||
|
mask = create_causal_mask(T, offset, lengths=lengths)
|
||||||
|
mask = mask == 0
|
||||||
|
|
||||||
|
logits = self._model(inp, cache=cache, mask=mask)
|
||||||
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
||||||
|
|
||||||
score = mx.take_along_axis(
|
score = mx.take_along_axis(
|
||||||
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
|
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
|
||||||
)[..., 0]
|
)[..., 0]
|
||||||
ig = mask[:, i : i + step_size] * (
|
ig = targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
|
||||||
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
|
ig = mx.where(mx.arange(T) + offset < lengths[:, None], ig, False)
|
||||||
)
|
|
||||||
|
|
||||||
mx.eval(score, ig)
|
mx.eval(score, ig)
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
@@ -119,38 +102,32 @@ class MLXLM(LM):
|
|||||||
scores = mx.concatenate(scores, axis=1)
|
scores = mx.concatenate(scores, axis=1)
|
||||||
is_greedy = mx.concatenate(is_greedy, axis=1)
|
is_greedy = mx.concatenate(is_greedy, axis=1)
|
||||||
|
|
||||||
return scores, mask.sum(axis=-1), is_greedy
|
return scores, lengths, is_greedy
|
||||||
|
|
||||||
def _loglikelihood(self, texts, score_spans=None, tokenize=True):
|
def _loglikelihood(self, texts, score_spans=None):
|
||||||
# sort by length to get batches with little padding.
|
all_scores = mx.zeros(len(texts))
|
||||||
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
|
all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_)
|
||||||
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
|
for i in tqdm(range(0, len(texts), self._batch_size)):
|
||||||
sorted_spans = None
|
batch = texts[i : i + self._batch_size]
|
||||||
|
scores, lengths, is_greedy = self._score_fn(batch)
|
||||||
|
|
||||||
|
ind = np.arange(scores.shape[-1])
|
||||||
if score_spans is not None:
|
if score_spans is not None:
|
||||||
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
|
spans = score_spans[i : i + self._batch_size]
|
||||||
|
lengths = [end - start for start, end in spans]
|
||||||
|
masks = mx.array(
|
||||||
|
np.array([(ind >= start) & (ind < end) for start, end in spans])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
masks = ind[None] < lengths[:, None]
|
||||||
|
|
||||||
results = []
|
scores = (masks * scores).sum(axis=-1)
|
||||||
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
|
is_greedy = (masks * is_greedy).sum(axis=-1)
|
||||||
batch = sorted_inputs[i : i + self._batch_size]
|
|
||||||
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
|
|
||||||
for j in range(len(batch)):
|
|
||||||
if sorted_spans is None: # full sequence score
|
|
||||||
mask = mx.arange(scores[j].shape[-1]) < length
|
|
||||||
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
|
|
||||||
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
|
|
||||||
else: # subsequence score
|
|
||||||
start, end = sorted_spans[i + j]
|
|
||||||
score = scores[j][start:end].astype(mx.float32).sum()
|
|
||||||
ig = is_greedy[j][start:end].astype(mx.int32).sum()
|
|
||||||
length = end - start
|
|
||||||
|
|
||||||
results.append((score.item(), ig.item(), length))
|
all_scores[i : i + self._batch_size] = scores
|
||||||
|
all_is_greedy[i : i + self._batch_size] = is_greedy == lengths
|
||||||
|
|
||||||
# reorder the outputs
|
return all_scores, all_is_greedy
|
||||||
inv_sort = np.argsort(sorted_indices)
|
|
||||||
results = [results[inv_sort[i]] for i in range(len(results))]
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def _tokenize(self, texts):
|
def _tokenize(self, texts):
|
||||||
return [
|
return [
|
||||||
@@ -222,16 +199,53 @@ class MLXLM(LM):
|
|||||||
+ "completion longer than context."
|
+ "completion longer than context."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_results = len(shortened)
|
||||||
|
|
||||||
|
# sort by length to get batches with little padding.
|
||||||
|
sorted_indices = sorted(range(len(shortened)), key=lambda i: -len(shortened[i]))
|
||||||
|
shortened = [shortened[i] for i in sorted_indices]
|
||||||
|
completion_spans = [completion_spans[i] for i in sorted_indices]
|
||||||
|
|
||||||
|
group = mx.distributed.init()
|
||||||
|
|
||||||
|
# split strided so we have approximately the same lengths on each node
|
||||||
|
shortened = shortened[group.rank() :: group.size()]
|
||||||
|
completion_spans = completion_spans[group.rank() :: group.size()]
|
||||||
|
|
||||||
# model scoring, returns num_requests x (logp, is_greedy, length).
|
# model scoring, returns num_requests x (logp, is_greedy, length).
|
||||||
results = self._loglikelihood(
|
scores, is_greedy = self._loglikelihood(
|
||||||
shortened,
|
shortened,
|
||||||
score_spans=completion_spans,
|
score_spans=completion_spans,
|
||||||
tokenize=False,
|
|
||||||
)
|
)
|
||||||
return [(r[0], r[1] == r[2]) for r in results]
|
|
||||||
|
# all gather the results across groups
|
||||||
|
if group.size() > 1:
|
||||||
|
per_group = int(np.ceil(num_results / group.size()))
|
||||||
|
scores = mx.pad(scores, ((0, per_group - len(scores)),))
|
||||||
|
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
|
||||||
|
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
|
||||||
|
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
|
||||||
|
scores = scores.T.reshape(-1)
|
||||||
|
is_greedy = is_greedy.T.reshape(-1)
|
||||||
|
|
||||||
|
scores = np.array(scores[:num_results])
|
||||||
|
is_greedy = np.array(is_greedy[:num_results])
|
||||||
|
|
||||||
|
results = [(score, ig) for score, ig in zip(scores, is_greedy)]
|
||||||
|
inv_sort = np.argsort(sorted_indices)
|
||||||
|
results = [results[inv_sort[i]] for i in range(len(inv_sort))]
|
||||||
|
return results
|
||||||
|
|
||||||
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
||||||
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
|
||||||
|
) -> str:
|
||||||
|
if len(chat_history) == 0:
|
||||||
|
return ""
|
||||||
|
return lm_eval.models.huggingface.HFLM.apply_chat_template(
|
||||||
|
chat_history, add_generation_prompt
|
||||||
|
)
|
||||||
|
|
||||||
def loglikelihood_rolling(self, requests) -> list[float]:
|
def loglikelihood_rolling(self, requests) -> list[float]:
|
||||||
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
||||||
@@ -268,8 +282,9 @@ class MLXLM(LM):
|
|||||||
logging.info(
|
logging.info(
|
||||||
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
||||||
)
|
)
|
||||||
inputs = [req.args[0] for req in requests]
|
inputs = self._tokenize([req.args[0] for req in requests])
|
||||||
return [t[0] for t in self._loglikelihood(inputs)]
|
scores, _ = self._loglikelihood(inputs)
|
||||||
|
return scores.tolist()
|
||||||
|
|
||||||
def generate_until(self, requests) -> list[str]:
|
def generate_until(self, requests) -> list[str]:
|
||||||
"""Generate greedily until a stopping sequence
|
"""Generate greedily until a stopping sequence
|
||||||
@@ -332,7 +347,7 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--limit",
|
"--limit",
|
||||||
default=1.0,
|
default=None,
|
||||||
help="Limit the number of examples per task.",
|
help="Limit the number of examples per task.",
|
||||||
type=float,
|
type=float,
|
||||||
)
|
)
|
||||||
@@ -346,11 +361,8 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--apply-chat-template",
|
"--apply-chat-template",
|
||||||
action=argparse.BooleanOptionalAction,
|
action="store_true",
|
||||||
help="Specifies whether to apply a chat template to the prompt. If "
|
help="Specifies whether to apply a chat template to the prompt.",
|
||||||
"the model has a chat template, this defaults to `True`, "
|
|
||||||
"otherwise `False`.",
|
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@@ -4,11 +4,10 @@
|
|||||||
Run with:
|
Run with:
|
||||||
|
|
||||||
```
|
```
|
||||||
mlx.launch \
|
/path/to/mpirun \
|
||||||
|
-np 2 \
|
||||||
--hostfile /path/to/hosts.txt \
|
--hostfile /path/to/hosts.txt \
|
||||||
--backend mpi \
|
python /path/to/pipeline_generate.py --prompt "hello world"
|
||||||
/path/to/pipeline_generate.py \
|
|
||||||
--prompt "hello world"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure you can run MLX over MPI on two hosts. For more information see the
|
Make sure you can run MLX over MPI on two hosts. For more information see the
|
||||||
@@ -18,110 +17,62 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from mlx.utils import tree_flatten
|
|
||||||
from mlx_lm import load, stream_generate
|
from mlx_lm import load, stream_generate
|
||||||
from mlx_lm.utils import load_model, load_tokenizer
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||||
def download(repo: str, allow_patterns: list[str]) -> Path:
|
parser.add_argument(
|
||||||
return Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo,
|
|
||||||
allow_patterns=allow_patterns,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def shard_and_load(repo):
|
|
||||||
# Get model path with everything but weight safetensors
|
|
||||||
model_path = download(
|
|
||||||
args.model,
|
|
||||||
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Lazy load and shard model to figure out
|
|
||||||
# which weights we need
|
|
||||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
|
||||||
|
|
||||||
group = mx.distributed.init(backend="mpi")
|
|
||||||
rank = group.rank()
|
|
||||||
model.model.pipeline(group)
|
|
||||||
|
|
||||||
# Figure out which files we need for the local shard
|
|
||||||
with open(model_path / "model.safetensors.index.json", "r") as fid:
|
|
||||||
weight_index = json.load(fid)["weight_map"]
|
|
||||||
|
|
||||||
local_files = set()
|
|
||||||
for k, _ in tree_flatten(model.parameters()):
|
|
||||||
local_files.add(weight_index[k])
|
|
||||||
|
|
||||||
# Download weights for local shard
|
|
||||||
download(args.model, allow_patterns=local_files)
|
|
||||||
|
|
||||||
# Load and shard the model, and load the weights
|
|
||||||
tokenizer = load_tokenizer(model_path)
|
|
||||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
|
||||||
model.model.pipeline(group)
|
|
||||||
mx.eval(model.parameters())
|
|
||||||
|
|
||||||
# Synchronize processes before generation to avoid timeout if downloading
|
|
||||||
# model for the first time.
|
|
||||||
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
"--model",
|
||||||
default="mlx-community/DeepSeek-R1-3bit",
|
default="mlx-community/DeepSeek-R1-3bit",
|
||||||
help="HF repo or path to local model.",
|
help="HF repo or path to local model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
"-p",
|
"-p",
|
||||||
default="Write a quicksort in C++.",
|
default="Write a quicksort in C++.",
|
||||||
help="Message to be processed by the model ('-' reads from stdin)",
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-tokens",
|
"--max-tokens",
|
||||||
"-m",
|
"-m",
|
||||||
type=int,
|
type=int,
|
||||||
default=256,
|
default=256,
|
||||||
help="Maximum number of tokens to generate",
|
help="Maximum number of tokens to generate",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
group = mx.distributed.init(backend="mpi")
|
model, tokenizer = load(args.model, lazy=True)
|
||||||
rank = group.rank()
|
|
||||||
|
|
||||||
def rprint(*args, **kwargs):
|
messages = [{"role": "user", "content": args.prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
|
|
||||||
|
group = mx.distributed.init()
|
||||||
|
rank = group.rank()
|
||||||
|
model.model.pipeline(group)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
# Synchronize processes before generation to avoid timeout if downloading
|
||||||
|
# model for the first time.
|
||||||
|
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
||||||
|
|
||||||
|
|
||||||
|
def rprint(*args, **kwargs):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(*args, **kwargs)
|
print(*args, **kwargs)
|
||||||
|
|
||||||
model, tokenizer = shard_and_load(args.model)
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": args.prompt}]
|
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens):
|
||||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
|
||||||
|
|
||||||
for response in stream_generate(
|
|
||||||
model, tokenizer, prompt, max_tokens=args.max_tokens
|
|
||||||
):
|
|
||||||
rprint(response.text, end="", flush=True)
|
rprint(response.text, end="", flush=True)
|
||||||
|
|
||||||
rprint()
|
rprint()
|
||||||
rprint("=" * 10)
|
rprint("=" * 10)
|
||||||
rprint(
|
rprint(
|
||||||
f"Prompt: {response.prompt_tokens} tokens, "
|
f"Prompt: {response.prompt_tokens} tokens, "
|
||||||
f"{response.prompt_tps:.3f} tokens-per-sec"
|
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||||
)
|
)
|
||||||
rprint(
|
rprint(
|
||||||
f"Generation: {response.generation_tokens} tokens, "
|
f"Generation: {response.generation_tokens} tokens, "
|
||||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||||
)
|
)
|
||||||
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||||
|
|||||||
@@ -191,7 +191,6 @@ def main():
|
|||||||
model_path,
|
model_path,
|
||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config=tokenizer_config,
|
tokenizer_config=tokenizer_config,
|
||||||
sequential_load=mx.distributed.init().size() > 1,
|
|
||||||
)
|
)
|
||||||
for eos_token in args.extra_eos_token:
|
for eos_token in args.extra_eos_token:
|
||||||
tokenizer.add_eos_token(eos_token)
|
tokenizer.add_eos_token(eos_token)
|
||||||
@@ -235,17 +234,13 @@ def main():
|
|||||||
else:
|
else:
|
||||||
draft_model = None
|
draft_model = None
|
||||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||||
|
|
||||||
world = mx.distributed.init()
|
|
||||||
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
|
||||||
world.barrier()
|
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
|
verbose=args.verbose,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
verbose=args.verbose and world.rank() == 0,
|
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
kv_bits=args.kv_bits,
|
kv_bits=args.kv_bits,
|
||||||
@@ -254,10 +249,8 @@ def main():
|
|||||||
draft_model=draft_model,
|
draft_model=draft_model,
|
||||||
num_draft_tokens=args.num_draft_tokens,
|
num_draft_tokens=args.num_draft_tokens,
|
||||||
)
|
)
|
||||||
|
if not args.verbose:
|
||||||
if not args.verbose and mx.distributed.init().rank() == 0:
|
|
||||||
print(response)
|
print(response)
|
||||||
mx.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -364,30 +364,8 @@ class DeepseekV2Model(nn.Module):
|
|||||||
DeepseekV2DecoderLayer(config, idx)
|
DeepseekV2DecoderLayer(config, idx)
|
||||||
for idx in range(config.num_hidden_layers)
|
for idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
self.start_idx = 0
|
|
||||||
self.end_idx = len(self.layers)
|
|
||||||
self.num_layers = self.end_idx
|
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
self.pipeline_rank = 0
|
|
||||||
self.pipeline_size = 1
|
|
||||||
|
|
||||||
def pipeline(self, group):
|
|
||||||
# Split layers in reverse so rank=0 gets the last layers and
|
|
||||||
# rank=pipeline_size-1 gets the first
|
|
||||||
self.pipeline_rank = group.rank()
|
|
||||||
self.pipeline_size = group.size()
|
|
||||||
layers_per_rank = (
|
|
||||||
len(self.layers) + self.pipeline_size - 1
|
|
||||||
) // self.pipeline_size
|
|
||||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
|
||||||
self.end_idx = self.start_idx + layers_per_rank
|
|
||||||
self.num_layers = layers_per_rank
|
|
||||||
self.layers = self.layers[: self.end_idx]
|
|
||||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
|
||||||
self.num_layers = len(self.layers) - self.start_idx
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
@@ -396,31 +374,14 @@ class DeepseekV2Model(nn.Module):
|
|||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(x)
|
h = self.embed_tokens(x)
|
||||||
|
|
||||||
pipeline_rank = self.pipeline_rank
|
|
||||||
pipeline_size = self.pipeline_size
|
|
||||||
# Hack to avoid time-outs during prompt-processing
|
|
||||||
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * self.num_layers
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
# Receive from the previous process in the pipeline
|
for layer, c in zip(self.layers, cache):
|
||||||
if pipeline_rank < pipeline_size - 1:
|
h = layer(h, mask, c)
|
||||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
|
||||||
|
|
||||||
for i in range(self.num_layers):
|
|
||||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
|
||||||
|
|
||||||
# Send to the next process in the pipeline
|
|
||||||
if pipeline_rank != 0:
|
|
||||||
h = mx.distributed.send(
|
|
||||||
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast h while keeping it in the graph
|
|
||||||
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
|
||||||
|
|
||||||
return self.norm(h)
|
return self.norm(h)
|
||||||
|
|
||||||
@@ -457,4 +418,4 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
return self.model.layers
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@@ -126,12 +125,6 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# A clipped silu to prevent fp16 from overflowing
|
|
||||||
@partial(mx.compile, shapeless=True)
|
|
||||||
def clipped_silu(x):
|
|
||||||
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3Attention(nn.Module):
|
class DeepseekV3Attention(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -319,10 +312,7 @@ class DeepseekV3MoE(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.num_experts_per_tok = config.num_experts_per_tok
|
self.num_experts_per_tok = config.num_experts_per_tok
|
||||||
self.switch_mlp = SwitchGLU(
|
self.switch_mlp = SwitchGLU(
|
||||||
config.hidden_size,
|
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
||||||
config.moe_intermediate_size,
|
|
||||||
config.n_routed_experts,
|
|
||||||
activation=clipped_silu,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate = MoEGate(config)
|
self.gate = MoEGate(config)
|
||||||
@@ -369,7 +359,11 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
return h + r
|
out = h + r
|
||||||
|
# Protect against overflow for fp16
|
||||||
|
if out.dtype == mx.float16:
|
||||||
|
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3Model(nn.Module):
|
class DeepseekV3Model(nn.Module):
|
||||||
@@ -381,10 +375,6 @@ class DeepseekV3Model(nn.Module):
|
|||||||
DeepseekV3DecoderLayer(config, idx)
|
DeepseekV3DecoderLayer(config, idx)
|
||||||
for idx in range(config.num_hidden_layers)
|
for idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
self.start_idx = 0
|
|
||||||
self.end_idx = len(self.layers)
|
|
||||||
self.num_layers = self.end_idx
|
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.pipeline_rank = 0
|
self.pipeline_rank = 0
|
||||||
self.pipeline_size = 1
|
self.pipeline_size = 1
|
||||||
@@ -397,11 +387,8 @@ class DeepseekV3Model(nn.Module):
|
|||||||
layers_per_rank = (
|
layers_per_rank = (
|
||||||
len(self.layers) + self.pipeline_size - 1
|
len(self.layers) + self.pipeline_size - 1
|
||||||
) // self.pipeline_size
|
) // self.pipeline_size
|
||||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
self.end_idx = self.start_idx + layers_per_rank
|
self.layers = self.layers[start : start + layers_per_rank]
|
||||||
self.layers = self.layers[: self.end_idx]
|
|
||||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
|
||||||
self.num_layers = len(self.layers) - self.start_idx
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -419,15 +406,15 @@ class DeepseekV3Model(nn.Module):
|
|||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * self.num_layers
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
# Receive from the previous process in the pipeline
|
# Receive from the previous process in the pipeline
|
||||||
|
|
||||||
if pipeline_rank < pipeline_size - 1:
|
if pipeline_rank < pipeline_size - 1:
|
||||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||||
|
|
||||||
for i in range(self.num_layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
# Send to the next process in the pipeline
|
# Send to the next process in the pipeline
|
||||||
if pipeline_rank != 0:
|
if pipeline_rank != 0:
|
||||||
@@ -475,4 +462,4 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
return self.model.layers
|
||||||
|
|||||||
@@ -1,185 +0,0 @@
|
|||||||
# Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Optional, Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelArgs(BaseModelArgs):
|
|
||||||
hidden_size: int
|
|
||||||
num_hidden_layers: int
|
|
||||||
intermediate_size: int
|
|
||||||
num_attention_heads: int
|
|
||||||
num_key_value_heads: int
|
|
||||||
rms_norm_eps: float
|
|
||||||
vocab_size: int
|
|
||||||
attention_bias: bool
|
|
||||||
head_dim: int
|
|
||||||
max_position_embeddings: int
|
|
||||||
mlp_bias: bool
|
|
||||||
model_type: str
|
|
||||||
rope_theta: float
|
|
||||||
tie_word_embeddings: bool
|
|
||||||
|
|
||||||
|
|
||||||
class HeliumAttention(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
dim = args.hidden_size
|
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
|
||||||
assert args.num_key_value_heads is not None
|
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
|
||||||
|
|
||||||
head_dim = args.hidden_size // n_heads
|
|
||||||
self.scale = head_dim**-0.5
|
|
||||||
|
|
||||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
|
||||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
|
||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
|
||||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
x: mx.array,
|
|
||||||
mask: Optional[mx.array] = None,
|
|
||||||
cache: Optional[Any] = None,
|
|
||||||
) -> mx.array:
|
|
||||||
B, L, D = x.shape
|
|
||||||
|
|
||||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
||||||
|
|
||||||
# Prepare the queries, keys and values for the attention computation
|
|
||||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
queries = self.rope(queries, offset=cache.offset)
|
|
||||||
keys = self.rope(keys, offset=cache.offset)
|
|
||||||
keys, values = cache.update_and_fetch(keys, values)
|
|
||||||
else:
|
|
||||||
queries = self.rope(queries)
|
|
||||||
keys = self.rope(keys)
|
|
||||||
|
|
||||||
output = scaled_dot_product_attention(
|
|
||||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
|
||||||
)
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
return self.o_proj(output)
|
|
||||||
|
|
||||||
|
|
||||||
class HeliumMLP(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = args.hidden_size
|
|
||||||
self.intermediate_size = args.intermediate_size
|
|
||||||
|
|
||||||
self.gate_proj = nn.Linear(
|
|
||||||
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
|
||||||
)
|
|
||||||
self.up_proj = nn.Linear(
|
|
||||||
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
|
||||||
)
|
|
||||||
self.down_proj = nn.Linear(
|
|
||||||
self.intermediate_size, self.hidden_size, bias=args.mlp_bias
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
||||||
|
|
||||||
|
|
||||||
class HeliumDecoderLayer(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = args.hidden_size
|
|
||||||
|
|
||||||
self.self_attn = HeliumAttention(args)
|
|
||||||
self.mlp = HeliumMLP(args)
|
|
||||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
||||||
self.post_attention_layernorm = nn.RMSNorm(
|
|
||||||
args.hidden_size, eps=args.rms_norm_eps
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
x: mx.array,
|
|
||||||
mask: Optional[mx.array] = None,
|
|
||||||
cache: Optional[Any] = None,
|
|
||||||
) -> mx.array:
|
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
|
||||||
h = x + r
|
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
|
||||||
out = h + r
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class HeliumModel(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
self.num_hidden_layers = args.num_hidden_layers
|
|
||||||
self.vocab_size = args.vocab_size
|
|
||||||
|
|
||||||
assert self.vocab_size > 0
|
|
||||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
||||||
|
|
||||||
self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
inputs: mx.array,
|
|
||||||
mask: mx.array = None,
|
|
||||||
cache=None,
|
|
||||||
) -> mx.array:
|
|
||||||
h = self.embed_tokens(inputs)
|
|
||||||
|
|
||||||
if mask is None:
|
|
||||||
mask = create_attention_mask(h, cache)
|
|
||||||
|
|
||||||
if cache is None:
|
|
||||||
cache = [None] * len(self.layers)
|
|
||||||
|
|
||||||
for layer, c in zip(self.layers, cache):
|
|
||||||
h = layer(h, mask, c)
|
|
||||||
|
|
||||||
return self.norm(h)
|
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
self.model_type = args.model_type
|
|
||||||
|
|
||||||
self.model = HeliumModel(args)
|
|
||||||
|
|
||||||
self.vocab_size = args.vocab_size
|
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
||||||
|
|
||||||
if not args.tie_word_embeddings:
|
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
inputs: mx.array,
|
|
||||||
mask: mx.array = None,
|
|
||||||
cache=None,
|
|
||||||
) -> mx.array:
|
|
||||||
out = self.model(inputs, mask, cache)
|
|
||||||
if self.args.tie_word_embeddings:
|
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
|
||||||
else:
|
|
||||||
out = self.lm_head(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
@property
|
|
||||||
def layers(self):
|
|
||||||
return self.model.layers
|
|
||||||
@@ -200,36 +200,6 @@ class Model(nn.Module):
|
|||||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
}
|
}
|
||||||
|
|
||||||
def shard(self, group: Optional[mx.distributed.Group] = None):
|
|
||||||
group = group or mx.distributed.init()
|
|
||||||
|
|
||||||
def all_to_sharded(l):
|
|
||||||
if isinstance(l, nn.QuantizedLinear):
|
|
||||||
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
|
|
||||||
else:
|
|
||||||
return nn.AllToShardedLinear.from_linear(l, group)
|
|
||||||
|
|
||||||
def sharded_to_all(l):
|
|
||||||
if isinstance(l, nn.QuantizedLinear):
|
|
||||||
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
|
|
||||||
else:
|
|
||||||
return nn.ShardedToAllLinear.from_linear(l, group)
|
|
||||||
|
|
||||||
N = group.size()
|
|
||||||
for layer in self.model.layers:
|
|
||||||
# Shard the self attention
|
|
||||||
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
|
|
||||||
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
|
|
||||||
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
|
|
||||||
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
|
|
||||||
layer.self_attn.n_heads //= N
|
|
||||||
layer.self_attn.n_kv_heads //= N
|
|
||||||
|
|
||||||
# Shard the MLP
|
|
||||||
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
|
|
||||||
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
|
|
||||||
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2024-2025 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -123,16 +123,17 @@ class MambaBlock(nn.Module):
|
|||||||
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
def ssm_step(self, x, A, state=None):
|
def ssm_step(self, x, state=None):
|
||||||
|
A = -mx.exp(self.A_log)
|
||||||
D = self.D
|
D = self.D
|
||||||
deltaBC = self.x_proj(x)
|
deltaBC = self.x_proj(x)
|
||||||
delta, B, C = map(
|
delta, B, C = mx.split(
|
||||||
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
|
||||||
mx.split(
|
|
||||||
deltaBC,
|
deltaBC,
|
||||||
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
|
indices_or_sections=[
|
||||||
|
self.time_step_rank,
|
||||||
|
self.time_step_rank + self.ssm_state_size,
|
||||||
|
],
|
||||||
axis=-1,
|
axis=-1,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if self.use_bcdt_rms:
|
if self.use_bcdt_rms:
|
||||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
@@ -144,40 +145,25 @@ class MambaBlock(nn.Module):
|
|||||||
y = y + D * x
|
y = y + D * x
|
||||||
return y, new_state
|
return y, new_state
|
||||||
|
|
||||||
def _process_sequence(self, x, conv_cache, state_cache):
|
def __call__(self, x, cache):
|
||||||
B, T, D = x.shape
|
B, T, D = x.shape
|
||||||
xz = self.in_proj(x)
|
if cache is None:
|
||||||
x, z = xz.split(indices_or_sections=2, axis=-1)
|
cache = [None, None]
|
||||||
|
|
||||||
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
|
|
||||||
x = nn.silu(conv_out)
|
|
||||||
|
|
||||||
A = -mx.exp(self.A_log)
|
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
current_state = state_cache
|
|
||||||
y = []
|
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
|
xt = x[:, t, :]
|
||||||
y.append(y_t)
|
xz = self.in_proj(xt)
|
||||||
y = mx.stack(y, axis=1)
|
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
|
||||||
z = self.out_proj(nn.silu(z) * y)
|
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
||||||
return z, (new_conv_cache, current_state)
|
x_t = conv_out.squeeze(1)
|
||||||
|
x_t = nn.silu(x_t)
|
||||||
def __call__(self, x, cache):
|
y_t, cache[1] = self.ssm_step(x_t, cache[1])
|
||||||
if cache is None:
|
z_t = nn.silu(z_t)
|
||||||
conv_cache, state_cache = None, None
|
output_t = y_t * z_t
|
||||||
else:
|
output_t = self.out_proj(output_t)
|
||||||
conv_cache, state_cache = cache[0], cache[1]
|
outputs.append(output_t)
|
||||||
|
output = mx.stack(outputs, axis=1)
|
||||||
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
|
||||||
x, conv_cache, state_cache
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(cache, MambaCache):
|
|
||||||
cache[0] = new_conv_cache
|
|
||||||
cache[1] = new_state_cache
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023-2025 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|||||||
@@ -147,11 +147,11 @@ def min_p_sampling(
|
|||||||
logprobs = logprobs * (1 / temperature)
|
logprobs = logprobs * (1 / temperature)
|
||||||
|
|
||||||
# Indices sorted in decreasing order
|
# Indices sorted in decreasing order
|
||||||
sorted_indices = mx.argsort(-logprobs, axis=-1)
|
sorted_indices = mx.argsort(-logprobs).squeeze(0)
|
||||||
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
|
sorted_logprobs = logprobs[..., sorted_indices]
|
||||||
|
|
||||||
# Top probability
|
# Top probability
|
||||||
top_logprobs = sorted_logprobs[:, 0:1]
|
top_logprobs = logprobs[..., sorted_indices[0]]
|
||||||
|
|
||||||
# Calculate the min_p threshold
|
# Calculate the min_p threshold
|
||||||
scaled_min_p = top_logprobs + math.log(min_p)
|
scaled_min_p = top_logprobs + math.log(min_p)
|
||||||
@@ -163,9 +163,9 @@ def min_p_sampling(
|
|||||||
# Create pool of tokens with probability less than scaled min_p
|
# Create pool of tokens with probability less than scaled min_p
|
||||||
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
||||||
|
|
||||||
# Return sampled tokens
|
# Return sampled token
|
||||||
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
|
sorted_token = mx.random.categorical(selected_logprobs)
|
||||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
return sorted_indices[sorted_token]
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
@@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
|
|
||||||
# sort probs in ascending order
|
# sort probs in ascending order
|
||||||
sorted_indices = mx.argsort(probs, axis=-1)
|
sorted_indices = mx.argsort(probs, axis=-1)
|
||||||
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
|
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
||||||
|
|
||||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||||
|
|
||||||
@@ -196,8 +196,10 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
|
sorted_token = mx.random.categorical(mx.log(top_probs))
|
||||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
token = sorted_indices.squeeze(0)[sorted_token]
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
|||||||
@@ -114,33 +114,6 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
|||||||
return prompt.rstrip()
|
return prompt.rstrip()
|
||||||
|
|
||||||
|
|
||||||
def process_message_content(messages):
|
|
||||||
"""
|
|
||||||
Convert message content to a format suitable for `apply_chat_template`.
|
|
||||||
|
|
||||||
The function operates on messages in place. It converts the 'content' field
|
|
||||||
to a string instead of a list of text fragments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_list (list): A list of dictionaries, where each dictionary may
|
|
||||||
have a 'content' key containing a list of dictionaries with 'type' and
|
|
||||||
'text' keys.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the 'content' type is not supported or if 'text' is missing.
|
|
||||||
|
|
||||||
"""
|
|
||||||
for message in messages:
|
|
||||||
content = message["content"]
|
|
||||||
if isinstance(content, list):
|
|
||||||
text_fragments = [
|
|
||||||
fragment["text"] for fragment in content if fragment["type"] == "text"
|
|
||||||
]
|
|
||||||
if len(text_fragments) != len(content):
|
|
||||||
raise ValueError("Only 'text' content type is supported.")
|
|
||||||
message["content"] = "".join(text_fragments)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptCache:
|
class PromptCache:
|
||||||
cache: List[Any] = field(default_factory=list)
|
cache: List[Any] = field(default_factory=list)
|
||||||
@@ -618,10 +591,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
||||||
if self.tokenizer.chat_template:
|
if self.tokenizer.chat_template:
|
||||||
messages = body["messages"]
|
|
||||||
process_message_content(messages)
|
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
body["messages"],
|
||||||
body.get("tools", None),
|
body.get("tools", None),
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -140,8 +140,8 @@ def evaluate(
|
|||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
):
|
):
|
||||||
all_losses = mx.array(0.0)
|
all_losses = 0
|
||||||
ntokens = mx.array(0)
|
ntokens = 0
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
|
|||||||
@@ -94,7 +94,6 @@ def linear_to_lora_layers(
|
|||||||
"phimoe",
|
"phimoe",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
"helium",
|
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"cohere",
|
"cohere",
|
||||||
"cohere2",
|
"cohere2",
|
||||||
|
|||||||
@@ -306,12 +306,12 @@ def generate_step(
|
|||||||
|
|
||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.eval(y, logprobs)
|
mx.async_eval(y, logprobs)
|
||||||
n = 0
|
n = 0
|
||||||
while True:
|
while True:
|
||||||
if n != max_tokens:
|
if n != max_tokens:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.eval(next_y, next_logprobs)
|
mx.async_eval(next_y, next_logprobs)
|
||||||
if n == 0:
|
if n == 0:
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||||
@@ -398,9 +398,8 @@ def speculative_generate_step(
|
|||||||
quantize_cache_fn(cache)
|
quantize_cache_fn(cache)
|
||||||
|
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
logprobs = logprobs.squeeze(0)
|
y = sampler(logprobs).squeeze(0)
|
||||||
y = sampler(logprobs)
|
return y, logprobs.squeeze(0)
|
||||||
return y, logprobs
|
|
||||||
|
|
||||||
def _prefill(model, cache, y):
|
def _prefill(model, cache, y):
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
@@ -627,8 +626,6 @@ def load_config(model_path: Path) -> dict:
|
|||||||
def load_model(
|
def load_model(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
strict: bool = True,
|
|
||||||
sequential_load: bool = False,
|
|
||||||
model_config: dict = {},
|
model_config: dict = {},
|
||||||
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
@@ -640,8 +637,6 @@ def load_model(
|
|||||||
lazy (bool): If False eval the model parameters to make sure they are
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
strict (bool): Whether or not to raise an exception if weights don't
|
|
||||||
match. Default: ``True``
|
|
||||||
model_config (dict, optional): Optional configuration parameters for the
|
model_config (dict, optional): Optional configuration parameters for the
|
||||||
model. Defaults to an empty dictionary.
|
model. Defaults to an empty dictionary.
|
||||||
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||||
@@ -664,7 +659,7 @@ def load_model(
|
|||||||
# Try weight for back-compat
|
# Try weight for back-compat
|
||||||
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
||||||
|
|
||||||
if not weight_files and strict:
|
if not weight_files:
|
||||||
logging.error(f"No safetensors found in {model_path}")
|
logging.error(f"No safetensors found in {model_path}")
|
||||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||||
|
|
||||||
@@ -698,18 +693,9 @@ def load_model(
|
|||||||
class_predicate=class_predicate,
|
class_predicate=class_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(list(weights.items()), strict=strict)
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
if mx.distributed.init().size() > 1:
|
|
||||||
if not hasattr(model, "shard"):
|
|
||||||
raise RuntimeError("Model doesn't support distributed inference.")
|
|
||||||
model.shard()
|
|
||||||
|
|
||||||
if not lazy:
|
if not lazy:
|
||||||
weights.clear()
|
|
||||||
if sequential_load:
|
|
||||||
for layer in model.layers:
|
|
||||||
mx.eval(layer.parameters())
|
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -722,7 +708,6 @@ def load(
|
|||||||
model_config={},
|
model_config={},
|
||||||
adapter_path: Optional[str] = None,
|
adapter_path: Optional[str] = None,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
sequential_load: bool = False,
|
|
||||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||||
"""
|
"""
|
||||||
Load the model and tokenizer from a given path or a huggingface repository.
|
Load the model and tokenizer from a given path or a huggingface repository.
|
||||||
@@ -738,8 +723,6 @@ def load(
|
|||||||
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
sequential_load (bool): If True then load each layer sequentially to
|
|
||||||
ensure that we are not wasting memory.
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
||||||
|
|
||||||
@@ -749,7 +732,7 @@ def load(
|
|||||||
"""
|
"""
|
||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
|
model, config = load_model(model_path, lazy)
|
||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = load_adapters(model, adapter_path)
|
model = load_adapters(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -763,7 +746,7 @@ def load(
|
|||||||
def fetch_from_hub(
|
def fetch_from_hub(
|
||||||
model_path: Path, lazy: bool = False
|
model_path: Path, lazy: bool = False
|
||||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||||
model, config = load_model(model_path, lazy=lazy)
|
model, config = load_model(model_path, lazy)
|
||||||
tokenizer = load_tokenizer(
|
tokenizer = load_tokenizer(
|
||||||
model_path, eos_token_ids=config.get("eos_token_id", None)
|
model_path, eos_token_ids=config.get("eos_token_id", None)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,12 +28,6 @@ class TestSampleUtils(unittest.TestCase):
|
|||||||
token = top_p_sampling(logits, 0.95, temperature).item()
|
token = top_p_sampling(logits, 0.95, temperature).item()
|
||||||
self.assertTrue(token in (1, 2, 3))
|
self.assertTrue(token in (1, 2, 3))
|
||||||
|
|
||||||
# Batch mode works
|
|
||||||
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
|
||||||
logits = mx.log(probs)
|
|
||||||
tokens = top_p_sampling(logits, 0.5, temperature)
|
|
||||||
self.assertEqual(tokens.tolist(), [0, 1])
|
|
||||||
|
|
||||||
def test_min_p_sampling(self):
|
def test_min_p_sampling(self):
|
||||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||||
logits = mx.log(probs)
|
logits = mx.log(probs)
|
||||||
@@ -48,12 +42,6 @@ class TestSampleUtils(unittest.TestCase):
|
|||||||
token = min_p_sampling(logits, 0.05)
|
token = min_p_sampling(logits, 0.05)
|
||||||
self.assertTrue(token in (0, 3))
|
self.assertTrue(token in (0, 3))
|
||||||
|
|
||||||
# Batch mode works
|
|
||||||
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
|
||||||
logits = mx.log(probs)
|
|
||||||
tokens = min_p_sampling(logits, 0.7)
|
|
||||||
self.assertEqual(tokens.tolist(), [0, 1])
|
|
||||||
|
|
||||||
def test_top_k_sampling(self):
|
def test_top_k_sampling(self):
|
||||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||||
logits = mx.log(probs)
|
logits = mx.log(probs)
|
||||||
|
|||||||
@@ -80,29 +80,6 @@ class TestServer(unittest.TestCase):
|
|||||||
self.assertIn("id", response_body)
|
self.assertIn("id", response_body)
|
||||||
self.assertIn("choices", response_body)
|
self.assertIn("choices", response_body)
|
||||||
|
|
||||||
def test_handle_chat_completions_with_content_fragments(self):
|
|
||||||
url = f"http://localhost:{self.port}/v1/chat/completions"
|
|
||||||
chat_post_data = {
|
|
||||||
"model": "chat_model",
|
|
||||||
"max_tokens": 10,
|
|
||||||
"temperature": 0.7,
|
|
||||||
"top_p": 0.85,
|
|
||||||
"repetition_penalty": 1.2,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "You are a helpful assistant."}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
response = requests.post(url, json=chat_post_data)
|
|
||||||
response_body = response.text
|
|
||||||
self.assertIn("id", response_body)
|
|
||||||
self.assertIn("choices", response_body)
|
|
||||||
|
|
||||||
def test_handle_models(self):
|
def test_handle_models(self):
|
||||||
url = f"http://localhost:{self.port}/v1/models"
|
url = f"http://localhost:{self.port}/v1/models"
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
|
|||||||
Reference in New Issue
Block a user