mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
32 Commits
awq
...
distribute
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
65b792d7c0 | ||
|
|
617f9289b9 | ||
|
|
026362e0f8 | ||
|
|
a0ce0594f6 | ||
|
|
d77840207c | ||
|
|
e2e5478da5 | ||
|
|
21d0ab6e8a | ||
|
|
0989c073b0 | ||
|
|
d9924d08d1 | ||
|
|
9c2ef38d4d | ||
|
|
e8afb59de4 | ||
|
|
7a83077cd7 | ||
|
|
f44a52e2dc | ||
|
|
77faa14ba4 | ||
|
|
9a3ddc3e65 | ||
|
|
df1406735b | ||
|
|
07f88f8057 | ||
|
|
50f0a7f6d9 | ||
|
|
6ae6c72c2e | ||
|
|
c117af83b8 | ||
|
|
0228c46434 | ||
|
|
bf2da36fc6 | ||
|
|
514502da22 | ||
|
|
93c5cfd781 | ||
|
|
5cae0a60e6 | ||
|
|
40b88eff48 | ||
|
|
b8f0cacfa8 | ||
|
|
9183fe8b6d | ||
|
|
f2619f507c | ||
|
|
25ec2d8c44 | ||
|
|
c4833a2f55 | ||
|
|
3a58c36109 |
@@ -32,7 +32,7 @@ jobs:
|
||||
pip install --upgrade pip
|
||||
pip install unittest-xml-reporting
|
||||
cd llms/
|
||||
pip install -e ".[testing]"
|
||||
pip install -e ".[test]"
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
|
||||
@@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
|
||||
- Markus Enzweiler: Added the `cvae` examples.
|
||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||
- Shiyu Li: Added the `Segment Anything Model`.
|
||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.
|
||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.
|
||||
@@ -45,7 +45,7 @@ Some more useful examples are listed below.
|
||||
|
||||
### Hugging Face
|
||||
|
||||
Note: You can now directly download a few converted checkpoints from the [MLX
|
||||
You can directly use or download converted checkpoints from the [MLX
|
||||
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
||||
We encourage you to join the community and [contribute new
|
||||
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
||||
|
||||
@@ -58,7 +58,7 @@ prompt = "Write a story about Einstein"
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
text = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||
@@ -115,7 +115,7 @@ prompt = "Write a story about Einstein"
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
@@ -164,7 +164,7 @@ mlx_lm.convert \
|
||||
```
|
||||
|
||||
Models can also be converted and quantized directly in the
|
||||
[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
||||
[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
||||
Face Space.
|
||||
|
||||
### Long Prompts and Generations
|
||||
|
||||
@@ -241,14 +241,25 @@ Refer to the documentation for the model you are fine-tuning for more details.
|
||||
{"prompt": "What is the capital of France?", "completion": "Paris."}
|
||||
```
|
||||
|
||||
For the `completions` data format, a different key can be used for the prompt
|
||||
and completion by specifying the following in the YAML config:
|
||||
|
||||
```yaml
|
||||
prompt_feature: "input"
|
||||
completion_feature: "output"
|
||||
```
|
||||
|
||||
Here, `"input"` is the expected key instead of the default `"prompt"`, and
|
||||
`"output"` is the expected key instead of `"completion"`.
|
||||
|
||||
`text`:
|
||||
|
||||
```jsonl
|
||||
{"text": "This is an example for the model."}
|
||||
```
|
||||
|
||||
Note, the format is automatically determined by the dataset. Note also, keys in
|
||||
each line not expected by the loader will be ignored.
|
||||
Note, the format is automatically determined by the dataset. Note also, keys
|
||||
in each line not expected by the loader will be ignored.
|
||||
|
||||
> [!NOTE]
|
||||
> Each example in the datasets must be on a single line. Do not put more than
|
||||
@@ -270,7 +281,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM
|
||||
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
|
||||
example:
|
||||
|
||||
```
|
||||
```yaml
|
||||
hf_dataset:
|
||||
name: "billsum"
|
||||
prompt_feature: "text"
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.20.4"
|
||||
__version__ = "0.21.0"
|
||||
|
||||
@@ -110,29 +110,17 @@ def main():
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages, add_generation_prompt=False, continue_final_message=True
|
||||
)
|
||||
|
||||
# Treat the prompt as a prefix assuming that the suffix will be
|
||||
# provided at generation time.
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": "<query>"}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
|
||||
prompt = prompt[:-n]
|
||||
else:
|
||||
prompt = args.prompt
|
||||
prompt = tokenizer.encode(args.prompt)
|
||||
|
||||
cache = make_prompt_cache(model, args.max_kv_size)
|
||||
y = mx.array(tokenizer.encode(prompt))
|
||||
y = mx.array(prompt)
|
||||
|
||||
# Process the prompt
|
||||
start = time.time()
|
||||
|
||||
@@ -16,6 +16,25 @@ DEFAULT_MAX_TOKENS = 256
|
||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
|
||||
|
||||
def share_message(world, prompt):
|
||||
if world.size() == 1:
|
||||
return prompt
|
||||
|
||||
if world.rank() == 0:
|
||||
size = mx.array([len(prompt)])
|
||||
else:
|
||||
size = mx.array([0])
|
||||
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
|
||||
if size == 0:
|
||||
return []
|
||||
|
||||
if world.rank() == 0:
|
||||
prompt = mx.array(prompt)
|
||||
else:
|
||||
prompt = mx.array([0] * len(prompt))
|
||||
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
||||
@@ -54,6 +73,7 @@ def setup_arg_parser():
|
||||
|
||||
|
||||
def main():
|
||||
world = mx.distributed.init()
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -63,18 +83,30 @@ def main():
|
||||
args.model,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config={"trust_remote_code": True},
|
||||
sequential_load=mx.distributed.init().size() > 1,
|
||||
)
|
||||
|
||||
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
||||
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
||||
print(
|
||||
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
|
||||
flush=True,
|
||||
)
|
||||
world.barrier()
|
||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||
while True:
|
||||
query = input(">> ")
|
||||
if query == "q":
|
||||
if world.rank() == 0:
|
||||
query = input(">> ")
|
||||
if query == "q":
|
||||
prompt = []
|
||||
else:
|
||||
messages = [{"role": "user", "content": query}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
prompt = share_message(world, prompt)
|
||||
if len(prompt) == 0:
|
||||
break
|
||||
messages = [{"role": "user", "content": query}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
for response in stream_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
@@ -83,8 +115,10 @@ def main():
|
||||
sampler=make_sampler(args.temp, args.top_p),
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
print(response.text, flush=True, end="")
|
||||
print()
|
||||
if world.rank() == 0:
|
||||
print(response, flush=True, end="")
|
||||
if world.rank() == 0:
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
# Adapted from a PyTorch implementation by David Grangier
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Adapted from a PyTorch implementation by David Grangier
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
@@ -6,7 +10,7 @@ import logging
|
||||
import os
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import lm_eval
|
||||
import mlx.core as mx
|
||||
@@ -73,15 +77,19 @@ class MLXLM(LM):
|
||||
path_or_hf_repo: str,
|
||||
batch_size: int = 16,
|
||||
max_tokens: Optional[int] = None,
|
||||
use_chat_template: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._batch_size = batch_size
|
||||
self._model, self._tokenizer = load(path_or_hf_repo)
|
||||
self._max_tokens = max_tokens or self._tokenizer.model_max_length
|
||||
self._model, self.tokenizer = load(path_or_hf_repo)
|
||||
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
||||
self.use_chat_template = use_chat_template or (
|
||||
self.tokenizer.chat_template is not None
|
||||
)
|
||||
|
||||
def _score_fn(self, inputs, tokenize=True, step_size=32):
|
||||
if tokenize:
|
||||
inputs = self._tokenizer.encode(inputs)
|
||||
inputs = self._tokenize(inputs)
|
||||
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
|
||||
inputs = mx.array(inputs)
|
||||
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
||||
@@ -145,7 +153,12 @@ class MLXLM(LM):
|
||||
return results
|
||||
|
||||
def _tokenize(self, texts):
|
||||
return [tuple(self._tokenizer.encode(t)) for t in texts]
|
||||
return [
|
||||
tuple(
|
||||
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
|
||||
)
|
||||
for t in texts
|
||||
]
|
||||
|
||||
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
|
||||
"""Compute log-likelihood of generating a continuation from a context.
|
||||
@@ -217,6 +230,9 @@ class MLXLM(LM):
|
||||
)
|
||||
return [(r[0], r[1] == r[2]) for r in results]
|
||||
|
||||
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
||||
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
||||
|
||||
def loglikelihood_rolling(self, requests) -> list[float]:
|
||||
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
||||
- We will use the full max context length of the model.
|
||||
@@ -277,23 +293,16 @@ class MLXLM(LM):
|
||||
assert "until" in keys
|
||||
untils = [x["until"] for x in options]
|
||||
completions = []
|
||||
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
||||
if (
|
||||
hasattr(self._tokenizer, "apply_chat_template")
|
||||
and self._tokenizer.chat_template is not None
|
||||
):
|
||||
messages = [{"role": "user", "content": context}]
|
||||
context = self._tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
||||
context = self._tokenize(context)
|
||||
max_tokens = min(
|
||||
self._max_tokens,
|
||||
self._tokenizer.model_max_length - len(self._tokenizer.encode(context)),
|
||||
self.tokenizer.model_max_length - len(context),
|
||||
)
|
||||
text = ""
|
||||
for response in stream_generate(
|
||||
self._model, self._tokenizer, prompt=context, max_tokens=max_tokens
|
||||
self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
|
||||
):
|
||||
text += response.text
|
||||
if any(u in text for u in until):
|
||||
@@ -321,7 +330,28 @@ def main():
|
||||
type=int,
|
||||
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
default=1.0,
|
||||
help="Limit the number of examples per task.",
|
||||
type=float,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
|
||||
parser.add_argument(
|
||||
"--fewshot-as-multiturn",
|
||||
action="store_true",
|
||||
help="Whether to provide the fewshot examples as a multiturn "
|
||||
"conversation or a single user turn.",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply-chat-template",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Specifies whether to apply a chat template to the prompt. If "
|
||||
"the model has a chat template, this defaults to `True`, "
|
||||
"otherwise `False`.",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
@@ -332,12 +362,19 @@ def main():
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens)
|
||||
|
||||
lm = MLXLM(
|
||||
args.model,
|
||||
batch_size=args.batch_size,
|
||||
max_tokens=args.max_tokens,
|
||||
use_chat_template=args.apply_chat_template,
|
||||
)
|
||||
results = lm_eval.simple_evaluate(
|
||||
model=lm,
|
||||
tasks=args.tasks,
|
||||
fewshot_as_multiturn=args.fewshot_as_multiturn,
|
||||
apply_chat_template=lm.use_chat_template,
|
||||
num_fewshot=args.num_shots,
|
||||
limit=args.limit,
|
||||
random_seed=args.seed,
|
||||
numpy_random_seed=args.seed,
|
||||
torch_random_seed=args.seed,
|
||||
|
||||
@@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model)
|
||||
# User turn
|
||||
prompt = "Hi my name is <Name>."
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
@@ -32,9 +30,7 @@ response = generate(
|
||||
# User turn
|
||||
prompt = "What's my name?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
|
||||
@@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Transform the prompt into the chat template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation, tokenize=False, add_generation_prompt=True
|
||||
conversation=conversation, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Specify the maximum number of tokens
|
||||
|
||||
127
llms/mlx_lm/examples/pipeline_generate.py
Normal file
127
llms/mlx_lm/examples/pipeline_generate.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Run with:
|
||||
|
||||
```
|
||||
mlx.launch \
|
||||
--hostfile /path/to/hosts.txt \
|
||||
--backend mpi \
|
||||
/path/to/pipeline_generate.py \
|
||||
--prompt "hello world"
|
||||
```
|
||||
|
||||
Make sure you can run MLX over MPI on two hosts. For more information see the
|
||||
documentation:
|
||||
|
||||
https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx_lm import load, stream_generate
|
||||
from mlx_lm.utils import load_model, load_tokenizer
|
||||
|
||||
|
||||
def download(repo: str, allow_patterns: list[str]) -> Path:
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def shard_and_load(repo):
|
||||
# Get model path with everything but weight safetensors
|
||||
model_path = download(
|
||||
args.model,
|
||||
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
||||
)
|
||||
|
||||
# Lazy load and shard model to figure out
|
||||
# which weights we need
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
|
||||
group = mx.distributed.init(backend="mpi")
|
||||
rank = group.rank()
|
||||
model.model.pipeline(group)
|
||||
|
||||
# Figure out which files we need for the local shard
|
||||
with open(model_path / "model.safetensors.index.json", "r") as fid:
|
||||
weight_index = json.load(fid)["weight_map"]
|
||||
|
||||
local_files = set()
|
||||
for k, _ in tree_flatten(model.parameters()):
|
||||
local_files.add(weight_index[k])
|
||||
|
||||
# Download weights for local shard
|
||||
download(args.model, allow_patterns=local_files)
|
||||
|
||||
# Load and shard the model, and load the weights
|
||||
tokenizer = load_tokenizer(model_path)
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
model.model.pipeline(group)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Synchronize processes before generation to avoid timeout if downloading
|
||||
# model for the first time.
|
||||
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx-community/DeepSeek-R1-3bit",
|
||||
help="HF repo or path to local model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
default="Write a quicksort in C++.",
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
group = mx.distributed.init(backend="mpi")
|
||||
rank = group.rank()
|
||||
|
||||
def rprint(*args, **kwargs):
|
||||
if rank == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
model, tokenizer = shard_and_load(args.model)
|
||||
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
for response in stream_generate(
|
||||
model, tokenizer, prompt, max_tokens=args.max_tokens
|
||||
):
|
||||
rprint(response.text, end="", flush=True)
|
||||
|
||||
rprint()
|
||||
rprint("=" * 10)
|
||||
rprint(
|
||||
f"Prompt: {response.prompt_tokens} tokens, "
|
||||
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
rprint(
|
||||
f"Generation: {response.generation_tokens} tokens, "
|
||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||
@@ -43,10 +43,11 @@ def setup_arg_parser():
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eos-token",
|
||||
"--extra-eos-token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="End of sequence token for tokenizer",
|
||||
default=(),
|
||||
nargs="+",
|
||||
help="Add tokens in the list of eos tokens that stop generation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system-prompt",
|
||||
@@ -130,6 +131,18 @@ def setup_arg_parser():
|
||||
type=int,
|
||||
default=DEFAULT_QUANTIZED_KV_START,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--draft-model",
|
||||
type=str,
|
||||
help="A model to be used for speculative decoding.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-draft-tokens",
|
||||
type=int,
|
||||
help="Number of tokens to draft when using speculative decoding.",
|
||||
default=2,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -161,8 +174,6 @@ def main():
|
||||
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||
)
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
model_path = args.model
|
||||
if using_cache:
|
||||
@@ -180,7 +191,10 @@ def main():
|
||||
model_path,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
sequential_load=mx.distributed.init().size() > 1,
|
||||
)
|
||||
for eos_token in args.extra_eos_token:
|
||||
tokenizer.add_eos_token(eos_token)
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
@@ -190,10 +204,7 @@ def main():
|
||||
|
||||
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
||||
prompt = sys.stdin.read() if prompt == "-" else prompt
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
||||
if args.system_prompt is not None:
|
||||
messages = [{"role": "system", "content": args.system_prompt}]
|
||||
else:
|
||||
@@ -213,23 +224,40 @@ def main():
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompt = prompt[test_prompt.index("<query>") :]
|
||||
prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
else:
|
||||
prompt = tokenizer.encode(prompt)
|
||||
|
||||
if args.draft_model is not None:
|
||||
draft_model, draft_tokenizer = load(args.draft_model)
|
||||
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
|
||||
raise ValueError("Draft model tokenizer does not match model tokenizer.")
|
||||
else:
|
||||
draft_model = None
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
|
||||
world = mx.distributed.init()
|
||||
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
||||
world.barrier()
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
max_tokens=args.max_tokens,
|
||||
verbose=args.verbose,
|
||||
sampler=sampler,
|
||||
verbose=args.verbose and world.rank() == 0,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
kv_bits=args.kv_bits,
|
||||
kv_group_size=args.kv_group_size,
|
||||
quantized_kv_start=args.quantized_kv_start,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=args.num_draft_tokens,
|
||||
)
|
||||
if not args.verbose:
|
||||
|
||||
if not args.verbose and mx.distributed.init().rank() == 0:
|
||||
print(response)
|
||||
mx.synchronize()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
@@ -57,6 +58,8 @@ CONFIG_DEFAULTS = {
|
||||
"test": False,
|
||||
"test_batches": 500,
|
||||
"max_seq_length": 2048,
|
||||
"config": None,
|
||||
"grad_checkpoint": False,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
}
|
||||
@@ -66,6 +69,7 @@ def build_parser():
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
|
||||
@@ -88,7 +92,6 @@ def build_parser():
|
||||
"--fine-tune-type",
|
||||
type=str,
|
||||
choices=["lora", "dora", "full"],
|
||||
default="lora",
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -148,7 +151,7 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
default=None,
|
||||
type=str,
|
||||
help="A YAML configuration file with the training options",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -157,7 +160,7 @@ def build_parser():
|
||||
help="Use gradient checkpointing to reduce memory use.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -271,6 +274,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
|
||||
|
||||
def main():
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
config = args.config
|
||||
|
||||
@@ -6,19 +6,18 @@ from transformers.commands.user import tabulate
|
||||
|
||||
|
||||
def ask_for_confirmation(message: str) -> bool:
|
||||
"""Ask user for confirmation with Y/N prompt.
|
||||
Returns True for Y/yes, False for N/no/empty."""
|
||||
y = ("y", "yes", "1")
|
||||
n = ("n", "no", "0")
|
||||
all_values = y + n + ("",)
|
||||
full_message = f"{message} (Y/n) "
|
||||
n = ("n", "no", "0", "")
|
||||
full_message = f"{message} (y/n) "
|
||||
while True:
|
||||
answer = input(full_message).lower()
|
||||
if answer == "":
|
||||
return False
|
||||
if answer in y:
|
||||
return True
|
||||
if answer in n:
|
||||
return False
|
||||
print(f"Invalid input. Must be one of {all_values}")
|
||||
print(f"Invalid input. Must be one of: yes/no/y/n or empty for no")
|
||||
|
||||
|
||||
def main():
|
||||
@@ -43,9 +42,7 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.scan:
|
||||
print(
|
||||
"Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".'
|
||||
)
|
||||
print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".')
|
||||
hf_cache_info = scan_cache_dir()
|
||||
print(
|
||||
tabulate(
|
||||
@@ -86,35 +83,41 @@ def main():
|
||||
if args.pattern in repo.repo_id
|
||||
]
|
||||
if repos:
|
||||
print("\nFound the following models:")
|
||||
print(
|
||||
tabulate(
|
||||
rows=[
|
||||
[
|
||||
repo.repo_id,
|
||||
repo.size_on_disk_str, # Added size information
|
||||
str(repo.repo_path),
|
||||
]
|
||||
for repo in repos
|
||||
],
|
||||
headers=[
|
||||
"REPO ID",
|
||||
"SIZE", # Added size header
|
||||
"LOCAL PATH",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
confirmed = ask_for_confirmation(f"Confirm deletion ?")
|
||||
confirmed = ask_for_confirmation(
|
||||
"\nAre you sure you want to delete these models?"
|
||||
)
|
||||
if confirmed:
|
||||
for model_info in repos:
|
||||
print(f"\nDeleting {model_info.repo_id}...")
|
||||
for revision in sorted(
|
||||
model_info.revisions, key=lambda revision: revision.commit_hash
|
||||
):
|
||||
strategy = hf_cache_info.delete_revisions(revision.commit_hash)
|
||||
strategy.execute()
|
||||
print("Model(s) deleted.")
|
||||
print("\nModel(s) deleted successfully.")
|
||||
else:
|
||||
print("Deletion is cancelled. Do nothing.")
|
||||
print("\nDeletion cancelled - no changes made.")
|
||||
else:
|
||||
print(f"No models found.")
|
||||
print(f'No models found matching pattern "{args.pattern}"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -156,12 +156,13 @@ class CohereModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
if mask is None:
|
||||
j = self.args.sliding_window_pattern
|
||||
mask = create_attention_mask(h, cache[j - 1 : j])
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
|
||||
@@ -364,8 +364,30 @@ class DeepseekV2Model(nn.Module):
|
||||
DeepseekV2DecoderLayer(config, idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.start_idx = 0
|
||||
self.end_idx = len(self.layers)
|
||||
self.num_layers = self.end_idx
|
||||
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.pipeline_rank = 0
|
||||
self.pipeline_size = 1
|
||||
|
||||
def pipeline(self, group):
|
||||
# Split layers in reverse so rank=0 gets the last layers and
|
||||
# rank=pipeline_size-1 gets the first
|
||||
self.pipeline_rank = group.rank()
|
||||
self.pipeline_size = group.size()
|
||||
layers_per_rank = (
|
||||
len(self.layers) + self.pipeline_size - 1
|
||||
) // self.pipeline_size
|
||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.end_idx = self.start_idx + layers_per_rank
|
||||
self.num_layers = layers_per_rank
|
||||
self.layers = self.layers[: self.end_idx]
|
||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||
self.num_layers = len(self.layers) - self.start_idx
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
@@ -374,14 +396,31 @@ class DeepseekV2Model(nn.Module):
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
|
||||
pipeline_rank = self.pipeline_rank
|
||||
pipeline_size = self.pipeline_size
|
||||
# Hack to avoid time-outs during prompt-processing
|
||||
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
cache = [None] * self.num_layers
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
# Receive from the previous process in the pipeline
|
||||
if pipeline_rank < pipeline_size - 1:
|
||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||
|
||||
# Send to the next process in the pipeline
|
||||
if pipeline_rank != 0:
|
||||
h = mx.distributed.send(
|
||||
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||
)
|
||||
|
||||
# Broadcast h while keeping it in the graph
|
||||
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
@@ -418,4 +457,4 @@ class Model(nn.Module):
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||
|
||||
478
llms/mlx_lm/models/deepseek_v3.py
Normal file
478
llms/mlx_lm/models/deepseek_v3.py
Normal file
@@ -0,0 +1,478 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "deepseek_v3"
|
||||
vocab_size: int = 102400
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
moe_intermediate_size: int = 1407
|
||||
num_hidden_layers: int = 30
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_routed_experts: Optional[int] = None
|
||||
routed_scaling_factor: float = 1.0
|
||||
kv_lora_rank: int = 512
|
||||
q_lora_rank: int = 1536
|
||||
qk_rope_head_dim: int = 64
|
||||
v_head_dim: int = 128
|
||||
qk_nope_head_dim: int = 128
|
||||
topk_method: str = "noaux_tc"
|
||||
scoring_func: str = "sigmoid"
|
||||
norm_topk_prob: bool = True
|
||||
n_group: Optional[int] = None
|
||||
topk_group: Optional[int] = None
|
||||
num_experts_per_tok: Optional[int] = None
|
||||
moe_layer_freq: int = 1
|
||||
first_k_dense_replace: int = 0
|
||||
max_position_embeddings: int = 2048
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Dict = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
def yarn_find_correction_dim(
|
||||
num_rotations, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
|
||||
def yarn_find_correction_range(
|
||||
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
low = math.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
high = math.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_val, max_val, dim):
|
||||
if min_val == max_val:
|
||||
max_val += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
|
||||
return mx.clip(linear_func, 0, 1)
|
||||
|
||||
|
||||
class DeepseekV3YarnRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
scaling_factor=1.0,
|
||||
original_max_position_embeddings=4096,
|
||||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
mscale=1,
|
||||
mscale_all_dim=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
|
||||
scaling_factor, mscale_all_dim
|
||||
)
|
||||
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
||||
freq_inter = scaling_factor * base ** (
|
||||
mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||
)
|
||||
low, high = yarn_find_correction_range(
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
dim,
|
||||
base,
|
||||
original_max_position_embeddings,
|
||||
)
|
||||
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
||||
self._freqs = (freq_inter * freq_extra) / (
|
||||
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
|
||||
)
|
||||
|
||||
def __call__(self, x, offset=0):
|
||||
if self.mscale != 1.0:
|
||||
x = self.mscale * x
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
x.shape[-1],
|
||||
traditional=True,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
# A clipped silu to prevent fp16 from overflowing
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def clipped_silu(x):
|
||||
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
|
||||
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
|
||||
self.scale = self.q_head_dim**-0.5
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(
|
||||
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
|
||||
)
|
||||
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(
|
||||
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads
|
||||
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scale = self.scale * mscale * mscale
|
||||
|
||||
rope_kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in self.config.rope_scaling
|
||||
}
|
||||
self.rope = DeepseekV3YarnRotaryEmbedding(
|
||||
dim=self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
**rope_kwargs,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(x)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
|
||||
|
||||
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
|
||||
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
|
||||
compressed_kv = self.kv_a_proj_with_mqa(x)
|
||||
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
|
||||
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
||||
|
||||
if cache is not None:
|
||||
q_pe = self.rope(q_pe, cache.offset)
|
||||
k_pe = self.rope(k_pe, cache.offset)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys, values = cache.update_and_fetch(
|
||||
mx.concatenate([k_nope, k_pe], axis=-1), values
|
||||
)
|
||||
else:
|
||||
q_pe = self.rope(q_pe)
|
||||
k_pe = self.rope(k_pe)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
||||
|
||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
)
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
||||
|
||||
def __call__(self, x):
|
||||
gates = x @ self.weight.T
|
||||
|
||||
scores = mx.sigmoid(gates.astype(mx.float32))
|
||||
|
||||
assert self.topk_method == "noaux_tc", "Unsupported topk method."
|
||||
bsz, seq_len = x.shape[:2]
|
||||
scores = scores + self.e_score_correction_bias
|
||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
|
||||
k = self.n_group - self.topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
||||
scores = scores / denominator
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
return inds, scores
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
self.switch_mlp = SwitchGLU(
|
||||
config.hidden_size,
|
||||
config.moe_intermediate_size,
|
||||
config.n_routed_experts,
|
||||
activation=clipped_silu,
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV3MLP(
|
||||
config=config, intermediate_size=intermediate_size
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
inds, scores = self.gate(x)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(x)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV3Attention(config)
|
||||
self.mlp = (
|
||||
DeepseekV3MoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekV3MLP(config)
|
||||
)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
return h + r
|
||||
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [
|
||||
DeepseekV3DecoderLayer(config, idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.start_idx = 0
|
||||
self.end_idx = len(self.layers)
|
||||
self.num_layers = self.end_idx
|
||||
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pipeline_rank = 0
|
||||
self.pipeline_size = 1
|
||||
|
||||
def pipeline(self, group):
|
||||
# Split layers in reverse so rank=0 gets the last layers and
|
||||
# rank=pipeline_size-1 gets the first
|
||||
self.pipeline_rank = group.rank()
|
||||
self.pipeline_size = group.size()
|
||||
layers_per_rank = (
|
||||
len(self.layers) + self.pipeline_size - 1
|
||||
) // self.pipeline_size
|
||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.end_idx = self.start_idx + layers_per_rank
|
||||
self.layers = self.layers[: self.end_idx]
|
||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||
self.num_layers = len(self.layers) - self.start_idx
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
|
||||
pipeline_rank = self.pipeline_rank
|
||||
pipeline_size = self.pipeline_size
|
||||
# Hack to avoid time-outs during prompt-processing
|
||||
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * self.num_layers
|
||||
|
||||
# Receive from the previous process in the pipeline
|
||||
|
||||
if pipeline_rank < pipeline_size - 1:
|
||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||
|
||||
# Send to the next process in the pipeline
|
||||
if pipeline_rank != 0:
|
||||
h = mx.distributed.send(
|
||||
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||
)
|
||||
|
||||
# Broadcast h while keeping it in the graph
|
||||
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekV3Model(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
):
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
|
||||
# Remove multi-token prediction layer
|
||||
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||
@@ -145,16 +145,16 @@ class GPTBigCodeModel(nn.Module):
|
||||
hidden_states = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if hidden_states.shape[1] > 1:
|
||||
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
if mask is not None and hidden_states.shape[1] > 1:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
position_ids = mx.array(np.arange(L))
|
||||
else:
|
||||
position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L))
|
||||
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
hidden_states = layer(hidden_states, mask, cache=c)
|
||||
|
||||
185
llms/mlx_lm/models/helium.py
Normal file
185
llms/mlx_lm/models/helium.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
attention_bias: bool
|
||||
head_dim: int
|
||||
max_position_embeddings: int
|
||||
mlp_bias: bool
|
||||
model_type: str
|
||||
rope_theta: float
|
||||
tie_word_embeddings: bool
|
||||
|
||||
|
||||
class HeliumAttention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class HeliumMLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.intermediate_size = args.intermediate_size
|
||||
|
||||
self.gate_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
||||
)
|
||||
self.up_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
||||
)
|
||||
self.down_proj = nn.Linear(
|
||||
self.intermediate_size, self.hidden_size, bias=args.mlp_bias
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class HeliumDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
|
||||
self.self_attn = HeliumAttention(args)
|
||||
self.mlp = HeliumMLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class HeliumModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
self.vocab_size = args.vocab_size
|
||||
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
|
||||
self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
||||
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
|
||||
self.model = HeliumModel(args)
|
||||
|
||||
self.vocab_size = args.vocab_size
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
241
llms/mlx_lm/models/internlm3.py
Normal file
241
llms/mlx_lm/models/internlm3.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
bias: bool = False
|
||||
qkv_bias: bool = False
|
||||
max_position_embeddings: int = 32768
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "rope_type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["rope_type"] not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
"rope_scaling 'rope_type' currently only supports 'linear' or 'dynamic"
|
||||
)
|
||||
|
||||
|
||||
class DynamicNTKScalingRoPE(nn.Module):
|
||||
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
traditional: bool = False,
|
||||
base: float = 10000,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_base = base
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
self.scale = scale
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
seq_len = x.shape[1] + offset
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.original_base * (
|
||||
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||
) ** (self.dims / (self.dims - 2))
|
||||
else:
|
||||
base = self.original_base
|
||||
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=base,
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
qkv_bias = args.qkv_bias
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_kv_groups = n_heads // args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=qkv_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=qkv_bias)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None
|
||||
and args.rope_scaling["rope_type"] == "linear"
|
||||
else 2.0
|
||||
)
|
||||
|
||||
self.rope = DynamicNTKScalingRoPE(
|
||||
head_dim,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, bias):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size, args.bias)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class InternLM2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
assert args.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = InternLM2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -200,6 +200,36 @@ class Model(nn.Module):
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
def shard(self, group: Optional[mx.distributed.Group] = None):
|
||||
group = group or mx.distributed.init()
|
||||
|
||||
def all_to_sharded(l):
|
||||
if isinstance(l, nn.QuantizedLinear):
|
||||
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
|
||||
else:
|
||||
return nn.AllToShardedLinear.from_linear(l, group)
|
||||
|
||||
def sharded_to_all(l):
|
||||
if isinstance(l, nn.QuantizedLinear):
|
||||
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
|
||||
else:
|
||||
return nn.ShardedToAllLinear.from_linear(l, group)
|
||||
|
||||
N = group.size()
|
||||
for layer in self.model.layers:
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= N
|
||||
layer.self_attn.n_kv_heads //= N
|
||||
|
||||
# Shard the MLP
|
||||
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
# Copyright © 2024-2025 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
@@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
|
||||
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
||||
)
|
||||
|
||||
def ssm_step(self, x, state=None):
|
||||
A = -mx.exp(self.A_log)
|
||||
def ssm_step(self, x, A, state=None):
|
||||
D = self.D
|
||||
deltaBC = self.x_proj(x)
|
||||
delta, B, C = mx.split(
|
||||
deltaBC,
|
||||
indices_or_sections=[
|
||||
self.time_step_rank,
|
||||
self.time_step_rank + self.ssm_state_size,
|
||||
],
|
||||
axis=-1,
|
||||
delta, B, C = map(
|
||||
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
||||
mx.split(
|
||||
deltaBC,
|
||||
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
|
||||
axis=-1,
|
||||
),
|
||||
)
|
||||
if self.use_bcdt_rms:
|
||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||
@@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
|
||||
y = y + D * x
|
||||
return y, new_state
|
||||
|
||||
def __call__(self, x, cache):
|
||||
def _process_sequence(self, x, conv_cache, state_cache):
|
||||
B, T, D = x.shape
|
||||
if cache is None:
|
||||
cache = [None, None]
|
||||
xz = self.in_proj(x)
|
||||
x, z = xz.split(indices_or_sections=2, axis=-1)
|
||||
|
||||
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
|
||||
x = nn.silu(conv_out)
|
||||
|
||||
A = -mx.exp(self.A_log)
|
||||
|
||||
outputs = []
|
||||
current_state = state_cache
|
||||
y = []
|
||||
for t in range(T):
|
||||
xt = x[:, t, :]
|
||||
xz = self.in_proj(xt)
|
||||
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
|
||||
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
||||
x_t = conv_out.squeeze(1)
|
||||
x_t = nn.silu(x_t)
|
||||
y_t, cache[1] = self.ssm_step(x_t, cache[1])
|
||||
z_t = nn.silu(z_t)
|
||||
output_t = y_t * z_t
|
||||
output_t = self.out_proj(output_t)
|
||||
outputs.append(output_t)
|
||||
output = mx.stack(outputs, axis=1)
|
||||
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
|
||||
y.append(y_t)
|
||||
y = mx.stack(y, axis=1)
|
||||
z = self.out_proj(nn.silu(z) * y)
|
||||
return z, (new_conv_cache, current_state)
|
||||
|
||||
def __call__(self, x, cache):
|
||||
if cache is None:
|
||||
conv_cache, state_cache = None, None
|
||||
else:
|
||||
conv_cache, state_cache = cache[0], cache[1]
|
||||
|
||||
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
||||
x, conv_cache, state_cache
|
||||
)
|
||||
|
||||
if isinstance(cache, MambaCache):
|
||||
cache[0] = new_conv_cache
|
||||
cache[1] = new_state_cache
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
# Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
mlx>=0.19.2
|
||||
mlx>=0.22.0
|
||||
numpy
|
||||
transformers[sentencepiece]>=4.39.3
|
||||
protobuf
|
||||
|
||||
@@ -147,11 +147,11 @@ def min_p_sampling(
|
||||
logprobs = logprobs * (1 / temperature)
|
||||
|
||||
# Indices sorted in decreasing order
|
||||
sorted_indices = mx.argsort(-logprobs).squeeze(0)
|
||||
sorted_logprobs = logprobs[..., sorted_indices]
|
||||
sorted_indices = mx.argsort(-logprobs, axis=-1)
|
||||
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
|
||||
|
||||
# Top probability
|
||||
top_logprobs = logprobs[..., sorted_indices[0]]
|
||||
top_logprobs = sorted_logprobs[:, 0:1]
|
||||
|
||||
# Calculate the min_p threshold
|
||||
scaled_min_p = top_logprobs + math.log(min_p)
|
||||
@@ -163,9 +163,9 @@ def min_p_sampling(
|
||||
# Create pool of tokens with probability less than scaled min_p
|
||||
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
||||
|
||||
# Return sampled token
|
||||
sorted_token = mx.random.categorical(selected_logprobs)
|
||||
return sorted_indices[sorted_token]
|
||||
# Return sampled tokens
|
||||
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
|
||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
@@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
|
||||
# sort probs in ascending order
|
||||
sorted_indices = mx.argsort(probs, axis=-1)
|
||||
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
||||
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
|
||||
|
||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||
|
||||
@@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
0,
|
||||
)
|
||||
|
||||
sorted_token = mx.random.categorical(mx.log(top_probs))
|
||||
token = sorted_indices.squeeze(0)[sorted_token]
|
||||
|
||||
return token
|
||||
sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
|
||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
|
||||
@@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
return prompt.rstrip()
|
||||
|
||||
|
||||
def process_message_content(messages):
|
||||
"""
|
||||
Convert message content to a format suitable for `apply_chat_template`.
|
||||
|
||||
The function operates on messages in place. It converts the 'content' field
|
||||
to a string instead of a list of text fragments.
|
||||
|
||||
Args:
|
||||
message_list (list): A list of dictionaries, where each dictionary may
|
||||
have a 'content' key containing a list of dictionaries with 'type' and
|
||||
'text' keys.
|
||||
|
||||
Raises:
|
||||
ValueError: If the 'content' type is not supported or if 'text' is missing.
|
||||
|
||||
"""
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
text_fragments = [
|
||||
fragment["text"] for fragment in content if fragment["type"] == "text"
|
||||
]
|
||||
if len(text_fragments) != len(content):
|
||||
raise ValueError("Only 'text' content type is supported.")
|
||||
message["content"] = "".join(text_fragments)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptCache:
|
||||
cache: List[Any] = field(default_factory=list)
|
||||
@@ -590,14 +617,12 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
# Determine response type
|
||||
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
||||
if (
|
||||
hasattr(self.tokenizer, "apply_chat_template")
|
||||
and self.tokenizer.chat_template
|
||||
):
|
||||
if self.tokenizer.chat_template:
|
||||
messages = body["messages"]
|
||||
process_message_content(messages)
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
body["messages"],
|
||||
messages,
|
||||
body.get("tools", None),
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -266,6 +266,18 @@ class TokenizerWrapper:
|
||||
else {tokenizer.eos_token_id}
|
||||
)
|
||||
|
||||
def add_eos_token(self, token: str):
|
||||
token_id = None
|
||||
try:
|
||||
token_id = int(token)
|
||||
except ValueError:
|
||||
token_id = self._tokenizer.convert_tokens_to_ids(token)
|
||||
|
||||
if token_id is None:
|
||||
raise ValueError(f"'{token}' is not a token for this tokenizer")
|
||||
|
||||
self._eos_token_ids.add(token_id)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr == "detokenizer":
|
||||
return self._detokenizer
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@@ -10,41 +10,47 @@ class Dataset:
|
||||
Light-weight wrapper to hold a dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
|
||||
self._text_key = text_key
|
||||
self._data = data
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
text_key: str = "text",
|
||||
):
|
||||
self._data = [tokenizer.encode(d[text_key]) for d in data]
|
||||
for d in self._data:
|
||||
if d[-1] != tokenizer.eos_token_id:
|
||||
d.append(tokenizer.eos_token_id)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx][self._text_key]
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class ChatDataset(Dataset):
|
||||
class ChatDataset:
|
||||
"""
|
||||
A dataset for chat data in the format of {"messages": [...]}
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
||||
super().__init__(data)
|
||||
self._tokenizer = tokenizer
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
d["messages"],
|
||||
tools=d.get("tools", None),
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
messages = self._data[idx]["messages"]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=self._data[idx].get("tools", None),
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class CompletionsDataset(Dataset):
|
||||
class CompletionsDataset:
|
||||
"""
|
||||
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||
or using user-provided keys for prompt and completion values
|
||||
@@ -55,36 +61,41 @@ class CompletionsDataset(Dataset):
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt",
|
||||
completion_key: str = "completion",
|
||||
prompt_key: str,
|
||||
completion_key: str,
|
||||
):
|
||||
super().__init__(data)
|
||||
self._tokenizer = tokenizer
|
||||
self._prompt_key = prompt_key
|
||||
self._completion_key = completion_key
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[completion_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
data = self._data[idx]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": data[self._prompt_key]},
|
||||
{"role": "assistant", "content": data[self._completion_key]},
|
||||
],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
||||
def create_dataset(
|
||||
data,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
):
|
||||
prompt_feature = prompt_feature or "prompt"
|
||||
completion_feature = completion_feature or "completion"
|
||||
sample = data[0]
|
||||
|
||||
if "messages" in sample:
|
||||
return ChatDataset(data, tokenizer)
|
||||
elif "prompt" in sample and "completion" in sample:
|
||||
return CompletionsDataset(data, tokenizer)
|
||||
elif prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
elif "text" in sample:
|
||||
return Dataset(data)
|
||||
return Dataset(data, tokenizer)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported data format, check the supported formats here:\n"
|
||||
@@ -92,20 +103,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
||||
)
|
||||
|
||||
|
||||
def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
|
||||
def load_local_dataset(
|
||||
data_path: Path,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
):
|
||||
def load_subset(path):
|
||||
if not path.exists():
|
||||
return []
|
||||
with open(path, "r") as fid:
|
||||
data = [json.loads(l) for l in fid]
|
||||
return create_dataset(data, tokenizer)
|
||||
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||
return train, valid, test
|
||||
|
||||
|
||||
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
|
||||
def load_hf_dataset(
|
||||
data_id: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
):
|
||||
from datasets import exceptions, load_dataset
|
||||
|
||||
try:
|
||||
@@ -114,7 +135,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
|
||||
names = ("train", "valid", "test")
|
||||
|
||||
train, valid, test = [
|
||||
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
|
||||
(
|
||||
create_dataset(
|
||||
dataset[n], tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
if n in dataset.keys()
|
||||
else []
|
||||
)
|
||||
for n in names
|
||||
]
|
||||
|
||||
@@ -143,7 +170,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
if prompt_feature and completion_feature:
|
||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||
elif text_feature:
|
||||
return Dataset(train_ds, text_key=text_feature)
|
||||
return Dataset(ds, tokenizer, text_key=text_feature)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Specify either a prompt and completion feature or a text "
|
||||
@@ -166,15 +193,22 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
|
||||
|
||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
if getattr(args, "hf_dataset", None) is not None:
|
||||
if getattr(args, "hf_dataset", False):
|
||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||
else:
|
||||
data_path = Path(args.data)
|
||||
|
||||
prompt_feature = getattr(args, "prompt_feature", None)
|
||||
completion_feature = getattr(args, "completion_feature", None)
|
||||
if data_path.exists():
|
||||
train, valid, test = load_local_dataset(data_path, tokenizer)
|
||||
train, valid, test = load_local_dataset(
|
||||
data_path, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
else:
|
||||
print(f"Loading Hugging Face dataset {args.data}.")
|
||||
train, valid, test = load_hf_dataset(args.data, tokenizer)
|
||||
train, valid, test = load_hf_dataset(
|
||||
args.data, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
|
||||
@@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
# Encode batch
|
||||
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||
for b in batch:
|
||||
if b[-1] != tokenizer.eos_token_id:
|
||||
b.append(tokenizer.eos_token_id)
|
||||
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||
@@ -146,8 +140,8 @@ def evaluate(
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
):
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
all_losses = mx.array(0.0)
|
||||
ntokens = mx.array(0)
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
@@ -165,8 +159,8 @@ def evaluate(
|
||||
ntokens += toks
|
||||
mx.eval(all_losses, ntokens)
|
||||
|
||||
all_losses = mx.distributed.all_sum(all_losses)
|
||||
ntokens = mx.distributed.all_sum(ntokens)
|
||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||
|
||||
return (all_losses / ntokens).item()
|
||||
|
||||
@@ -278,9 +272,9 @@ def train(
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = mx.distributed.all_sum(losses).item()
|
||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||
train_loss /= steps * mx.distributed.init().size()
|
||||
n_tokens = mx.distributed.all_sum(n_tokens).item()
|
||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
|
||||
@@ -94,12 +94,14 @@ def linear_to_lora_layers(
|
||||
"phimoe",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"helium",
|
||||
"starcoder2",
|
||||
"cohere",
|
||||
"cohere2",
|
||||
"minicpm",
|
||||
"deepseek",
|
||||
"olmo2",
|
||||
"internlm3",
|
||||
]:
|
||||
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
||||
if model.model_type in ["mixtral", "phimoe"]:
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@@ -15,7 +17,17 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
|
||||
try:
|
||||
from modelscope import snapshot_download
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please run `pip install modelscope` to activate the ModelScope."
|
||||
)
|
||||
else:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from mlx.utils import tree_flatten, tree_reduce
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@@ -153,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
||||
Path: The path to the model.
|
||||
"""
|
||||
model_path = Path(path_or_hf_repo)
|
||||
|
||||
if not model_path.exists():
|
||||
try:
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=[
|
||||
"*.json",
|
||||
@@ -207,12 +220,6 @@ def generate_step(
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
||||
temp: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
min_p: Optional[float] = None,
|
||||
min_tokens_to_keep: Optional[int] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@@ -256,25 +263,17 @@ def generate_step(
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
|
||||
print(
|
||||
"[Warning] Specifying sampling arguments to ``generate_step`` is "
|
||||
"deprecated. Pass in a ``sampler`` instead."
|
||||
)
|
||||
if repetition_penalty is not None:
|
||||
print(
|
||||
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
|
||||
"Pass in ``logits_processors`` instead."
|
||||
)
|
||||
|
||||
sampler = sampler or make_sampler(
|
||||
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
|
||||
)
|
||||
logits_processors = logits_processors or make_logits_processors(
|
||||
None, repetition_penalty, repetition_context_size or 20
|
||||
)
|
||||
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
||||
|
||||
def _step(y):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
@@ -287,9 +286,7 @@ def generate_step(
|
||||
for processor in logits_processors:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = sampler(logprobs)
|
||||
@@ -300,9 +297,7 @@ def generate_step(
|
||||
prompt_processed_tokens = 0
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
||||
prompt_processed_tokens += prefill_step_size
|
||||
@@ -311,12 +306,12 @@ def generate_step(
|
||||
|
||||
y, logprobs = _step(y)
|
||||
|
||||
mx.async_eval(y, logprobs)
|
||||
mx.eval(y, logprobs)
|
||||
n = 0
|
||||
while True:
|
||||
if n != max_tokens:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y, next_logprobs)
|
||||
mx.eval(next_y, next_logprobs)
|
||||
if n == 0:
|
||||
mx.eval(y)
|
||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||
@@ -329,10 +324,163 @@ def generate_step(
|
||||
n += 1
|
||||
|
||||
|
||||
def speculative_generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
draft_model: nn.Module,
|
||||
*,
|
||||
num_draft_tokens=2,
|
||||
max_tokens: int = 256,
|
||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
draft_model (nn.Module): The draft model for speculative decoding.
|
||||
num_draft_tokens (int, optional): The number of draft tokens for
|
||||
speculative decoding. Default: ``2``.
|
||||
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||
generator. Default: ``256``.
|
||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||
token from a vector of log probabilities. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place. The cache must be trimmable.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
y = prompt
|
||||
tokens = None
|
||||
|
||||
# Create the KV cache for generation
|
||||
if prompt_cache is None:
|
||||
model_cache = cache.make_prompt_cache(model)
|
||||
draft_cache = cache.make_prompt_cache(draft_model)
|
||||
elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
else:
|
||||
model_cache = prompt_cache[: len(model.layers)]
|
||||
draft_cache = prompt_cache[len(model.layers) :]
|
||||
|
||||
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
def _step(model, cache, y, n_predict=1):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=cache)
|
||||
logits = logits[:, -n_predict:, :]
|
||||
|
||||
quantize_cache_fn(cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
logprobs = logprobs.squeeze(0)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs
|
||||
|
||||
def _prefill(model, cache, y):
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=cache)
|
||||
quantize_cache_fn(cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
y = y[prefill_step_size:]
|
||||
mx.metal.clear_cache()
|
||||
return y
|
||||
|
||||
def _rewind_cache(num_draft, num_accept):
|
||||
cache.trim_prompt_cache(model_cache, num_draft - num_accept)
|
||||
cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0))
|
||||
|
||||
def _draft_generate(y, num_draft):
|
||||
if num_draft == 0:
|
||||
return mx.array([], mx.uint32)
|
||||
ys = []
|
||||
for _ in range(num_draft):
|
||||
y, _ = _step(draft_model, draft_cache, y)
|
||||
mx.async_eval(y)
|
||||
ys.append(y)
|
||||
return mx.concatenate(ys)
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
draft_y = _prefill(draft_model, draft_cache, y)
|
||||
y = _prefill(model, model_cache, y)
|
||||
|
||||
ntoks = 0
|
||||
# Set these so the finally block doesn't raise
|
||||
num_draft = 0
|
||||
n = 0
|
||||
try:
|
||||
while True:
|
||||
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
||||
draft_tokens = _draft_generate(draft_y, num_draft)
|
||||
y = mx.concatenate([y, draft_tokens])
|
||||
|
||||
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
|
||||
mx.eval(tokens, draft_tokens)
|
||||
draft_tokens = draft_tokens.tolist()
|
||||
tokens = tokens.tolist()
|
||||
n = 0
|
||||
while n < num_draft:
|
||||
tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n]
|
||||
if tn != dtn:
|
||||
break
|
||||
n += 1
|
||||
ntoks += 1
|
||||
yield tn, lpn
|
||||
if ntoks == max_tokens:
|
||||
break
|
||||
if ntoks < max_tokens:
|
||||
ntoks += 1
|
||||
yield tokens[n], logprobs[n]
|
||||
|
||||
if ntoks == max_tokens:
|
||||
break
|
||||
|
||||
y = mx.array([tokens[n]], mx.uint32)
|
||||
draft_y = y
|
||||
|
||||
# If we accpeted all the draft tokens, include the last
|
||||
# draft token in the next draft step since it hasn't been
|
||||
# processed yet by the draft model
|
||||
if n == num_draft:
|
||||
draft_y = mx.concatenate(
|
||||
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
|
||||
)
|
||||
|
||||
_rewind_cache(num_draft, n)
|
||||
finally:
|
||||
_rewind_cache(num_draft, n)
|
||||
|
||||
|
||||
def stream_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: Union[str, mx.array, List[int]],
|
||||
draft_model: Optional[nn.Module] = None,
|
||||
**kwargs,
|
||||
) -> Generator[GenerationResponse, None, None]:
|
||||
"""
|
||||
@@ -341,7 +489,11 @@ def stream_generate(
|
||||
Args:
|
||||
model (nn.Module): The model to use for generation.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
|
||||
prompt (Union[str, mx.array, List[int]]): The input prompt string or
|
||||
integer tokens.
|
||||
draft_model (Optional[nn.Module]): An optional draft model. If provided
|
||||
then speculative decoding is used. The draft model must use the same
|
||||
tokenizer as the main model. Default: ``None``.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
|
||||
@@ -353,16 +505,28 @@ def stream_generate(
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
if not isinstance(prompt, mx.array):
|
||||
prompt = mx.array(
|
||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||
)
|
||||
if isinstance(prompt, str):
|
||||
# Try to infer if special tokens are needed
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
if draft_model is None:
|
||||
kwargs.pop("num_draft_tokens", None)
|
||||
token_generator = generate_step(prompt, model, **kwargs)
|
||||
else:
|
||||
kwargs.pop("max_kv_size", None)
|
||||
token_generator = speculative_generate_step(
|
||||
prompt, model, draft_model, **kwargs
|
||||
)
|
||||
with wired_limit(model, [generation_stream]):
|
||||
detokenizer.reset()
|
||||
tic = time.perf_counter()
|
||||
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
|
||||
for n, (token, logprobs) in enumerate(token_generator):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
@@ -401,7 +565,7 @@ def stream_generate(
|
||||
def generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
prompt: Union[str, List[int]],
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
@@ -412,7 +576,7 @@ def generate(
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (str): The string prompt.
|
||||
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||
@@ -425,7 +589,6 @@ def generate(
|
||||
)
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
text = ""
|
||||
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
||||
@@ -464,6 +627,8 @@ def load_config(model_path: Path) -> dict:
|
||||
def load_model(
|
||||
model_path: Path,
|
||||
lazy: bool = False,
|
||||
strict: bool = True,
|
||||
sequential_load: bool = False,
|
||||
model_config: dict = {},
|
||||
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
||||
) -> nn.Module:
|
||||
@@ -475,6 +640,8 @@ def load_model(
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
strict (bool): Whether or not to raise an exception if weights don't
|
||||
match. Default: ``True``
|
||||
model_config (dict, optional): Optional configuration parameters for the
|
||||
model. Defaults to an empty dictionary.
|
||||
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||
@@ -497,7 +664,7 @@ def load_model(
|
||||
# Try weight for back-compat
|
||||
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
||||
|
||||
if not weight_files:
|
||||
if not weight_files and strict:
|
||||
logging.error(f"No safetensors found in {model_path}")
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
|
||||
@@ -531,9 +698,18 @@ def load_model(
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
|
||||
if mx.distributed.init().size() > 1:
|
||||
if not hasattr(model, "shard"):
|
||||
raise RuntimeError("Model doesn't support distributed inference.")
|
||||
model.shard()
|
||||
|
||||
if not lazy:
|
||||
weights.clear()
|
||||
if sequential_load:
|
||||
for layer in model.layers:
|
||||
mx.eval(layer.parameters())
|
||||
mx.eval(model.parameters())
|
||||
|
||||
model.eval()
|
||||
@@ -546,6 +722,7 @@ def load(
|
||||
model_config={},
|
||||
adapter_path: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
sequential_load: bool = False,
|
||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||
"""
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
@@ -558,9 +735,11 @@ def load(
|
||||
Defaults to an empty dictionary.
|
||||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
||||
to the model. Default: ``None``.
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
sequential_load (bool): If True then load each layer sequentially to
|
||||
ensure that we are not wasting memory.
|
||||
Returns:
|
||||
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
||||
|
||||
@@ -570,7 +749,7 @@ def load(
|
||||
"""
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
model, config = load_model(model_path, lazy)
|
||||
model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
|
||||
if adapter_path is not None:
|
||||
model = load_adapters(model, adapter_path)
|
||||
model.eval()
|
||||
@@ -584,7 +763,7 @@ def load(
|
||||
def fetch_from_hub(
|
||||
model_path: Path, lazy: bool = False
|
||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||
model, config = load_model(model_path, lazy)
|
||||
model, config = load_model(model_path, lazy=lazy)
|
||||
tokenizer = load_tokenizer(
|
||||
model_path, eos_token_ids=config.get("eos_token_id", None)
|
||||
)
|
||||
@@ -652,12 +831,12 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
|
||||
model, tokenizer = load("{upload_repo}")
|
||||
|
||||
prompt="hello"
|
||||
prompt = "hello"
|
||||
|
||||
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
|
||||
if tokenizer.chat_template is not None:
|
||||
messages = [{{"role": "user", "content": prompt}}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||
@@ -670,12 +849,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||
api.upload_folder(
|
||||
api.upload_large_folder(
|
||||
folder_path=path,
|
||||
repo_id=upload_repo,
|
||||
repo_type="model",
|
||||
multi_commits=True,
|
||||
multi_commits_verbose=True,
|
||||
)
|
||||
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ setup(
|
||||
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
|
||||
python_requires=">=3.8",
|
||||
extras_require={
|
||||
"testing": ["datasets"],
|
||||
"evaluation": ["lm-eval"],
|
||||
"test": ["datasets"],
|
||||
"evaluate": ["lm-eval", "tqdm"],
|
||||
},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
|
||||
@@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase):
|
||||
data = {"text": "This is an example for the model."}
|
||||
self.save_data(4 * [data])
|
||||
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||
train, valid, test = datasets.load_dataset(args, None)
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||
self.assertEqual(len(train), 4)
|
||||
self.assertEqual(len(valid), 4)
|
||||
self.assertEqual(len(test), 0)
|
||||
@@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase):
|
||||
"name": "billsum",
|
||||
"prompt_feature": "text",
|
||||
"completion_feature": "summary",
|
||||
"train_split": "train[:2%]",
|
||||
"valid_split": "train[-2%:]",
|
||||
},
|
||||
test=False,
|
||||
train=True,
|
||||
|
||||
@@ -21,7 +21,7 @@ from mlx_lm.tuner.utils import build_schedule
|
||||
@contextmanager
|
||||
def swapped_with_identity(obj, func):
|
||||
old_func = getattr(obj, func)
|
||||
setattr(obj, func, lambda x: x)
|
||||
setattr(obj, func, lambda x, **kwargs: x)
|
||||
yield
|
||||
setattr(obj, func, old_func)
|
||||
|
||||
|
||||
@@ -682,6 +682,43 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_deepseek_v3(self):
|
||||
from mlx_lm.models import deepseek_v3
|
||||
|
||||
args = deepseek_v3.ModelArgs(
|
||||
model_type="deepseek_v3",
|
||||
vocab_size=1024,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
n_routed_experts=4,
|
||||
n_group=2,
|
||||
topk_group=1,
|
||||
num_experts_per_tok=2,
|
||||
n_shared_experts=1,
|
||||
kv_lora_rank=4,
|
||||
q_lora_rank=4,
|
||||
qk_rope_head_dim=32,
|
||||
v_head_dim=16,
|
||||
qk_nope_head_dim=32,
|
||||
rope_scaling={
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"factor": 40,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn",
|
||||
},
|
||||
)
|
||||
model = deepseek_v3.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gemma2(self):
|
||||
from mlx_lm.models import gemma2
|
||||
|
||||
@@ -890,6 +927,23 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_internlm3(self):
|
||||
from mlx_lm.models import internlm3
|
||||
|
||||
args = internlm3.ModelArgs(
|
||||
model_type="internlm3",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = internlm3.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -28,6 +28,12 @@ class TestSampleUtils(unittest.TestCase):
|
||||
token = top_p_sampling(logits, 0.95, temperature).item()
|
||||
self.assertTrue(token in (1, 2, 3))
|
||||
|
||||
# Batch mode works
|
||||
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
||||
logits = mx.log(probs)
|
||||
tokens = top_p_sampling(logits, 0.5, temperature)
|
||||
self.assertEqual(tokens.tolist(), [0, 1])
|
||||
|
||||
def test_min_p_sampling(self):
|
||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||
logits = mx.log(probs)
|
||||
@@ -42,6 +48,12 @@ class TestSampleUtils(unittest.TestCase):
|
||||
token = min_p_sampling(logits, 0.05)
|
||||
self.assertTrue(token in (0, 3))
|
||||
|
||||
# Batch mode works
|
||||
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
||||
logits = mx.log(probs)
|
||||
tokens = min_p_sampling(logits, 0.7)
|
||||
self.assertEqual(tokens.tolist(), [0, 1])
|
||||
|
||||
def test_top_k_sampling(self):
|
||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||
logits = mx.log(probs)
|
||||
|
||||
@@ -80,6 +80,29 @@ class TestServer(unittest.TestCase):
|
||||
self.assertIn("id", response_body)
|
||||
self.assertIn("choices", response_body)
|
||||
|
||||
def test_handle_chat_completions_with_content_fragments(self):
|
||||
url = f"http://localhost:{self.port}/v1/chat/completions"
|
||||
chat_post_data = {
|
||||
"model": "chat_model",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.85,
|
||||
"repetition_penalty": 1.2,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
|
||||
],
|
||||
}
|
||||
response = requests.post(url, json=chat_post_data)
|
||||
response_body = response.text
|
||||
self.assertIn("id", response_body)
|
||||
self.assertIn("choices", response_body)
|
||||
|
||||
def test_handle_models(self):
|
||||
url = f"http://localhost:{self.port}/v1/models"
|
||||
response = requests.get(url)
|
||||
|
||||
@@ -17,7 +17,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
|
||||
self.config = args
|
||||
self.custom_attribute = "This is a custom model"
|
||||
|
||||
def load_weights(self, weights):
|
||||
def load_weights(self, weights, **kwargs):
|
||||
self.qwenWeights = weights
|
||||
|
||||
class CustomQwenConfig:
|
||||
|
||||
Reference in New Issue
Block a user