mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
14 Commits
dist-eval
...
distribute
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
65b792d7c0 | ||
|
|
617f9289b9 | ||
|
|
026362e0f8 | ||
|
|
a0ce0594f6 | ||
|
|
d77840207c | ||
|
|
e2e5478da5 | ||
|
|
21d0ab6e8a | ||
|
|
0989c073b0 | ||
|
|
d9924d08d1 | ||
|
|
9c2ef38d4d | ||
|
|
e8afb59de4 | ||
|
|
7a83077cd7 | ||
|
|
f44a52e2dc | ||
|
|
77faa14ba4 |
@@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
|
|||||||
- Markus Enzweiler: Added the `cvae` examples.
|
- Markus Enzweiler: Added the `cvae` examples.
|
||||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||||
- Shiyu Li: Added the `Segment Anything Model`.
|
- Shiyu Li: Added the `Segment Anything Model`.
|
||||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.
|
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.
|
||||||
@@ -45,7 +45,7 @@ Some more useful examples are listed below.
|
|||||||
|
|
||||||
### Hugging Face
|
### Hugging Face
|
||||||
|
|
||||||
Note: You can now directly download a few converted checkpoints from the [MLX
|
You can directly use or download converted checkpoints from the [MLX
|
||||||
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
||||||
We encourage you to join the community and [contribute new
|
We encourage you to join the community and [contribute new
|
||||||
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ mlx_lm.convert \
|
|||||||
```
|
```
|
||||||
|
|
||||||
Models can also be converted and quantized directly in the
|
Models can also be converted and quantized directly in the
|
||||||
[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
||||||
Face Space.
|
Face Space.
|
||||||
|
|
||||||
### Long Prompts and Generations
|
### Long Prompts and Generations
|
||||||
|
|||||||
@@ -16,6 +16,25 @@ DEFAULT_MAX_TOKENS = 256
|
|||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
|
|
||||||
|
def share_message(world, prompt):
|
||||||
|
if world.size() == 1:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
if world.rank() == 0:
|
||||||
|
size = mx.array([len(prompt)])
|
||||||
|
else:
|
||||||
|
size = mx.array([0])
|
||||||
|
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
|
||||||
|
if size == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if world.rank() == 0:
|
||||||
|
prompt = mx.array(prompt)
|
||||||
|
else:
|
||||||
|
prompt = mx.array([0] * len(prompt))
|
||||||
|
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
"""Set up and return the argument parser."""
|
"""Set up and return the argument parser."""
|
||||||
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
||||||
@@ -54,6 +73,7 @@ def setup_arg_parser():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
world = mx.distributed.init()
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -63,16 +83,30 @@ def main():
|
|||||||
args.model,
|
args.model,
|
||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config={"trust_remote_code": True},
|
tokenizer_config={"trust_remote_code": True},
|
||||||
|
sequential_load=mx.distributed.init().size() > 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
||||||
|
print(
|
||||||
|
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
world.barrier()
|
||||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
while True:
|
while True:
|
||||||
|
if world.rank() == 0:
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
if query == "q":
|
if query == "q":
|
||||||
break
|
prompt = []
|
||||||
|
else:
|
||||||
messages = [{"role": "user", "content": query}]
|
messages = [{"role": "user", "content": query}]
|
||||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = share_message(world, prompt)
|
||||||
|
if len(prompt) == 0:
|
||||||
|
break
|
||||||
for response in stream_generate(
|
for response in stream_generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -81,7 +115,9 @@ def main():
|
|||||||
sampler=make_sampler(args.temp, args.top_p),
|
sampler=make_sampler(args.temp, args.top_p),
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
):
|
):
|
||||||
print(response.text, flush=True, end="")
|
if world.rank() == 0:
|
||||||
|
print(response, flush=True, end="")
|
||||||
|
if world.rank() == 0:
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import lm_eval
|
import lm_eval
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@@ -20,10 +20,11 @@ from lm_eval.api.model import LM
|
|||||||
from lm_eval.api.registry import register_model
|
from lm_eval.api.registry import register_model
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .models.base import create_causal_mask
|
|
||||||
from .models.cache import make_prompt_cache
|
from .models.cache import make_prompt_cache
|
||||||
from .utils import load, stream_generate
|
from .utils import load, stream_generate
|
||||||
|
|
||||||
|
PAD = 0
|
||||||
|
|
||||||
|
|
||||||
def _len_longest_common_prefix(a, b):
|
def _len_longest_common_prefix(a, b):
|
||||||
l = 0
|
l = 0
|
||||||
@@ -42,14 +43,31 @@ def _rstrip_until(s, untils):
|
|||||||
return s[: min(f)]
|
return s[: min(f)]
|
||||||
|
|
||||||
|
|
||||||
def _pad_inputs(inputs):
|
def _pad_inputs(
|
||||||
lengths = np.array([len(x) for x in inputs])
|
inputs,
|
||||||
maxlen = lengths.max()
|
maxlen,
|
||||||
padded = np.stack(
|
genlen=0,
|
||||||
[np.pad(x, (0, maxlen - len(x))) for x in inputs],
|
pad_left=False,
|
||||||
|
pad_multiple=32,
|
||||||
|
truncate=False,
|
||||||
|
):
|
||||||
|
# pad the prompts to the left with at least genlen tokens.
|
||||||
|
actual_maxlen = max(len(p) for p in inputs) + genlen
|
||||||
|
if actual_maxlen > maxlen:
|
||||||
|
if not truncate:
|
||||||
|
raise ValueError("Inputs are too long.")
|
||||||
|
else: # drop begining
|
||||||
|
actual_maxlen = maxlen
|
||||||
|
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
|
||||||
|
if pad_multiple > 0:
|
||||||
|
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
|
||||||
|
maxlen *= pad_multiple
|
||||||
|
assert PAD == 0
|
||||||
|
lr = np.array((1, 0) if pad_left else (0, 1))
|
||||||
|
return np.stack(
|
||||||
|
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
|
||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
return mx.array(padded), mx.array(lengths)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model("mlxlm")
|
@register_model("mlxlm")
|
||||||
@@ -65,33 +83,32 @@ class MLXLM(LM):
|
|||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._model, self.tokenizer = load(path_or_hf_repo)
|
self._model, self.tokenizer = load(path_or_hf_repo)
|
||||||
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
||||||
self.use_chat_template = use_chat_template and (
|
self.use_chat_template = use_chat_template or (
|
||||||
self.tokenizer.chat_template is not None
|
self.tokenizer.chat_template is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
def _score_fn(self, inputs, step_size: int = 64):
|
def _score_fn(self, inputs, tokenize=True, step_size=32):
|
||||||
inputs, lengths = _pad_inputs(inputs)
|
if tokenize:
|
||||||
|
inputs = self._tokenize(inputs)
|
||||||
|
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
|
||||||
|
inputs = mx.array(inputs)
|
||||||
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
||||||
|
|
||||||
cache = make_prompt_cache(self._model)
|
cache = make_prompt_cache(self._model)
|
||||||
|
|
||||||
|
mask = targets != PAD
|
||||||
|
|
||||||
scores, is_greedy = [], []
|
scores, is_greedy = [], []
|
||||||
for i in range(0, inputs.shape[1], step_size):
|
for i in range(0, inputs.shape[1], step_size):
|
||||||
inp = inputs[:, i : i + step_size]
|
logits = self._model(inputs[:, i : i + step_size], cache=cache)
|
||||||
T = inp.shape[1]
|
|
||||||
|
|
||||||
offset = cache[0].offset
|
|
||||||
mask = create_causal_mask(T, offset, lengths=lengths)
|
|
||||||
mask = mask == 0
|
|
||||||
|
|
||||||
logits = self._model(inp, cache=cache, mask=mask)
|
|
||||||
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
||||||
|
|
||||||
score = mx.take_along_axis(
|
score = mx.take_along_axis(
|
||||||
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
|
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
|
||||||
)[..., 0]
|
)[..., 0]
|
||||||
ig = targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
|
ig = mask[:, i : i + step_size] * (
|
||||||
ig = mx.where(mx.arange(T) + offset < lengths[:, None], ig, False)
|
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
|
||||||
|
)
|
||||||
|
|
||||||
mx.eval(score, ig)
|
mx.eval(score, ig)
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
@@ -102,32 +119,38 @@ class MLXLM(LM):
|
|||||||
scores = mx.concatenate(scores, axis=1)
|
scores = mx.concatenate(scores, axis=1)
|
||||||
is_greedy = mx.concatenate(is_greedy, axis=1)
|
is_greedy = mx.concatenate(is_greedy, axis=1)
|
||||||
|
|
||||||
return scores, lengths, is_greedy
|
return scores, mask.sum(axis=-1), is_greedy
|
||||||
|
|
||||||
def _loglikelihood(self, texts, score_spans=None):
|
def _loglikelihood(self, texts, score_spans=None, tokenize=True):
|
||||||
all_scores = mx.zeros(len(texts))
|
# sort by length to get batches with little padding.
|
||||||
all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_)
|
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
|
||||||
for i in tqdm(range(0, len(texts), self._batch_size)):
|
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
|
||||||
batch = texts[i : i + self._batch_size]
|
sorted_spans = None
|
||||||
scores, lengths, is_greedy = self._score_fn(batch)
|
|
||||||
|
|
||||||
ind = np.arange(scores.shape[-1])
|
|
||||||
if score_spans is not None:
|
if score_spans is not None:
|
||||||
spans = score_spans[i : i + self._batch_size]
|
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
|
||||||
lengths = [end - start for start, end in spans]
|
|
||||||
masks = mx.array(
|
|
||||||
np.array([(ind >= start) & (ind < end) for start, end in spans])
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
masks = ind[None] < lengths[:, None]
|
|
||||||
|
|
||||||
scores = (masks * scores).sum(axis=-1)
|
results = []
|
||||||
is_greedy = (masks * is_greedy).sum(axis=-1)
|
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
|
||||||
|
batch = sorted_inputs[i : i + self._batch_size]
|
||||||
|
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
|
||||||
|
for j in range(len(batch)):
|
||||||
|
if sorted_spans is None: # full sequence score
|
||||||
|
mask = mx.arange(scores[j].shape[-1]) < length
|
||||||
|
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
|
||||||
|
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
|
||||||
|
else: # subsequence score
|
||||||
|
start, end = sorted_spans[i + j]
|
||||||
|
score = scores[j][start:end].astype(mx.float32).sum()
|
||||||
|
ig = is_greedy[j][start:end].astype(mx.int32).sum()
|
||||||
|
length = end - start
|
||||||
|
|
||||||
all_scores[i : i + self._batch_size] = scores
|
results.append((score.item(), ig.item(), length))
|
||||||
all_is_greedy[i : i + self._batch_size] = is_greedy == lengths
|
|
||||||
|
|
||||||
return all_scores, all_is_greedy
|
# reorder the outputs
|
||||||
|
inv_sort = np.argsort(sorted_indices)
|
||||||
|
results = [results[inv_sort[i]] for i in range(len(results))]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def _tokenize(self, texts):
|
def _tokenize(self, texts):
|
||||||
return [
|
return [
|
||||||
@@ -199,53 +222,16 @@ class MLXLM(LM):
|
|||||||
+ "completion longer than context."
|
+ "completion longer than context."
|
||||||
)
|
)
|
||||||
|
|
||||||
num_results = len(shortened)
|
|
||||||
|
|
||||||
# sort by length to get batches with little padding.
|
|
||||||
sorted_indices = sorted(range(len(shortened)), key=lambda i: -len(shortened[i]))
|
|
||||||
shortened = [shortened[i] for i in sorted_indices]
|
|
||||||
completion_spans = [completion_spans[i] for i in sorted_indices]
|
|
||||||
|
|
||||||
group = mx.distributed.init()
|
|
||||||
|
|
||||||
# split strided so we have approximately the same lengths on each node
|
|
||||||
shortened = shortened[group.rank() :: group.size()]
|
|
||||||
completion_spans = completion_spans[group.rank() :: group.size()]
|
|
||||||
|
|
||||||
# model scoring, returns num_requests x (logp, is_greedy, length).
|
# model scoring, returns num_requests x (logp, is_greedy, length).
|
||||||
scores, is_greedy = self._loglikelihood(
|
results = self._loglikelihood(
|
||||||
shortened,
|
shortened,
|
||||||
score_spans=completion_spans,
|
score_spans=completion_spans,
|
||||||
|
tokenize=False,
|
||||||
)
|
)
|
||||||
|
return [(r[0], r[1] == r[2]) for r in results]
|
||||||
# all gather the results across groups
|
|
||||||
if group.size() > 1:
|
|
||||||
per_group = int(np.ceil(num_results / group.size()))
|
|
||||||
scores = mx.pad(scores, ((0, per_group - len(scores)),))
|
|
||||||
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
|
|
||||||
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
|
|
||||||
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
|
|
||||||
scores = scores.T.reshape(-1)
|
|
||||||
is_greedy = is_greedy.T.reshape(-1)
|
|
||||||
|
|
||||||
scores = np.array(scores[:num_results])
|
|
||||||
is_greedy = np.array(is_greedy[:num_results])
|
|
||||||
|
|
||||||
results = [(score, ig) for score, ig in zip(scores, is_greedy)]
|
|
||||||
inv_sort = np.argsort(sorted_indices)
|
|
||||||
results = [results[inv_sort[i]] for i in range(len(inv_sort))]
|
|
||||||
return results
|
|
||||||
|
|
||||||
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
||||||
|
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
||||||
def apply_chat_template(
|
|
||||||
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
|
|
||||||
) -> str:
|
|
||||||
if len(chat_history) == 0:
|
|
||||||
return ""
|
|
||||||
return lm_eval.models.huggingface.HFLM.apply_chat_template(
|
|
||||||
chat_history, add_generation_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
def loglikelihood_rolling(self, requests) -> list[float]:
|
def loglikelihood_rolling(self, requests) -> list[float]:
|
||||||
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
||||||
@@ -282,9 +268,8 @@ class MLXLM(LM):
|
|||||||
logging.info(
|
logging.info(
|
||||||
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
||||||
)
|
)
|
||||||
inputs = self._tokenize([req.args[0] for req in requests])
|
inputs = [req.args[0] for req in requests]
|
||||||
scores, _ = self._loglikelihood(inputs)
|
return [t[0] for t in self._loglikelihood(inputs)]
|
||||||
return scores.tolist()
|
|
||||||
|
|
||||||
def generate_until(self, requests) -> list[str]:
|
def generate_until(self, requests) -> list[str]:
|
||||||
"""Generate greedily until a stopping sequence
|
"""Generate greedily until a stopping sequence
|
||||||
@@ -347,7 +332,7 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--limit",
|
"--limit",
|
||||||
default=None,
|
default=1.0,
|
||||||
help="Limit the number of examples per task.",
|
help="Limit the number of examples per task.",
|
||||||
type=float,
|
type=float,
|
||||||
)
|
)
|
||||||
@@ -361,8 +346,11 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--apply-chat-template",
|
"--apply-chat-template",
|
||||||
action="store_true",
|
action=argparse.BooleanOptionalAction,
|
||||||
help="Specifies whether to apply a chat template to the prompt.",
|
help="Specifies whether to apply a chat template to the prompt. If "
|
||||||
|
"the model has a chat template, this defaults to `True`, "
|
||||||
|
"otherwise `False`.",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,11 @@
|
|||||||
Run with:
|
Run with:
|
||||||
|
|
||||||
```
|
```
|
||||||
/path/to/mpirun \
|
mlx.launch \
|
||||||
-np 2 \
|
|
||||||
--hostfile /path/to/hosts.txt \
|
--hostfile /path/to/hosts.txt \
|
||||||
python /path/to/pipeline_generate.py --prompt "hello world"
|
--backend mpi \
|
||||||
|
/path/to/pipeline_generate.py \
|
||||||
|
--prompt "hello world"
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure you can run MLX over MPI on two hosts. For more information see the
|
Make sure you can run MLX over MPI on two hosts. For more information see the
|
||||||
@@ -17,62 +18,110 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
from mlx_lm import load, stream_generate
|
from mlx_lm import load, stream_generate
|
||||||
|
from mlx_lm.utils import load_model, load_tokenizer
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
|
||||||
parser.add_argument(
|
def download(repo: str, allow_patterns: list[str]) -> Path:
|
||||||
|
return Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def shard_and_load(repo):
|
||||||
|
# Get model path with everything but weight safetensors
|
||||||
|
model_path = download(
|
||||||
|
args.model,
|
||||||
|
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Lazy load and shard model to figure out
|
||||||
|
# which weights we need
|
||||||
|
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||||
|
|
||||||
|
group = mx.distributed.init(backend="mpi")
|
||||||
|
rank = group.rank()
|
||||||
|
model.model.pipeline(group)
|
||||||
|
|
||||||
|
# Figure out which files we need for the local shard
|
||||||
|
with open(model_path / "model.safetensors.index.json", "r") as fid:
|
||||||
|
weight_index = json.load(fid)["weight_map"]
|
||||||
|
|
||||||
|
local_files = set()
|
||||||
|
for k, _ in tree_flatten(model.parameters()):
|
||||||
|
local_files.add(weight_index[k])
|
||||||
|
|
||||||
|
# Download weights for local shard
|
||||||
|
download(args.model, allow_patterns=local_files)
|
||||||
|
|
||||||
|
# Load and shard the model, and load the weights
|
||||||
|
tokenizer = load_tokenizer(model_path)
|
||||||
|
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||||
|
model.model.pipeline(group)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
# Synchronize processes before generation to avoid timeout if downloading
|
||||||
|
# model for the first time.
|
||||||
|
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||||
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
default="mlx-community/DeepSeek-R1-3bit",
|
default="mlx-community/DeepSeek-R1-3bit",
|
||||||
help="HF repo or path to local model.",
|
help="HF repo or path to local model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
"-p",
|
"-p",
|
||||||
default="Write a quicksort in C++.",
|
default="Write a quicksort in C++.",
|
||||||
help="Message to be processed by the model ('-' reads from stdin)",
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-tokens",
|
"--max-tokens",
|
||||||
"-m",
|
"-m",
|
||||||
type=int,
|
type=int,
|
||||||
default=256,
|
default=256,
|
||||||
help="Maximum number of tokens to generate",
|
help="Maximum number of tokens to generate",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model, tokenizer = load(args.model, lazy=True)
|
group = mx.distributed.init(backend="mpi")
|
||||||
|
rank = group.rank()
|
||||||
|
|
||||||
messages = [{"role": "user", "content": args.prompt}]
|
def rprint(*args, **kwargs):
|
||||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
|
||||||
|
|
||||||
group = mx.distributed.init()
|
|
||||||
rank = group.rank()
|
|
||||||
model.model.pipeline(group)
|
|
||||||
mx.eval(model.parameters())
|
|
||||||
|
|
||||||
# Synchronize processes before generation to avoid timeout if downloading
|
|
||||||
# model for the first time.
|
|
||||||
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
|
||||||
|
|
||||||
|
|
||||||
def rprint(*args, **kwargs):
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(*args, **kwargs)
|
print(*args, **kwargs)
|
||||||
|
|
||||||
|
model, tokenizer = shard_and_load(args.model)
|
||||||
|
|
||||||
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens):
|
messages = [{"role": "user", "content": args.prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
|
|
||||||
|
for response in stream_generate(
|
||||||
|
model, tokenizer, prompt, max_tokens=args.max_tokens
|
||||||
|
):
|
||||||
rprint(response.text, end="", flush=True)
|
rprint(response.text, end="", flush=True)
|
||||||
|
|
||||||
rprint()
|
rprint()
|
||||||
rprint("=" * 10)
|
rprint("=" * 10)
|
||||||
rprint(
|
rprint(
|
||||||
f"Prompt: {response.prompt_tokens} tokens, "
|
f"Prompt: {response.prompt_tokens} tokens, "
|
||||||
f"{response.prompt_tps:.3f} tokens-per-sec"
|
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||||
)
|
)
|
||||||
rprint(
|
rprint(
|
||||||
f"Generation: {response.generation_tokens} tokens, "
|
f"Generation: {response.generation_tokens} tokens, "
|
||||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||||
)
|
)
|
||||||
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||||
|
|||||||
@@ -191,6 +191,7 @@ def main():
|
|||||||
model_path,
|
model_path,
|
||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config=tokenizer_config,
|
tokenizer_config=tokenizer_config,
|
||||||
|
sequential_load=mx.distributed.init().size() > 1,
|
||||||
)
|
)
|
||||||
for eos_token in args.extra_eos_token:
|
for eos_token in args.extra_eos_token:
|
||||||
tokenizer.add_eos_token(eos_token)
|
tokenizer.add_eos_token(eos_token)
|
||||||
@@ -234,13 +235,17 @@ def main():
|
|||||||
else:
|
else:
|
||||||
draft_model = None
|
draft_model = None
|
||||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||||
|
|
||||||
|
world = mx.distributed.init()
|
||||||
|
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
||||||
|
world.barrier()
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
verbose=args.verbose,
|
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
verbose=args.verbose and world.rank() == 0,
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
kv_bits=args.kv_bits,
|
kv_bits=args.kv_bits,
|
||||||
@@ -249,8 +254,10 @@ def main():
|
|||||||
draft_model=draft_model,
|
draft_model=draft_model,
|
||||||
num_draft_tokens=args.num_draft_tokens,
|
num_draft_tokens=args.num_draft_tokens,
|
||||||
)
|
)
|
||||||
if not args.verbose:
|
|
||||||
|
if not args.verbose and mx.distributed.init().rank() == 0:
|
||||||
print(response)
|
print(response)
|
||||||
|
mx.synchronize()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -364,8 +364,30 @@ class DeepseekV2Model(nn.Module):
|
|||||||
DeepseekV2DecoderLayer(config, idx)
|
DeepseekV2DecoderLayer(config, idx)
|
||||||
for idx in range(config.num_hidden_layers)
|
for idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
self.start_idx = 0
|
||||||
|
self.end_idx = len(self.layers)
|
||||||
|
self.num_layers = self.end_idx
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
self.pipeline_rank = 0
|
||||||
|
self.pipeline_size = 1
|
||||||
|
|
||||||
|
def pipeline(self, group):
|
||||||
|
# Split layers in reverse so rank=0 gets the last layers and
|
||||||
|
# rank=pipeline_size-1 gets the first
|
||||||
|
self.pipeline_rank = group.rank()
|
||||||
|
self.pipeline_size = group.size()
|
||||||
|
layers_per_rank = (
|
||||||
|
len(self.layers) + self.pipeline_size - 1
|
||||||
|
) // self.pipeline_size
|
||||||
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
|
self.num_layers = layers_per_rank
|
||||||
|
self.layers = self.layers[: self.end_idx]
|
||||||
|
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||||
|
self.num_layers = len(self.layers) - self.start_idx
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
@@ -374,14 +396,31 @@ class DeepseekV2Model(nn.Module):
|
|||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(x)
|
h = self.embed_tokens(x)
|
||||||
|
|
||||||
|
pipeline_rank = self.pipeline_rank
|
||||||
|
pipeline_size = self.pipeline_size
|
||||||
|
# Hack to avoid time-outs during prompt-processing
|
||||||
|
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * self.num_layers
|
||||||
|
|
||||||
for layer, c in zip(self.layers, cache):
|
# Receive from the previous process in the pipeline
|
||||||
h = layer(h, mask, c)
|
if pipeline_rank < pipeline_size - 1:
|
||||||
|
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||||
|
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||||
|
|
||||||
|
# Send to the next process in the pipeline
|
||||||
|
if pipeline_rank != 0:
|
||||||
|
h = mx.distributed.send(
|
||||||
|
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||||
|
)
|
||||||
|
|
||||||
|
# Broadcast h while keeping it in the graph
|
||||||
|
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||||
|
|
||||||
return self.norm(h)
|
return self.norm(h)
|
||||||
|
|
||||||
@@ -418,4 +457,4 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@@ -125,6 +126,12 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# A clipped silu to prevent fp16 from overflowing
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def clipped_silu(x):
|
||||||
|
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3Attention(nn.Module):
|
class DeepseekV3Attention(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -312,7 +319,10 @@ class DeepseekV3MoE(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.num_experts_per_tok = config.num_experts_per_tok
|
self.num_experts_per_tok = config.num_experts_per_tok
|
||||||
self.switch_mlp = SwitchGLU(
|
self.switch_mlp = SwitchGLU(
|
||||||
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
config.hidden_size,
|
||||||
|
config.moe_intermediate_size,
|
||||||
|
config.n_routed_experts,
|
||||||
|
activation=clipped_silu,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate = MoEGate(config)
|
self.gate = MoEGate(config)
|
||||||
@@ -359,11 +369,7 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
out = h + r
|
return h + r
|
||||||
# Protect against overflow for fp16
|
|
||||||
if out.dtype == mx.float16:
|
|
||||||
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3Model(nn.Module):
|
class DeepseekV3Model(nn.Module):
|
||||||
@@ -375,6 +381,10 @@ class DeepseekV3Model(nn.Module):
|
|||||||
DeepseekV3DecoderLayer(config, idx)
|
DeepseekV3DecoderLayer(config, idx)
|
||||||
for idx in range(config.num_hidden_layers)
|
for idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
self.start_idx = 0
|
||||||
|
self.end_idx = len(self.layers)
|
||||||
|
self.num_layers = self.end_idx
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.pipeline_rank = 0
|
self.pipeline_rank = 0
|
||||||
self.pipeline_size = 1
|
self.pipeline_size = 1
|
||||||
@@ -387,8 +397,11 @@ class DeepseekV3Model(nn.Module):
|
|||||||
layers_per_rank = (
|
layers_per_rank = (
|
||||||
len(self.layers) + self.pipeline_size - 1
|
len(self.layers) + self.pipeline_size - 1
|
||||||
) // self.pipeline_size
|
) // self.pipeline_size
|
||||||
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
self.layers = self.layers[start : start + layers_per_rank]
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
|
self.layers = self.layers[: self.end_idx]
|
||||||
|
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||||
|
self.num_layers = len(self.layers) - self.start_idx
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -406,15 +419,15 @@ class DeepseekV3Model(nn.Module):
|
|||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * self.num_layers
|
||||||
|
|
||||||
# Receive from the previous process in the pipeline
|
# Receive from the previous process in the pipeline
|
||||||
|
|
||||||
if pipeline_rank < pipeline_size - 1:
|
if pipeline_rank < pipeline_size - 1:
|
||||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||||
|
|
||||||
for layer, c in zip(self.layers, cache):
|
for i in range(self.num_layers):
|
||||||
h = layer(h, mask, c)
|
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||||
|
|
||||||
# Send to the next process in the pipeline
|
# Send to the next process in the pipeline
|
||||||
if pipeline_rank != 0:
|
if pipeline_rank != 0:
|
||||||
@@ -462,4 +475,4 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||||
|
|||||||
185
llms/mlx_lm/models/helium.py
Normal file
185
llms/mlx_lm/models/helium.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
num_key_value_heads: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
attention_bias: bool
|
||||||
|
head_dim: int
|
||||||
|
max_position_embeddings: int
|
||||||
|
mlp_bias: bool
|
||||||
|
model_type: str
|
||||||
|
rope_theta: float
|
||||||
|
tie_word_embeddings: bool
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumAttention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
assert args.num_key_value_heads is not None
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
|
head_dim = args.hidden_size // n_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Any] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = scaled_dot_product_attention(
|
||||||
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumMLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.intermediate_size = args.intermediate_size
|
||||||
|
|
||||||
|
self.gate_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
||||||
|
)
|
||||||
|
self.up_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
||||||
|
)
|
||||||
|
self.down_proj = nn.Linear(
|
||||||
|
self.intermediate_size, self.hidden_size, bias=args.mlp_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
|
||||||
|
self.self_attn = HeliumAttention(args)
|
||||||
|
self.mlp = HeliumMLP(args)
|
||||||
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Any] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class HeliumModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
|
||||||
|
self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
||||||
|
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache=None,
|
||||||
|
) -> mx.array:
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
|
||||||
|
self.model = HeliumModel(args)
|
||||||
|
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache=None,
|
||||||
|
) -> mx.array:
|
||||||
|
out = self.model(inputs, mask, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
@@ -200,6 +200,36 @@ class Model(nn.Module):
|
|||||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def shard(self, group: Optional[mx.distributed.Group] = None):
|
||||||
|
group = group or mx.distributed.init()
|
||||||
|
|
||||||
|
def all_to_sharded(l):
|
||||||
|
if isinstance(l, nn.QuantizedLinear):
|
||||||
|
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
|
||||||
|
else:
|
||||||
|
return nn.AllToShardedLinear.from_linear(l, group)
|
||||||
|
|
||||||
|
def sharded_to_all(l):
|
||||||
|
if isinstance(l, nn.QuantizedLinear):
|
||||||
|
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
|
||||||
|
else:
|
||||||
|
return nn.ShardedToAllLinear.from_linear(l, group)
|
||||||
|
|
||||||
|
N = group.size()
|
||||||
|
for layer in self.model.layers:
|
||||||
|
# Shard the self attention
|
||||||
|
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
|
||||||
|
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
|
||||||
|
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
|
||||||
|
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
|
||||||
|
layer.self_attn.n_heads //= N
|
||||||
|
layer.self_attn.n_kv_heads //= N
|
||||||
|
|
||||||
|
# Shard the MLP
|
||||||
|
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
|
||||||
|
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
|
||||||
|
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024-2025 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
|
|||||||
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
def ssm_step(self, x, state=None):
|
def ssm_step(self, x, A, state=None):
|
||||||
A = -mx.exp(self.A_log)
|
|
||||||
D = self.D
|
D = self.D
|
||||||
deltaBC = self.x_proj(x)
|
deltaBC = self.x_proj(x)
|
||||||
delta, B, C = mx.split(
|
delta, B, C = map(
|
||||||
|
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
||||||
|
mx.split(
|
||||||
deltaBC,
|
deltaBC,
|
||||||
indices_or_sections=[
|
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
|
||||||
self.time_step_rank,
|
|
||||||
self.time_step_rank + self.ssm_state_size,
|
|
||||||
],
|
|
||||||
axis=-1,
|
axis=-1,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if self.use_bcdt_rms:
|
if self.use_bcdt_rms:
|
||||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
@@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
|
|||||||
y = y + D * x
|
y = y + D * x
|
||||||
return y, new_state
|
return y, new_state
|
||||||
|
|
||||||
def __call__(self, x, cache):
|
def _process_sequence(self, x, conv_cache, state_cache):
|
||||||
B, T, D = x.shape
|
B, T, D = x.shape
|
||||||
if cache is None:
|
xz = self.in_proj(x)
|
||||||
cache = [None, None]
|
x, z = xz.split(indices_or_sections=2, axis=-1)
|
||||||
|
|
||||||
|
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
|
||||||
|
x = nn.silu(conv_out)
|
||||||
|
|
||||||
|
A = -mx.exp(self.A_log)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
current_state = state_cache
|
||||||
|
y = []
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
xt = x[:, t, :]
|
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
|
||||||
xz = self.in_proj(xt)
|
y.append(y_t)
|
||||||
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
|
y = mx.stack(y, axis=1)
|
||||||
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
z = self.out_proj(nn.silu(z) * y)
|
||||||
x_t = conv_out.squeeze(1)
|
return z, (new_conv_cache, current_state)
|
||||||
x_t = nn.silu(x_t)
|
|
||||||
y_t, cache[1] = self.ssm_step(x_t, cache[1])
|
def __call__(self, x, cache):
|
||||||
z_t = nn.silu(z_t)
|
if cache is None:
|
||||||
output_t = y_t * z_t
|
conv_cache, state_cache = None, None
|
||||||
output_t = self.out_proj(output_t)
|
else:
|
||||||
outputs.append(output_t)
|
conv_cache, state_cache = cache[0], cache[1]
|
||||||
output = mx.stack(outputs, axis=1)
|
|
||||||
|
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
||||||
|
x, conv_cache, state_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cache, MambaCache):
|
||||||
|
cache[0] = new_conv_cache
|
||||||
|
cache[1] = new_state_cache
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|||||||
@@ -147,11 +147,11 @@ def min_p_sampling(
|
|||||||
logprobs = logprobs * (1 / temperature)
|
logprobs = logprobs * (1 / temperature)
|
||||||
|
|
||||||
# Indices sorted in decreasing order
|
# Indices sorted in decreasing order
|
||||||
sorted_indices = mx.argsort(-logprobs).squeeze(0)
|
sorted_indices = mx.argsort(-logprobs, axis=-1)
|
||||||
sorted_logprobs = logprobs[..., sorted_indices]
|
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
|
||||||
|
|
||||||
# Top probability
|
# Top probability
|
||||||
top_logprobs = logprobs[..., sorted_indices[0]]
|
top_logprobs = sorted_logprobs[:, 0:1]
|
||||||
|
|
||||||
# Calculate the min_p threshold
|
# Calculate the min_p threshold
|
||||||
scaled_min_p = top_logprobs + math.log(min_p)
|
scaled_min_p = top_logprobs + math.log(min_p)
|
||||||
@@ -163,9 +163,9 @@ def min_p_sampling(
|
|||||||
# Create pool of tokens with probability less than scaled min_p
|
# Create pool of tokens with probability less than scaled min_p
|
||||||
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
||||||
|
|
||||||
# Return sampled token
|
# Return sampled tokens
|
||||||
sorted_token = mx.random.categorical(selected_logprobs)
|
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
|
||||||
return sorted_indices[sorted_token]
|
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
@@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
|
|
||||||
# sort probs in ascending order
|
# sort probs in ascending order
|
||||||
sorted_indices = mx.argsort(probs, axis=-1)
|
sorted_indices = mx.argsort(probs, axis=-1)
|
||||||
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
|
||||||
|
|
||||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||||
|
|
||||||
@@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_token = mx.random.categorical(mx.log(top_probs))
|
sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
|
||||||
token = sorted_indices.squeeze(0)[sorted_token]
|
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||||
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
|||||||
@@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
|||||||
return prompt.rstrip()
|
return prompt.rstrip()
|
||||||
|
|
||||||
|
|
||||||
|
def process_message_content(messages):
|
||||||
|
"""
|
||||||
|
Convert message content to a format suitable for `apply_chat_template`.
|
||||||
|
|
||||||
|
The function operates on messages in place. It converts the 'content' field
|
||||||
|
to a string instead of a list of text fragments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_list (list): A list of dictionaries, where each dictionary may
|
||||||
|
have a 'content' key containing a list of dictionaries with 'type' and
|
||||||
|
'text' keys.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the 'content' type is not supported or if 'text' is missing.
|
||||||
|
|
||||||
|
"""
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_fragments = [
|
||||||
|
fragment["text"] for fragment in content if fragment["type"] == "text"
|
||||||
|
]
|
||||||
|
if len(text_fragments) != len(content):
|
||||||
|
raise ValueError("Only 'text' content type is supported.")
|
||||||
|
message["content"] = "".join(text_fragments)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptCache:
|
class PromptCache:
|
||||||
cache: List[Any] = field(default_factory=list)
|
cache: List[Any] = field(default_factory=list)
|
||||||
@@ -591,8 +618,10 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
||||||
if self.tokenizer.chat_template:
|
if self.tokenizer.chat_template:
|
||||||
|
messages = body["messages"]
|
||||||
|
process_message_content(messages)
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
body["messages"],
|
messages,
|
||||||
body.get("tools", None),
|
body.get("tools", None),
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -140,8 +140,8 @@ def evaluate(
|
|||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
):
|
):
|
||||||
all_losses = 0
|
all_losses = mx.array(0.0)
|
||||||
ntokens = 0
|
ntokens = mx.array(0)
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ def linear_to_lora_layers(
|
|||||||
"phimoe",
|
"phimoe",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
|
"helium",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"cohere",
|
"cohere",
|
||||||
"cohere2",
|
"cohere2",
|
||||||
|
|||||||
@@ -306,12 +306,12 @@ def generate_step(
|
|||||||
|
|
||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.async_eval(y, logprobs)
|
mx.eval(y, logprobs)
|
||||||
n = 0
|
n = 0
|
||||||
while True:
|
while True:
|
||||||
if n != max_tokens:
|
if n != max_tokens:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.async_eval(next_y, next_logprobs)
|
mx.eval(next_y, next_logprobs)
|
||||||
if n == 0:
|
if n == 0:
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||||
@@ -398,8 +398,9 @@ def speculative_generate_step(
|
|||||||
quantize_cache_fn(cache)
|
quantize_cache_fn(cache)
|
||||||
|
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
y = sampler(logprobs).squeeze(0)
|
logprobs = logprobs.squeeze(0)
|
||||||
return y, logprobs.squeeze(0)
|
y = sampler(logprobs)
|
||||||
|
return y, logprobs
|
||||||
|
|
||||||
def _prefill(model, cache, y):
|
def _prefill(model, cache, y):
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
@@ -626,6 +627,8 @@ def load_config(model_path: Path) -> dict:
|
|||||||
def load_model(
|
def load_model(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
|
strict: bool = True,
|
||||||
|
sequential_load: bool = False,
|
||||||
model_config: dict = {},
|
model_config: dict = {},
|
||||||
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
@@ -637,6 +640,8 @@ def load_model(
|
|||||||
lazy (bool): If False eval the model parameters to make sure they are
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
|
strict (bool): Whether or not to raise an exception if weights don't
|
||||||
|
match. Default: ``True``
|
||||||
model_config (dict, optional): Optional configuration parameters for the
|
model_config (dict, optional): Optional configuration parameters for the
|
||||||
model. Defaults to an empty dictionary.
|
model. Defaults to an empty dictionary.
|
||||||
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||||
@@ -659,7 +664,7 @@ def load_model(
|
|||||||
# Try weight for back-compat
|
# Try weight for back-compat
|
||||||
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
||||||
|
|
||||||
if not weight_files:
|
if not weight_files and strict:
|
||||||
logging.error(f"No safetensors found in {model_path}")
|
logging.error(f"No safetensors found in {model_path}")
|
||||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||||
|
|
||||||
@@ -693,9 +698,18 @@ def load_model(
|
|||||||
class_predicate=class_predicate,
|
class_predicate=class_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()), strict=strict)
|
||||||
|
|
||||||
|
if mx.distributed.init().size() > 1:
|
||||||
|
if not hasattr(model, "shard"):
|
||||||
|
raise RuntimeError("Model doesn't support distributed inference.")
|
||||||
|
model.shard()
|
||||||
|
|
||||||
if not lazy:
|
if not lazy:
|
||||||
|
weights.clear()
|
||||||
|
if sequential_load:
|
||||||
|
for layer in model.layers:
|
||||||
|
mx.eval(layer.parameters())
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -708,6 +722,7 @@ def load(
|
|||||||
model_config={},
|
model_config={},
|
||||||
adapter_path: Optional[str] = None,
|
adapter_path: Optional[str] = None,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
|
sequential_load: bool = False,
|
||||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||||
"""
|
"""
|
||||||
Load the model and tokenizer from a given path or a huggingface repository.
|
Load the model and tokenizer from a given path or a huggingface repository.
|
||||||
@@ -723,6 +738,8 @@ def load(
|
|||||||
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
|
sequential_load (bool): If True then load each layer sequentially to
|
||||||
|
ensure that we are not wasting memory.
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
||||||
|
|
||||||
@@ -732,7 +749,7 @@ def load(
|
|||||||
"""
|
"""
|
||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
model, config = load_model(model_path, lazy)
|
model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
|
||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = load_adapters(model, adapter_path)
|
model = load_adapters(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -746,7 +763,7 @@ def load(
|
|||||||
def fetch_from_hub(
|
def fetch_from_hub(
|
||||||
model_path: Path, lazy: bool = False
|
model_path: Path, lazy: bool = False
|
||||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||||
model, config = load_model(model_path, lazy)
|
model, config = load_model(model_path, lazy=lazy)
|
||||||
tokenizer = load_tokenizer(
|
tokenizer = load_tokenizer(
|
||||||
model_path, eos_token_ids=config.get("eos_token_id", None)
|
model_path, eos_token_ids=config.get("eos_token_id", None)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,6 +28,12 @@ class TestSampleUtils(unittest.TestCase):
|
|||||||
token = top_p_sampling(logits, 0.95, temperature).item()
|
token = top_p_sampling(logits, 0.95, temperature).item()
|
||||||
self.assertTrue(token in (1, 2, 3))
|
self.assertTrue(token in (1, 2, 3))
|
||||||
|
|
||||||
|
# Batch mode works
|
||||||
|
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
||||||
|
logits = mx.log(probs)
|
||||||
|
tokens = top_p_sampling(logits, 0.5, temperature)
|
||||||
|
self.assertEqual(tokens.tolist(), [0, 1])
|
||||||
|
|
||||||
def test_min_p_sampling(self):
|
def test_min_p_sampling(self):
|
||||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||||
logits = mx.log(probs)
|
logits = mx.log(probs)
|
||||||
@@ -42,6 +48,12 @@ class TestSampleUtils(unittest.TestCase):
|
|||||||
token = min_p_sampling(logits, 0.05)
|
token = min_p_sampling(logits, 0.05)
|
||||||
self.assertTrue(token in (0, 3))
|
self.assertTrue(token in (0, 3))
|
||||||
|
|
||||||
|
# Batch mode works
|
||||||
|
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
||||||
|
logits = mx.log(probs)
|
||||||
|
tokens = min_p_sampling(logits, 0.7)
|
||||||
|
self.assertEqual(tokens.tolist(), [0, 1])
|
||||||
|
|
||||||
def test_top_k_sampling(self):
|
def test_top_k_sampling(self):
|
||||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||||
logits = mx.log(probs)
|
logits = mx.log(probs)
|
||||||
|
|||||||
@@ -80,6 +80,29 @@ class TestServer(unittest.TestCase):
|
|||||||
self.assertIn("id", response_body)
|
self.assertIn("id", response_body)
|
||||||
self.assertIn("choices", response_body)
|
self.assertIn("choices", response_body)
|
||||||
|
|
||||||
|
def test_handle_chat_completions_with_content_fragments(self):
|
||||||
|
url = f"http://localhost:{self.port}/v1/chat/completions"
|
||||||
|
chat_post_data = {
|
||||||
|
"model": "chat_model",
|
||||||
|
"max_tokens": 10,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.85,
|
||||||
|
"repetition_penalty": 1.2,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "You are a helpful assistant."}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
response = requests.post(url, json=chat_post_data)
|
||||||
|
response_body = response.text
|
||||||
|
self.assertIn("id", response_body)
|
||||||
|
self.assertIn("choices", response_body)
|
||||||
|
|
||||||
def test_handle_models(self):
|
def test_handle_models(self):
|
||||||
url = f"http://localhost:{self.port}/v1/models"
|
url = f"http://localhost:{self.port}/v1/models"
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
|
|||||||
Reference in New Issue
Block a user