32 Commits

Author SHA1 Message Date
Awni Hannun
65b792d7c0 fix lazy load 2025-02-06 07:28:59 -08:00
Angelos Katharopoulos
617f9289b9 Make the chat distributed 2025-02-06 07:28:59 -08:00
Angelos Katharopoulos
026362e0f8 Remove async eval and add sequential load 2025-02-06 07:28:58 -08:00
Angelos Katharopoulos
a0ce0594f6 Temporarily remove async_eval 2025-02-06 07:28:03 -08:00
Angelos Katharopoulos
d77840207c Start distributed inference for llama models 2025-02-06 07:28:03 -08:00
Pedro Cuenca
e2e5478da5 READMEs: fix typo in link, minor update. (#1246) 2025-02-04 11:52:32 -08:00
Awni Hannun
21d0ab6e8a fix deepseek sharding (#1242) 2025-02-03 16:59:50 -08:00
Gökdeniz Gülmez
0989c073b0 Optimizations for mamba1 (#1213)
* added mx.einsum() operations: before: 41.293 tokens-per-sec, after: 57.822 tokens-per-sec

* Fused Operations in delta, B, C = ... :. Before: 57.822 tokens-per-sec, after: 83.890 tokens-per-sec

* Pre-computing A_log. After: 83.890 tokens-per-sec, before: 85.848 tokens-per-sec

* Update MambaBlock, Batched Input Processing, Improved Cache Handling, Pre-computed Constants, Cleaner State Management, Explicit Return Values:. Before: 82.442 tokens-per-sec, after: 129.130 tokens-per-sec.

* cleaning up and adding apple copyright to helium modelfile

* update Copyright to this year

* nits + even faster

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-02-03 13:36:08 -08:00
Awni Hannun
d9924d08d1 Fix no validation in lora (#1241) 2025-02-03 09:55:24 -08:00
Awni Hannun
9c2ef38d4d only download local shard (#1240) 2025-02-02 13:58:44 -08:00
Awni Hannun
e8afb59de4 better overflow correction (#1229) 2025-01-28 14:37:30 -08:00
Anchen
7a83077cd7 chore(mlx-lm): support text type content in messages (#1225)
* chore(mlx-lm): support text type content

* chore: optimize the messagef content processing

* nits + format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-27 17:13:50 -08:00
Awni Hannun
f44a52e2dc batched min p and fix spec gen sampling (#1222) 2025-01-27 15:40:31 -08:00
Gökdeniz Gülmez
77faa14ba4 adding support for kyutai's helium (#1208)
* initial commit

* adding helium into training

* Update ACKNOWLEDGMENTS.md

* nits

* nits

* fixes / nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-26 07:19:07 -08:00
Awni Hannun
9a3ddc3e65 some fixes for pipeline parallel deep seek r1 (#1216) 2025-01-21 19:40:29 -08:00
Victor Nogueira
df1406735b Fix dataset variable name, in datasets.py (#1212) 2025-01-21 14:12:43 -08:00
Jarrett
07f88f8057 fix(lora): add back store_true default args (#1205) 2025-01-16 11:15:42 -08:00
Awni Hannun
50f0a7f6d9 add internlm3 (#1206) 2025-01-15 14:55:41 -08:00
Ivan Fioravanti
6ae6c72c2e reduction moved to CPU in case of distributed training (#1200) 2025-01-14 17:20:42 -08:00
Awni Hannun
c117af83b8 fix gpt bigcode (#1204) 2025-01-13 10:22:32 -08:00
Chime Ogbuji
0228c46434 Custom local dataset features (#1085)
* Generalize prompt_feature and completion_feature for use in local datasets to facilitate compatibility with many other training dataset formats.

* Persist configured prompt/completion key

* rebase + nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-13 10:01:18 -08:00
Prince Canuma
bf2da36fc6 Fix Cohere2: mask shape error (long context) (#1202)
* fix mask shape error (long context)

* Update llms/mlx_lm/models/cohere2.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* revert layer_idx

* black formatting

* Update cohere2.py

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-12 12:58:08 -08:00
Xingjun.Wang
514502da22 Support snapshot_download for ModelScope (#1194)
* add MLX_USE_MODELSCOPE env

* update

* update snapshot_download

* update

* remove modelscope dependency and add import check

* update

* nits

* fix

---------

Co-authored-by: wangxingjun778 <jason@U-C7X6TX5G-2239.local>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-10 15:29:34 -08:00
Awni Hannun
93c5cfd781 Add a speculative decoding generator (#1155)
* add a speculative decoding generator

* fix

* fixes

* optional kwarg pop
2025-01-10 15:27:08 -08:00
Awni Hannun
5cae0a60e6 deepseek v3 model with pipeline parallelism (#1191)
* deepseekv3

* use upload_large_file instead of deprecated multi comit

* add pipeline generation and example

* comment

* get fp16 working

* use mlx==0.22
2025-01-09 15:55:53 -08:00
Jarrett
40b88eff48 fix(lora): config yaml & arg default merge bug (#1196) 2025-01-09 11:33:54 -08:00
Pedro Cuenca
b8f0cacfa8 Use upload_large_folder (#1193) 2025-01-07 09:18:31 -08:00
Awni Hannun
9183fe8b6d fix (#1192) 2025-01-06 10:12:07 -08:00
Chime Ogbuji
f2619f507c Add support for fewshot and apply chat template lm_eval functionality (#1180)
* Add support for multiturn fewshot examples and chat templates

Added two new arguments to the evaluation script: `--fewshot-as-multiturn` and `--apply-chat-template` which correspond to lm_eval options of similar names and are very often used to ensure apples-to-apples comparisons of lm_evaluation results

* Add HF overrides for methods needed by added options

* don't add duplicate bos

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-06 07:58:43 -08:00
Angelos Katharopoulos
25ec2d8c44 Change the eos-token argument for mlx_lm.generate (#1176) 2025-01-05 22:26:05 -08:00
Awni Hannun
c4833a2f55 fix encoding with special tokens + chat template (#1189) 2025-01-03 10:50:59 -08:00
Ivan Fioravanti
3a58c36109 Improvements to mlx_lm.manage (#1178)
* improvements to manage. Default value is N and size added to deletion confirmation.

* Fixing case for no case

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-01 07:25:57 -08:00
39 changed files with 1814 additions and 264 deletions

View File

@@ -32,7 +32,7 @@ jobs:
pip install --upgrade pip
pip install unittest-xml-reporting
cd llms/
pip install -e ".[testing]"
pip install -e ".[test]"
- run:
name: Run Python tests
command: |

View File

@@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.

View File

@@ -45,7 +45,7 @@ Some more useful examples are listed below.
### Hugging Face
Note: You can now directly download a few converted checkpoints from the [MLX
You can directly use or download converted checkpoints from the [MLX
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
We encourage you to join the community and [contribute new
models](https://github.com/ml-explore/mlx-examples/issues/155).

View File

@@ -58,7 +58,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=True
)
text = generate(model, tokenizer, prompt=prompt, verbose=True)
@@ -115,7 +115,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=True
)
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
@@ -164,7 +164,7 @@ mlx_lm.convert \
```
Models can also be converted and quantized directly in the
[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
Face Space.
### Long Prompts and Generations

View File

@@ -241,14 +241,25 @@ Refer to the documentation for the model you are fine-tuning for more details.
{"prompt": "What is the capital of France?", "completion": "Paris."}
```
For the `completions` data format, a different key can be used for the prompt
and completion by specifying the following in the YAML config:
```yaml
prompt_feature: "input"
completion_feature: "output"
```
Here, `"input"` is the expected key instead of the default `"prompt"`, and
`"output"` is the expected key instead of `"completion"`.
`text`:
```jsonl
{"text": "This is an example for the model."}
```
Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored.
Note, the format is automatically determined by the dataset. Note also, keys
in each line not expected by the loader will be ignored.
> [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than
@@ -270,7 +281,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example:
```
```yaml
hf_dataset:
name: "billsum"
prompt_feature: "text"

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.20.4"
__version__ = "0.21.0"

View File

@@ -110,29 +110,17 @@ def main():
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
if not args.ignore_chat_template and tokenizer.chat_template is not None:
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=False, continue_final_message=True
)
# Treat the prompt as a prefix assuming that the suffix will be
# provided at generation time.
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
prompt = prompt[:-n]
else:
prompt = args.prompt
prompt = tokenizer.encode(args.prompt)
cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
y = mx.array(prompt)
# Process the prompt
start = time.time()

View File

@@ -16,6 +16,25 @@ DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def share_message(world, prompt):
if world.size() == 1:
return prompt
if world.rank() == 0:
size = mx.array([len(prompt)])
else:
size = mx.array([0])
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
if size == 0:
return []
if world.rank() == 0:
prompt = mx.array(prompt)
else:
prompt = mx.array([0] * len(prompt))
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
@@ -54,6 +73,7 @@ def setup_arg_parser():
def main():
world = mx.distributed.init()
parser = setup_arg_parser()
args = parser.parse_args()
@@ -63,18 +83,30 @@ def main():
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
sequential_load=mx.distributed.init().size() > 1,
)
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
print(f"Node {world.rank()} of {world.size()}", flush=True)
print(
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
flush=True,
)
world.barrier()
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
if query == "q":
if world.rank() == 0:
query = input(">> ")
if query == "q":
prompt = []
else:
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True
)
prompt = share_message(world, prompt)
if len(prompt) == 0:
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate(
model,
tokenizer,
@@ -83,8 +115,10 @@ def main():
sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache,
):
print(response.text, flush=True, end="")
print()
if world.rank() == 0:
print(response, flush=True, end="")
if world.rank() == 0:
print()
if __name__ == "__main__":

View File

@@ -1,4 +1,8 @@
# Adapted from a PyTorch implementation by David Grangier
# Copyright © 2024 Apple Inc.
"""
Adapted from a PyTorch implementation by David Grangier
"""
import argparse
import json
@@ -6,7 +10,7 @@ import logging
import os
from importlib.metadata import version
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import lm_eval
import mlx.core as mx
@@ -73,15 +77,19 @@ class MLXLM(LM):
path_or_hf_repo: str,
batch_size: int = 16,
max_tokens: Optional[int] = None,
use_chat_template: Optional[bool] = None,
) -> None:
super().__init__()
self._batch_size = batch_size
self._model, self._tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self._tokenizer.model_max_length
self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template or (
self.tokenizer.chat_template is not None
)
def _score_fn(self, inputs, tokenize=True, step_size=32):
if tokenize:
inputs = self._tokenizer.encode(inputs)
inputs = self._tokenize(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:]
@@ -145,7 +153,12 @@ class MLXLM(LM):
return results
def _tokenize(self, texts):
return [tuple(self._tokenizer.encode(t)) for t in texts]
return [
tuple(
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
)
for t in texts
]
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
@@ -217,6 +230,9 @@ class MLXLM(LM):
)
return [(r[0], r[1] == r[2]) for r in results]
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
@@ -277,23 +293,16 @@ class MLXLM(LM):
assert "until" in keys
untils = [x["until"] for x in options]
completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
if (
hasattr(self._tokenizer, "apply_chat_template")
and self._tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": context}]
context = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
context = self._tokenize(context)
max_tokens = min(
self._max_tokens,
self._tokenizer.model_max_length - len(self._tokenizer.encode(context)),
self.tokenizer.model_max_length - len(context),
)
text = ""
for response in stream_generate(
self._model, self._tokenizer, prompt=context, max_tokens=max_tokens
self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
):
text += response.text
if any(u in text for u in until):
@@ -321,7 +330,28 @@ def main():
type=int,
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
)
parser.add_argument(
"--limit",
default=1.0,
help="Limit the number of examples per task.",
type=float,
)
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
parser.add_argument(
"--fewshot-as-multiturn",
action="store_true",
help="Whether to provide the fewshot examples as a multiturn "
"conversation or a single user turn.",
default=False,
)
parser.add_argument(
"--apply-chat-template",
action=argparse.BooleanOptionalAction,
help="Specifies whether to apply a chat template to the prompt. If "
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
)
args = parser.parse_args()
output_dir = Path(args.output_dir)
@@ -332,12 +362,19 @@ def main():
mx.random.seed(args.seed)
lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens)
lm = MLXLM(
args.model,
batch_size=args.batch_size,
max_tokens=args.max_tokens,
use_chat_template=args.apply_chat_template,
)
results = lm_eval.simple_evaluate(
model=lm,
tasks=args.tasks,
fewshot_as_multiturn=args.fewshot_as_multiturn,
apply_chat_template=lm.use_chat_template,
num_fewshot=args.num_shots,
limit=args.limit,
random_seed=args.seed,
numpy_random_seed=args.seed,
torch_random_seed=args.seed,

View File

@@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model)
# User turn
prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Assistant response
response = generate(
@@ -32,9 +30,7 @@ response = generate(
# User turn
prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Assistant response
response = generate(

View File

@@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}]
# Transform the prompt into the chat template
prompt = tokenizer.apply_chat_template(
conversation=conversation, tokenize=False, add_generation_prompt=True
conversation=conversation, add_generation_prompt=True
)
# Specify the maximum number of tokens

View File

@@ -0,0 +1,127 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
```
mlx.launch \
--hostfile /path/to/hosts.txt \
--backend mpi \
/path/to/pipeline_generate.py \
--prompt "hello world"
```
Make sure you can run MLX over MPI on two hosts. For more information see the
documentation:
https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
"""
import argparse
import json
from pathlib import Path
import mlx.core as mx
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from mlx_lm import load, stream_generate
from mlx_lm.utils import load_model, load_tokenizer
def download(repo: str, allow_patterns: list[str]) -> Path:
return Path(
snapshot_download(
repo,
allow_patterns=allow_patterns,
)
)
def shard_and_load(repo):
# Get model path with everything but weight safetensors
model_path = download(
args.model,
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
)
# Lazy load and shard model to figure out
# which weights we need
model, _ = load_model(model_path, lazy=True, strict=False)
group = mx.distributed.init(backend="mpi")
rank = group.rank()
model.model.pipeline(group)
# Figure out which files we need for the local shard
with open(model_path / "model.safetensors.index.json", "r") as fid:
weight_index = json.load(fid)["weight_map"]
local_files = set()
for k, _ in tree_flatten(model.parameters()):
local_files.add(weight_index[k])
# Download weights for local shard
download(args.model, allow_patterns=local_files)
# Load and shard the model, and load the weights
tokenizer = load_tokenizer(model_path)
model, _ = load_model(model_path, lazy=True, strict=False)
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
group = mx.distributed.init(backend="mpi")
rank = group.rank()
def rprint(*args, **kwargs):
if rank == 0:
print(*args, **kwargs)
model, tokenizer = shard_and_load(args.model)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model, tokenizer, prompt, max_tokens=args.max_tokens
):
rprint(response.text, end="", flush=True)
rprint()
rprint("=" * 10)
rprint(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
rprint(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
rprint(f"Peak memory: {response.peak_memory:.3f} GB")

View File

@@ -43,10 +43,11 @@ def setup_arg_parser():
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--eos-token",
"--extra-eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
default=(),
nargs="+",
help="Add tokens in the list of eos tokens that stop generation.",
)
parser.add_argument(
"--system-prompt",
@@ -130,6 +131,18 @@ def setup_arg_parser():
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
parser.add_argument(
"--draft-model",
type=str,
help="A model to be used for speculative decoding.",
default=None,
)
parser.add_argument(
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
default=2,
)
return parser
@@ -161,8 +174,6 @@ def main():
{} if not using_cache else json.loads(metadata["tokenizer_config"])
)
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model_path = args.model
if using_cache:
@@ -180,7 +191,10 @@ def main():
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
sequential_load=mx.distributed.init().size() > 1,
)
for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
@@ -190,10 +204,7 @@ def main():
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
if not args.ignore_chat_template and tokenizer.chat_template is not None:
if args.system_prompt is not None:
messages = [{"role": "system", "content": args.system_prompt}]
else:
@@ -213,23 +224,40 @@ def main():
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
prompt = tokenizer.encode(prompt)
if args.draft_model is not None:
draft_model, draft_tokenizer = load(args.draft_model)
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
raise ValueError("Draft model tokenizer does not match model tokenizer.")
else:
draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
world = mx.distributed.init()
print(f"Node {world.rank()} of {world.size()}", flush=True)
world.barrier()
response = generate(
model,
tokenizer,
prompt,
max_tokens=args.max_tokens,
verbose=args.verbose,
sampler=sampler,
verbose=args.verbose and world.rank() == 0,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
)
if not args.verbose:
if not args.verbose and mx.distributed.init().rank() == 0:
print(response)
mx.synchronize()
if __name__ == "__main__":

View File

@@ -2,6 +2,7 @@
import argparse
import math
import os
import re
import types
from pathlib import Path
@@ -57,6 +58,8 @@ CONFIG_DEFAULTS = {
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
"config": None,
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
}
@@ -66,6 +69,7 @@ def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
)
@@ -88,7 +92,6 @@ def build_parser():
"--fine-tune-type",
type=str,
choices=["lora", "dora", "full"],
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
@@ -148,7 +151,7 @@ def build_parser():
parser.add_argument(
"-c",
"--config",
default=None,
type=str,
help="A YAML configuration file with the training options",
)
parser.add_argument(
@@ -157,7 +160,7 @@ def build_parser():
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser
@@ -271,6 +274,7 @@ def run(args, training_callback: TrainingCallback = None):
def main():
os.environ["TOKENIZERS_PARALLELISM"] = "true"
parser = build_parser()
args = parser.parse_args()
config = args.config

View File

@@ -6,19 +6,18 @@ from transformers.commands.user import tabulate
def ask_for_confirmation(message: str) -> bool:
"""Ask user for confirmation with Y/N prompt.
Returns True for Y/yes, False for N/no/empty."""
y = ("y", "yes", "1")
n = ("n", "no", "0")
all_values = y + n + ("",)
full_message = f"{message} (Y/n) "
n = ("n", "no", "0", "")
full_message = f"{message} (y/n) "
while True:
answer = input(full_message).lower()
if answer == "":
return False
if answer in y:
return True
if answer in n:
return False
print(f"Invalid input. Must be one of {all_values}")
print(f"Invalid input. Must be one of: yes/no/y/n or empty for no")
def main():
@@ -43,9 +42,7 @@ def main():
args = parser.parse_args()
if args.scan:
print(
"Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".'
)
print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".')
hf_cache_info = scan_cache_dir()
print(
tabulate(
@@ -86,35 +83,41 @@ def main():
if args.pattern in repo.repo_id
]
if repos:
print("\nFound the following models:")
print(
tabulate(
rows=[
[
repo.repo_id,
repo.size_on_disk_str, # Added size information
str(repo.repo_path),
]
for repo in repos
],
headers=[
"REPO ID",
"SIZE", # Added size header
"LOCAL PATH",
],
)
)
confirmed = ask_for_confirmation(f"Confirm deletion ?")
confirmed = ask_for_confirmation(
"\nAre you sure you want to delete these models?"
)
if confirmed:
for model_info in repos:
print(f"\nDeleting {model_info.repo_id}...")
for revision in sorted(
model_info.revisions, key=lambda revision: revision.commit_hash
):
strategy = hf_cache_info.delete_revisions(revision.commit_hash)
strategy.execute()
print("Model(s) deleted.")
print("\nModel(s) deleted successfully.")
else:
print("Deletion is cancelled. Do nothing.")
print("\nDeletion cancelled - no changes made.")
else:
print(f"No models found.")
print(f'No models found matching pattern "{args.pattern}"')
if __name__ == "__main__":

View File

@@ -156,12 +156,13 @@ class CohereModel(nn.Module):
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
if mask is None:
j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)

View File

@@ -364,8 +364,30 @@ class DeepseekV2Model(nn.Module):
DeepseekV2DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__(
self,
x: mx.array,
@@ -374,14 +396,31 @@ class DeepseekV2Model(nn.Module):
) -> mx.array:
h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
cache = [None] * self.num_layers
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for i in range(self.num_layers):
h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
return self.norm(h)
@@ -418,4 +457,4 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@@ -0,0 +1,478 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek_v3"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
routed_scaling_factor: float = 1.0
kv_lora_rank: int = 512
q_lora_rank: int = 1536
qk_rope_head_dim: int = 64
v_head_dim: int = 128
qk_nope_head_dim: int = 128
topk_method: str = "noaux_tc"
scoring_func: str = "sigmoid"
norm_topk_prob: bool = True
n_group: Optional[int] = None
topk_group: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Dict = None
attention_bias: bool = False
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
return mx.clip(linear_func, 0, 1)
class DeepseekV3YarnRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
super().__init__()
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
scaling_factor, mscale_all_dim
)
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
freq_inter = scaling_factor * base ** (
mx.arange(0, dim, 2, dtype=mx.float32) / dim
)
low, high = yarn_find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_position_embeddings,
)
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
self._freqs = (freq_inter * freq_extra) / (
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
)
def __call__(self, x, offset=0):
if self.mscale != 1.0:
x = self.mscale * x
return mx.fast.rope(
x,
x.shape[-1],
traditional=True,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
# A clipped silu to prevent fp16 from overflowing
@partial(mx.compile, shapeless=True)
def clipped_silu(x):
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
class DeepseekV3Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.scale = self.q_head_dim**-0.5
if self.q_lora_rank is None:
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
)
else:
self.q_a_proj = nn.Linear(
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
self.q_b_proj = nn.Linear(
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=config.attention_bias,
)
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self.rope = DeepseekV3YarnRotaryEmbedding(
dim=self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**rope_kwargs,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekV3MLP(nn.Module):
def __init__(
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = (
config.intermediate_size if intermediate_size is None else intermediate_size
)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x):
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.sigmoid(gates.astype(mx.float32))
assert self.topk_method == "noaux_tc", "Unsupported topk method."
bsz, seq_len = x.shape[:2]
scores = scores + self.e_score_correction_bias
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if self.top_k > 1 and self.norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * self.routed_scaling_factor
return inds, scores
class DeepseekV3MoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU(
config.hidden_size,
config.moe_intermediate_size,
config.n_routed_experts,
activation=clipped_silu,
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV3MLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekV3DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekV3Attention(config)
self.mlp = (
DeepseekV3MoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekV3MLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
return h + r
class DeepseekV3Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekV3DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * self.num_layers
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for i in range(self.num_layers):
h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekV3Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
# Remove multi-token prediction layer
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")}
@property
def layers(self):
return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@@ -145,16 +145,16 @@ class GPTBigCodeModel(nn.Module):
hidden_states = self.wte(inputs)
mask = None
if hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if mask is not None and hidden_states.shape[1] > 1:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
position_ids = mx.array(np.arange(L))
else:
position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L))
hidden_states += self.wpe(position_ids)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)

View File

@@ -0,0 +1,185 @@
# Copyright © 2025 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
rms_norm_eps: float
vocab_size: int
attention_bias: bool
head_dim: int
max_position_embeddings: int
mlp_bias: bool
model_type: str
rope_theta: float
tie_word_embeddings: bool
class HeliumAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class HeliumMLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.intermediate_size = args.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=args.mlp_bias
)
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class HeliumDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = HeliumAttention(args)
self.mlp = HeliumMLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class HeliumModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_hidden_layers = args.num_hidden_layers
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = HeliumModel(args)
self.vocab_size = args.vocab_size
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,241 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
bias: bool = False
qkv_bias: bool = False
max_position_embeddings: int = 32768
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "rope_type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["rope_type"] not in ["linear", "dynamic"]:
raise ValueError(
"rope_scaling 'rope_type' currently only supports 'linear' or 'dynamic"
)
class DynamicNTKScalingRoPE(nn.Module):
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
):
super().__init__()
self.max_position_embeddings = max_position_embeddings
self.original_base = base
self.dims = dims
self.traditional = traditional
self.scale = scale
def extra_repr(self):
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
def __call__(self, x, offset: int = 0):
seq_len = x.shape[1] + offset
if seq_len > self.max_position_embeddings:
base = self.original_base * (
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
) ** (self.dims / (self.dims - 2))
else:
base = self.original_base
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=base,
scale=self.scale,
offset=offset,
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
qkv_bias = args.qkv_bias
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_kv_groups = n_heads // args.num_key_value_heads
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=qkv_bias)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None
and args.rope_scaling["rope_type"] == "linear"
else 2.0
)
self.rope = DynamicNTKScalingRoPE(
head_dim,
max_position_embeddings=args.max_position_embeddings,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim, bias):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size, args.bias)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class InternLM2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = InternLM2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
@property
def layers(self):
return self.model.layers

View File

@@ -200,6 +200,36 @@ class Model(nn.Module):
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()
def all_to_sharded(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
else:
return nn.AllToShardedLinear.from_linear(l, group)
def sharded_to_all(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
else:
return nn.ShardedToAllLinear.from_linear(l, group)
N = group.size()
for layer in self.model.layers:
# Shard the self attention
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
layer.self_attn.n_heads //= N
layer.self_attn.n_kv_heads //= N
# Shard the MLP
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
@property
def layers(self):
return self.model.layers

View File

@@ -1,4 +1,4 @@
# Copyright © 2024 Apple Inc.
# Copyright © 2024-2025 Apple Inc.
import math
from dataclasses import dataclass
@@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
self.intermediate_size, self.hidden_size, bias=args.use_bias
)
def ssm_step(self, x, state=None):
A = -mx.exp(self.A_log)
def ssm_step(self, x, A, state=None):
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = mx.split(
deltaBC,
indices_or_sections=[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1,
delta, B, C = map(
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
mx.split(
deltaBC,
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
axis=-1,
),
)
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
@@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
y = y + D * x
return y, new_state
def __call__(self, x, cache):
def _process_sequence(self, x, conv_cache, state_cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
xz = self.in_proj(x)
x, z = xz.split(indices_or_sections=2, axis=-1)
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
x = nn.silu(conv_out)
A = -mx.exp(self.A_log)
outputs = []
current_state = state_cache
y = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1])
z_t = nn.silu(z_t)
output_t = y_t * z_t
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
y.append(y_t)
y = mx.stack(y, axis=1)
z = self.out_proj(nn.silu(z) * y)
return z, (new_conv_cache, current_state)
def __call__(self, x, cache):
if cache is None:
conv_cache, state_cache = None, None
else:
conv_cache, state_cache = cache[0], cache[1]
output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache
)
if isinstance(cache, MambaCache):
cache[0] = new_conv_cache
cache[1] = new_state_cache
return output

View File

@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
# Copyright © 2023-2025 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

View File

@@ -1,4 +1,4 @@
mlx>=0.19.2
mlx>=0.22.0
numpy
transformers[sentencepiece]>=4.39.3
protobuf

View File

@@ -147,11 +147,11 @@ def min_p_sampling(
logprobs = logprobs * (1 / temperature)
# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logprobs).squeeze(0)
sorted_logprobs = logprobs[..., sorted_indices]
sorted_indices = mx.argsort(-logprobs, axis=-1)
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
# Top probability
top_logprobs = logprobs[..., sorted_indices[0]]
top_logprobs = sorted_logprobs[:, 0:1]
# Calculate the min_p threshold
scaled_min_p = top_logprobs + math.log(min_p)
@@ -163,9 +163,9 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
# Return sampled token
sorted_token = mx.random.categorical(selected_logprobs)
return sorted_indices[sorted_token]
# Return sampled tokens
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
@@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)]
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
@@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
0,
)
sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
return token
sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)

View File

@@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()
def process_message_content(messages):
"""
Convert message content to a format suitable for `apply_chat_template`.
The function operates on messages in place. It converts the 'content' field
to a string instead of a list of text fragments.
Args:
message_list (list): A list of dictionaries, where each dictionary may
have a 'content' key containing a list of dictionaries with 'type' and
'text' keys.
Raises:
ValueError: If the 'content' type is not supported or if 'text' is missing.
"""
for message in messages:
content = message["content"]
if isinstance(content, list):
text_fragments = [
fragment["text"] for fragment in content if fragment["type"] == "text"
]
if len(text_fragments) != len(content):
raise ValueError("Only 'text' content type is supported.")
message["content"] = "".join(text_fragments)
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
@@ -590,14 +617,12 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type
self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if (
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
):
if self.tokenizer.chat_template:
messages = body["messages"]
process_message_content(messages)
prompt = self.tokenizer.apply_chat_template(
body["messages"],
messages,
body.get("tools", None),
tokenize=True,
add_generation_prompt=True,
)
else:

View File

@@ -266,6 +266,18 @@ class TokenizerWrapper:
else {tokenizer.eos_token_id}
)
def add_eos_token(self, token: str):
token_id = None
try:
token_id = int(token)
except ValueError:
token_id = self._tokenizer.convert_tokens_to_ids(token)
if token_id is None:
raise ValueError(f"'{token}' is not a token for this tokenizer")
self._eos_token_ids.add(token_id)
def __getattr__(self, attr):
if attr == "detokenizer":
return self._detokenizer

View File

@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
@@ -10,41 +10,47 @@ class Dataset:
Light-weight wrapper to hold a dataset.
"""
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
self._text_key = text_key
self._data = data
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
text_key: str = "text",
):
self._data = [tokenizer.encode(d[text_key]) for d in data]
for d in self._data:
if d[-1] != tokenizer.eos_token_id:
d.append(tokenizer.eos_token_id)
def __getitem__(self, idx: int):
return self._data[idx][self._text_key]
return self._data[idx]
def __len__(self):
if self._data is None:
return 0
return len(self._data)
class ChatDataset(Dataset):
class ChatDataset:
"""
A dataset for chat data in the format of {"messages": [...]}
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data)
self._tokenizer = tokenizer
self._data = [
tokenizer.apply_chat_template(
d["messages"],
tools=d.get("tools", None),
)
for d in data
]
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
text = self._tokenizer.apply_chat_template(
messages,
tools=self._data[idx].get("tools", None),
tokenize=False,
add_generation_prompt=True,
)
return text
return self._data[idx]
def __len__(self):
return len(self._data)
class CompletionsDataset(Dataset):
class CompletionsDataset:
"""
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values
@@ -55,36 +61,41 @@ class CompletionsDataset(Dataset):
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
completion_key: str = "completion",
prompt_key: str,
completion_key: str,
):
super().__init__(data)
self._tokenizer = tokenizer
self._prompt_key = prompt_key
self._completion_key = completion_key
self._data = [
tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]},
],
)
for d in data
]
def __getitem__(self, idx: int):
data = self._data[idx]
text = self._tokenizer.apply_chat_template(
[
{"role": "user", "content": data[self._prompt_key]},
{"role": "assistant", "content": data[self._completion_key]},
],
tokenize=False,
add_generation_prompt=True,
)
return text
return self._data[idx]
def __len__(self):
return len(self._data)
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
def create_dataset(
data,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
sample = data[0]
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif "prompt" in sample and "completion" in sample:
return CompletionsDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample:
return Dataset(data)
return Dataset(data, tokenizer)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
@@ -92,20 +103,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
)
def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
def load_local_dataset(
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer)
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
return train, valid, test
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
def load_hf_dataset(
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
from datasets import exceptions, load_dataset
try:
@@ -114,7 +135,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test")
train, valid, test = [
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
(
create_dataset(
dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys()
else []
)
for n in names
]
@@ -143,7 +170,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
return Dataset(ds, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
@@ -166,15 +193,22 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
if getattr(args, "hf_dataset", False):
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists():
train, valid, test = load_local_dataset(data_path, tokenizer)
train, valid, test = load_local_dataset(
data_path, tokenizer, prompt_feature, completion_feature
)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args.data, tokenizer)
train, valid, test = load_hf_dataset(
args.data, tokenizer, prompt_feature, completion_feature
)
if args.train and len(train) == 0:
raise ValueError(

View File

@@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
while True:
indices = np.random.permutation(len(batch_idx))
for i in indices:
# Encode batch
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
batch = [dataset[j] for j in batch_idx[i]]
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
@@ -146,8 +140,8 @@ def evaluate(
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
):
all_losses = 0
ntokens = 0
all_losses = mx.array(0.0)
ntokens = mx.array(0)
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@@ -165,8 +159,8 @@ def evaluate(
ntokens += toks
mx.eval(all_losses, ntokens)
all_losses = mx.distributed.all_sum(all_losses)
ntokens = mx.distributed.all_sum(ntokens)
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
return (all_losses / ntokens).item()
@@ -278,9 +272,9 @@ def train(
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses).item()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens).item()
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)

View File

@@ -94,12 +94,14 @@ def linear_to_lora_layers(
"phimoe",
"gemma",
"gemma2",
"helium",
"starcoder2",
"cohere",
"cohere2",
"minicpm",
"deepseek",
"olmo2",
"internlm3",
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type in ["mixtral", "phimoe"]:

View File

@@ -2,10 +2,12 @@
import contextlib
import copy
import functools
import glob
import importlib
import json
import logging
import os
import shutil
import time
from dataclasses import dataclass
@@ -15,7 +17,17 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
try:
from modelscope import snapshot_download
except ImportError:
raise ImportError(
"Please run `pip install modelscope` to activate the ModelScope."
)
else:
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
@@ -153,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
if not model_path.exists():
try:
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
@@ -207,12 +220,6 @@ def generate_step(
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -256,25 +263,17 @@ def generate_step(
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
print(
"[Warning] Specifying sampling arguments to ``generate_step`` is "
"deprecated. Pass in a ``sampler`` instead."
)
if repetition_penalty is not None:
print(
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
"Pass in ``logits_processors`` instead."
)
sampler = sampler or make_sampler(
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
)
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
def _step(y):
with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache)
@@ -287,9 +286,7 @@ def generate_step(
for processor in logits_processors:
logits = processor(tokens, logits)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
quantize_cache_fn(prompt_cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs)
@@ -300,9 +297,7 @@ def generate_step(
prompt_processed_tokens = 0
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
quantize_cache_fn(prompt_cache)
mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size
@@ -311,12 +306,12 @@ def generate_step(
y, logprobs = _step(y)
mx.async_eval(y, logprobs)
mx.eval(y, logprobs)
n = 0
while True:
if n != max_tokens:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
mx.eval(next_y, next_logprobs)
if n == 0:
mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
@@ -329,10 +324,163 @@ def generate_step(
n += 1
def speculative_generate_step(
prompt: mx.array,
model: nn.Module,
draft_model: nn.Module,
*,
num_draft_tokens=2,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
draft_model (nn.Module): The draft model for speculative decoding.
num_draft_tokens (int, optional): The number of draft tokens for
speculative decoding. Default: ``2``.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. The cache must be trimmable.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
y = prompt
tokens = None
# Create the KV cache for generation
if prompt_cache is None:
model_cache = cache.make_prompt_cache(model)
draft_cache = cache.make_prompt_cache(draft_model)
elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)):
raise ValueError("Wrong number of layers in the prompt cache.")
else:
model_cache = prompt_cache[: len(model.layers)]
draft_cache = prompt_cache[len(model.layers) :]
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
def _step(model, cache, y, n_predict=1):
with mx.stream(generation_stream):
logits = model(y[None], cache=cache)
logits = logits[:, -n_predict:, :]
quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
logprobs = logprobs.squeeze(0)
y = sampler(logprobs)
return y, logprobs
def _prefill(model, cache, y):
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
quantize_cache_fn(cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
mx.metal.clear_cache()
return y
def _rewind_cache(num_draft, num_accept):
cache.trim_prompt_cache(model_cache, num_draft - num_accept)
cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0))
def _draft_generate(y, num_draft):
if num_draft == 0:
return mx.array([], mx.uint32)
ys = []
for _ in range(num_draft):
y, _ = _step(draft_model, draft_cache, y)
mx.async_eval(y)
ys.append(y)
return mx.concatenate(ys)
with mx.stream(generation_stream):
draft_y = _prefill(draft_model, draft_cache, y)
y = _prefill(model, model_cache, y)
ntoks = 0
# Set these so the finally block doesn't raise
num_draft = 0
n = 0
try:
while True:
num_draft = min(max_tokens - ntoks, num_draft_tokens)
draft_tokens = _draft_generate(draft_y, num_draft)
y = mx.concatenate([y, draft_tokens])
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
mx.eval(tokens, draft_tokens)
draft_tokens = draft_tokens.tolist()
tokens = tokens.tolist()
n = 0
while n < num_draft:
tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n]
if tn != dtn:
break
n += 1
ntoks += 1
yield tn, lpn
if ntoks == max_tokens:
break
if ntoks < max_tokens:
ntoks += 1
yield tokens[n], logprobs[n]
if ntoks == max_tokens:
break
y = mx.array([tokens[n]], mx.uint32)
draft_y = y
# If we accpeted all the draft tokens, include the last
# draft token in the next draft step since it hasn't been
# processed yet by the draft model
if n == num_draft:
draft_y = mx.concatenate(
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
)
_rewind_cache(num_draft, n)
finally:
_rewind_cache(num_draft, n)
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]],
draft_model: Optional[nn.Module] = None,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
"""
@@ -341,7 +489,11 @@ def stream_generate(
Args:
model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
prompt (Union[str, mx.array, List[int]]): The input prompt string or
integer tokens.
draft_model (Optional[nn.Module]): An optional draft model. If provided
then speculative decoding is used. The draft model must use the same
tokenizer as the main model. Default: ``None``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
@@ -353,16 +505,28 @@ def stream_generate(
tokenizer = TokenizerWrapper(tokenizer)
if not isinstance(prompt, mx.array):
prompt = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
if isinstance(prompt, str):
# Try to infer if special tokens are needed
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
prompt = mx.array(prompt)
detokenizer = tokenizer.detokenizer
if draft_model is None:
kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs)
else:
kwargs.pop("max_kv_size", None)
token_generator = speculative_generate_step(
prompt, model, draft_model, **kwargs
)
with wired_limit(model, [generation_stream]):
detokenizer.reset()
tic = time.perf_counter()
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
for n, (token, logprobs) in enumerate(token_generator):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
@@ -401,7 +565,7 @@ def stream_generate(
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
prompt: Union[str, List[int]],
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
@@ -412,7 +576,7 @@ def generate(
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`.
@@ -425,7 +589,6 @@ def generate(
)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
text = ""
for response in stream_generate(model, tokenizer, prompt, **kwargs):
@@ -464,6 +627,8 @@ def load_config(model_path: Path) -> dict:
def load_model(
model_path: Path,
lazy: bool = False,
strict: bool = True,
sequential_load: bool = False,
model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module:
@@ -475,6 +640,8 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
strict (bool): Whether or not to raise an exception if weights don't
match. Default: ``True``
model_config (dict, optional): Optional configuration parameters for the
model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
@@ -497,7 +664,7 @@ def load_model(
# Try weight for back-compat
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
if not weight_files:
if not weight_files and strict:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
@@ -531,9 +698,18 @@ def load_model(
class_predicate=class_predicate,
)
model.load_weights(list(weights.items()))
model.load_weights(list(weights.items()), strict=strict)
if mx.distributed.init().size() > 1:
if not hasattr(model, "shard"):
raise RuntimeError("Model doesn't support distributed inference.")
model.shard()
if not lazy:
weights.clear()
if sequential_load:
for layer in model.layers:
mx.eval(layer.parameters())
mx.eval(model.parameters())
model.eval()
@@ -546,6 +722,7 @@ def load(
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
sequential_load: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
@@ -558,9 +735,11 @@ def load(
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
lazy (bool): If ``False`` eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
sequential_load (bool): If True then load each layer sequentially to
ensure that we are not wasting memory.
Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
@@ -570,7 +749,7 @@ def load(
"""
model_path = get_model_path(path_or_hf_repo)
model, config = load_model(model_path, lazy)
model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
if adapter_path is not None:
model = load_adapters(model, adapter_path)
model.eval()
@@ -584,7 +763,7 @@ def load(
def fetch_from_hub(
model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model, config = load_model(model_path, lazy)
model, config = load_model(model_path, lazy=lazy)
tokenizer = load_tokenizer(
model_path, eos_token_ids=config.get("eos_token_id", None)
)
@@ -652,12 +831,12 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
model, tokenizer = load("{upload_repo}")
prompt="hello"
prompt = "hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
if tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
@@ -670,12 +849,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
api.upload_large_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")

View File

@@ -27,8 +27,8 @@ setup(
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
python_requires=">=3.8",
extras_require={
"testing": ["datasets"],
"evaluation": ["lm-eval"],
"test": ["datasets"],
"evaluate": ["lm-eval", "tqdm"],
},
entry_points={
"console_scripts": [

View File

@@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase):
data = {"text": "This is an example for the model."}
self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
train, valid, test = datasets.load_dataset(args, None)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0)
@@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase):
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
},
test=False,
train=True,

View File

@@ -21,7 +21,7 @@ from mlx_lm.tuner.utils import build_schedule
@contextmanager
def swapped_with_identity(obj, func):
old_func = getattr(obj, func)
setattr(obj, func, lambda x: x)
setattr(obj, func, lambda x, **kwargs: x)
yield
setattr(obj, func, old_func)

View File

@@ -682,6 +682,43 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek_v3(self):
from mlx_lm.models import deepseek_v3
args = deepseek_v3.ModelArgs(
model_type="deepseek_v3",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
n_routed_experts=4,
n_group=2,
topk_group=1,
num_experts_per_tok=2,
n_shared_experts=1,
kv_lora_rank=4,
q_lora_rank=4,
qk_rope_head_dim=32,
v_head_dim=16,
qk_nope_head_dim=32,
rope_scaling={
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
},
)
model = deepseek_v3.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma2(self):
from mlx_lm.models import gemma2
@@ -890,6 +927,23 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_internlm3(self):
from mlx_lm.models import internlm3
args = internlm3.ModelArgs(
model_type="internlm3",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = internlm3.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
if __name__ == "__main__":
unittest.main()

View File

@@ -28,6 +28,12 @@ class TestSampleUtils(unittest.TestCase):
token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (1, 2, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = top_p_sampling(logits, 0.5, temperature)
self.assertEqual(tokens.tolist(), [0, 1])
def test_min_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
@@ -42,6 +48,12 @@ class TestSampleUtils(unittest.TestCase):
token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = min_p_sampling(logits, 0.7)
self.assertEqual(tokens.tolist(), [0, 1])
def test_top_k_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)

View File

@@ -80,6 +80,29 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_handle_chat_completions_with_content_fragments(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
chat_post_data = {
"model": "chat_model",
"max_tokens": 10,
"temperature": 0.7,
"top_p": 0.85,
"repetition_penalty": 1.2,
"messages": [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
],
},
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
],
}
response = requests.post(url, json=chat_post_data)
response_body = response.text
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url)

View File

@@ -17,7 +17,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
self.config = args
self.custom_attribute = "This is a custom model"
def load_weights(self, weights):
def load_weights(self, weights, **kwargs):
self.qwenWeights = weights
class CustomQwenConfig: