Merge branch 'ml-explore:main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez
2025-01-13 20:16:04 +01:00
committed by GitHub
59 changed files with 1272 additions and 279 deletions

View File

@@ -32,7 +32,7 @@ jobs:
pip install --upgrade pip
pip install unittest-xml-reporting
cd llms/
pip install -e ".[testing]"
pip install -e ".[test]"
- run:
name: Run Python tests
command: |

View File

@@ -58,7 +58,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=True
)
text = generate(model, tokenizer, prompt=prompt, verbose=True)
@@ -115,7 +115,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=True
)
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):

View File

@@ -266,14 +266,25 @@ Refer to the documentation for the model you are fine-tuning for more details.
{"prompt": "What is the capital of France?", "completion": "Paris."}
```
For the `completions` data format, a different key can be used for the prompt
and completion by specifying the following in the YAML config:
```yaml
prompt_feature: "input"
completion_feature: "output"
```
Here, `"input"` is the expected key instead of the default `"prompt"`, and
`"output"` is the expected key instead of `"completion"`.
`text`:
```jsonl
{"text": "This is an example for the model."}
```
Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored.
Note, the format is automatically determined by the dataset. Note also, keys
in each line not expected by the loader will be ignored.
> [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than
@@ -295,7 +306,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example:
```
```yaml
hf_dataset:
name: "billsum"
prompt_feature: "text"

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.20.4"
__version__ = "0.21.0"

View File

@@ -110,29 +110,17 @@ def main():
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
if not args.ignore_chat_template and tokenizer.chat_template is not None:
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=False, continue_final_message=True
)
# Treat the prompt as a prefix assuming that the suffix will be
# provided at generation time.
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
prompt = prompt[:-n]
else:
prompt = args.prompt
prompt = tokenizer.encode(args.prompt)
cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
y = mx.array(prompt)
# Process the prompt
start = time.time()

View File

@@ -72,9 +72,7 @@ def main():
if query == "q":
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model,
tokenizer,

View File

@@ -1,4 +1,8 @@
# Adapted from a PyTorch implementation by David Grangier
# Copyright © 2024 Apple Inc.
"""
Adapted from a PyTorch implementation by David Grangier
"""
import argparse
import json
@@ -6,7 +10,7 @@ import logging
import os
from importlib.metadata import version
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import lm_eval
import mlx.core as mx
@@ -73,15 +77,19 @@ class MLXLM(LM):
path_or_hf_repo: str,
batch_size: int = 16,
max_tokens: Optional[int] = None,
use_chat_template: Optional[bool] = None,
) -> None:
super().__init__()
self._batch_size = batch_size
self._model, self._tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self._tokenizer.model_max_length
self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template or (
self.tokenizer.chat_template is not None
)
def _score_fn(self, inputs, tokenize=True, step_size=32):
if tokenize:
inputs = self._tokenizer.encode(inputs)
inputs = self._tokenize(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:]
@@ -145,7 +153,12 @@ class MLXLM(LM):
return results
def _tokenize(self, texts):
return [tuple(self._tokenizer.encode(t)) for t in texts]
return [
tuple(
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
)
for t in texts
]
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
@@ -217,6 +230,9 @@ class MLXLM(LM):
)
return [(r[0], r[1] == r[2]) for r in results]
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
@@ -277,23 +293,16 @@ class MLXLM(LM):
assert "until" in keys
untils = [x["until"] for x in options]
completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
if (
hasattr(self._tokenizer, "apply_chat_template")
and self._tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": context}]
context = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
context = self._tokenize(context)
max_tokens = min(
self._max_tokens,
self._tokenizer.model_max_length - len(self._tokenizer.encode(context)),
self.tokenizer.model_max_length - len(context),
)
text = ""
for response in stream_generate(
self._model, self._tokenizer, prompt=context, max_tokens=max_tokens
self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
):
text += response.text
if any(u in text for u in until):
@@ -321,7 +330,28 @@ def main():
type=int,
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
)
parser.add_argument(
"--limit",
default=1.0,
help="Limit the number of examples per task.",
type=float,
)
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
parser.add_argument(
"--fewshot-as-multiturn",
action="store_true",
help="Whether to provide the fewshot examples as a multiturn "
"conversation or a single user turn.",
default=False,
)
parser.add_argument(
"--apply-chat-template",
action=argparse.BooleanOptionalAction,
help="Specifies whether to apply a chat template to the prompt. If "
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
)
args = parser.parse_args()
output_dir = Path(args.output_dir)
@@ -332,12 +362,19 @@ def main():
mx.random.seed(args.seed)
lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens)
lm = MLXLM(
args.model,
batch_size=args.batch_size,
max_tokens=args.max_tokens,
use_chat_template=args.apply_chat_template,
)
results = lm_eval.simple_evaluate(
model=lm,
tasks=args.tasks,
fewshot_as_multiturn=args.fewshot_as_multiturn,
apply_chat_template=lm.use_chat_template,
num_fewshot=args.num_shots,
limit=args.limit,
random_seed=args.seed,
numpy_random_seed=args.seed,
torch_random_seed=args.seed,

View File

@@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model)
# User turn
prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Assistant response
response = generate(
@@ -32,9 +30,7 @@ response = generate(
# User turn
prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Assistant response
response = generate(

View File

@@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}]
# Transform the prompt into the chat template
prompt = tokenizer.apply_chat_template(
conversation=conversation, tokenize=False, add_generation_prompt=True
conversation=conversation, add_generation_prompt=True
)
# Specify the maximum number of tokens

View File

@@ -0,0 +1,75 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
```
/path/to/mpirun \
-np 2 \
--hostfile /path/to/hosts.txt \
python /path/to/pipeline_generate.py --prompt "hello world"
```
Make sure you can run MLX over MPI on two hosts. For more information see the
documentation:
https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
"""
import argparse
import mlx.core as mx
from mlx_lm import load, stream_generate
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
model_repo = "mlx-community/DeepSeek-V3-3bit"
model, tokenizer = load(model_repo, lazy=True)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
group = mx.distributed.init()
rank = group.rank()
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
def rprint(*args, **kwargs):
if rank == 0:
print(*args, **kwargs)
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens):
rprint(response.text, end="", flush=True)
rprint()
rprint("=" * 10)
rprint(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
rprint(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
rprint(f"Peak memory: {response.peak_memory:.3f} GB")

View File

@@ -1,7 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import codecs
import json
import sys
@@ -44,10 +43,11 @@ def setup_arg_parser():
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--eos-token",
"--extra-eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
default=(),
nargs="+",
help="Add tokens in the list of eos tokens that stop generation.",
)
parser.add_argument(
"--system-prompt",
@@ -131,6 +131,18 @@ def setup_arg_parser():
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
parser.add_argument(
"--draft-model",
type=str,
help="A model to be used for speculative decoding.",
default=None,
)
parser.add_argument(
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
default=2,
)
return parser
@@ -162,8 +174,6 @@ def main():
{} if not using_cache else json.loads(metadata["tokenizer_config"])
)
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model_path = args.model
if using_cache:
@@ -182,6 +192,8 @@ def main():
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
@@ -189,22 +201,14 @@ def main():
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
prompt = codecs.decode(args.prompt, "unicode_escape")
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt
if not args.ignore_chat_template and tokenizer.chat_template is not None:
if args.system_prompt is not None:
messages = [{"role": "system", "content": args.system_prompt}]
else:
messages = []
messages.append(
{
"role": "user",
"content": sys.stdin.read() if prompt == "-" else prompt,
}
)
messages.append({"role": "user", "content": prompt})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
@@ -219,7 +223,16 @@ def main():
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
prompt = tokenizer.encode(prompt)
if args.draft_model is not None:
draft_model, draft_tokenizer = load(args.draft_model)
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
raise ValueError("Draft model tokenizer does not match model tokenizer.")
else:
draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate(
model,
@@ -233,6 +246,8 @@ def main():
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
)
if not args.verbose:
print(response)

View File

@@ -2,6 +2,7 @@
import argparse
import math
import os
import re
import types
from pathlib import Path
@@ -57,6 +58,8 @@ CONFIG_DEFAULTS = {
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
"config": None,
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
}
@@ -66,6 +69,7 @@ def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
)
@@ -74,7 +78,6 @@ def build_parser():
"--train",
action="store_true",
help="Do training",
default=None,
)
parser.add_argument(
"--data",
@@ -88,7 +91,6 @@ def build_parser():
"--fine-tune-type",
type=str,
choices=["lora", "dora", "full"],
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
@@ -133,7 +135,6 @@ def build_parser():
"--test",
action="store_true",
help="Evaluate on the test set after training",
default=None,
)
parser.add_argument(
"--test-batches",
@@ -148,16 +149,15 @@ def build_parser():
parser.add_argument(
"-c",
"--config",
default=None,
type=str,
help="A YAML configuration file with the training options",
)
parser.add_argument(
"--grad-checkpoint",
action="store_true",
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser
@@ -271,6 +271,7 @@ def run(args, training_callback: TrainingCallback = None):
def main():
os.environ["TOKENIZERS_PARALLELISM"] = "true"
parser = build_parser()
args = parser.parse_args()
config = args.config

View File

@@ -6,19 +6,18 @@ from transformers.commands.user import tabulate
def ask_for_confirmation(message: str) -> bool:
"""Ask user for confirmation with Y/N prompt.
Returns True for Y/yes, False for N/no/empty."""
y = ("y", "yes", "1")
n = ("n", "no", "0")
all_values = y + n + ("",)
full_message = f"{message} (Y/n) "
n = ("n", "no", "0", "")
full_message = f"{message} (y/n) "
while True:
answer = input(full_message).lower()
if answer == "":
return False
if answer in y:
return True
if answer in n:
return False
print(f"Invalid input. Must be one of {all_values}")
print(f"Invalid input. Must be one of: yes/no/y/n or empty for no")
def main():
@@ -43,9 +42,7 @@ def main():
args = parser.parse_args()
if args.scan:
print(
"Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".'
)
print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".')
hf_cache_info = scan_cache_dir()
print(
tabulate(
@@ -86,35 +83,41 @@ def main():
if args.pattern in repo.repo_id
]
if repos:
print("\nFound the following models:")
print(
tabulate(
rows=[
[
repo.repo_id,
repo.size_on_disk_str, # Added size information
str(repo.repo_path),
]
for repo in repos
],
headers=[
"REPO ID",
"SIZE", # Added size header
"LOCAL PATH",
],
)
)
confirmed = ask_for_confirmation(f"Confirm deletion ?")
confirmed = ask_for_confirmation(
"\nAre you sure you want to delete these models?"
)
if confirmed:
for model_info in repos:
print(f"\nDeleting {model_info.repo_id}...")
for revision in sorted(
model_info.revisions, key=lambda revision: revision.commit_hash
):
strategy = hf_cache_info.delete_revisions(revision.commit_hash)
strategy.execute()
print("Model(s) deleted.")
print("\nModel(s) deleted successfully.")
else:
print("Deletion is cancelled. Do nothing.")
print("\nDeletion cancelled - no changes made.")
else:
print(f"No models found.")
print(f'No models found matching pattern "{args.pattern}"')
if __name__ == "__main__":

View File

@@ -23,7 +23,12 @@ class BaseModelArgs:
)
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
def create_causal_mask(
N: int,
offset: int = 0,
window_size: Optional[int] = None,
lengths: Optional[mx.array] = None,
):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None]
@@ -31,6 +36,9 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
if lengths is not None:
lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths)
return mask * -1e9

View File

@@ -155,10 +155,12 @@ class CohereModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -180,9 +182,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out

View File

@@ -6,7 +6,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache
@@ -151,20 +151,18 @@ class CohereModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
T = h.shape[1]
if T > 1:
offset = cache[0].offset if cache else 0
mask = create_causal_mask(T, offset).astype(h.dtype)
else:
mask = None
if cache is None:
cache = [None] * len(self.layers)
if mask is None:
j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
@@ -181,9 +179,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out

View File

@@ -197,10 +197,12 @@ class DBRX(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -223,9 +225,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
return self.lm_head(out)
@property

View File

@@ -211,8 +211,10 @@ class DeepseekModel(nn.Module):
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -236,8 +238,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -370,8 +370,11 @@ class DeepseekV2Model(nn.Module):
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -395,8 +398,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -0,0 +1,460 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek_v3"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
routed_scaling_factor: float = 1.0
kv_lora_rank: int = 512
q_lora_rank: int = 1536
qk_rope_head_dim: int = 64
v_head_dim: int = 128
qk_nope_head_dim: int = 128
topk_method: str = "noaux_tc"
scoring_func: str = "sigmoid"
norm_topk_prob: bool = True
n_group: Optional[int] = None
topk_group: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Dict = None
attention_bias: bool = False
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
return mx.clip(linear_func, 0, 1)
class DeepseekV3YarnRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
super().__init__()
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
scaling_factor, mscale_all_dim
)
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
freq_inter = scaling_factor * base ** (
mx.arange(0, dim, 2, dtype=mx.float32) / dim
)
low, high = yarn_find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_position_embeddings,
)
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
self._freqs = (freq_inter * freq_extra) / (
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
)
def __call__(self, x, offset=0):
if self.mscale != 1.0:
x = self.mscale * x
return mx.fast.rope(
x,
x.shape[-1],
traditional=True,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
class DeepseekV3Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.scale = self.q_head_dim**-0.5
if self.q_lora_rank is None:
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
)
else:
self.q_a_proj = nn.Linear(
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
self.q_b_proj = nn.Linear(
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=config.attention_bias,
)
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self.rope = DeepseekV3YarnRotaryEmbedding(
dim=self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**rope_kwargs,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekV3MLP(nn.Module):
def __init__(
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = (
config.intermediate_size if intermediate_size is None else intermediate_size
)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x):
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.sigmoid(gates.astype(mx.float32))
assert self.topk_method == "noaux_tc", "Unsupported topk method."
bsz, seq_len = x.shape[:2]
scores = scores + self.e_score_correction_bias
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if self.top_k > 1 and self.norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * self.routed_scaling_factor
return inds, scores
class DeepseekV3MoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV3MLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekV3DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekV3Attention(config)
self.mlp = (
DeepseekV3MoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekV3MLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
# Protect against overflow for fp16
if out.dtype == mx.float16:
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
return out
class DeepseekV3Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekV3DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.layers = self.layers[start : start + layers_per_rank]
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1))
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h)[: h.shape[0]]
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekV3Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
# Remove multi-token prediction layer
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")}
@property
def layers(self):
return self.model.layers

View File

@@ -123,9 +123,11 @@ class ExaoneModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -149,9 +151,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:

View File

@@ -138,11 +138,13 @@ class GemmaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -164,9 +166,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
return out

View File

@@ -160,11 +160,13 @@ class GemmaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -187,9 +189,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = mx.tanh(out / self.final_logit_softcapping)
out = out * self.final_logit_softcapping

View File

@@ -126,6 +126,7 @@ class GPT2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
_, L = inputs.shape
@@ -138,6 +139,7 @@ class GPT2Model(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
@@ -159,9 +161,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.wte.as_linear(out)
return out

View File

@@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
B, L = inputs.shape
@@ -144,15 +145,16 @@ class GPTBigCodeModel(nn.Module):
hidden_states = self.wte(inputs)
mask = None
if hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
if mask is not None and hidden_states.shape[1] > 1:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
position_ids = mx.array(np.arange(L))
else:
position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L))
hidden_states += self.wpe(position_ids)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
@@ -172,9 +174,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:

View File

@@ -146,12 +146,14 @@ class GPTNeoXModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
_, L = inputs.shape
hidden_states = self.embed_in(inputs)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
@@ -176,9 +178,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return out
def sanitize(self, weights):

View File

@@ -239,10 +239,12 @@ class HunYuanModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -266,9 +268,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights):

View File

@@ -193,10 +193,12 @@ class InternLM2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.tok_embeddings(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -220,9 +222,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.tok_embeddings.as_linear(out)
else:

View File

@@ -155,10 +155,12 @@ class LlamaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -182,9 +184,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -158,10 +158,12 @@ class MiniCPMModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs) * self.args.scale_emb
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -186,9 +188,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if not self.args.tie_word_embeddings:
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))

View File

@@ -162,10 +162,12 @@ class MixtralModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -188,9 +190,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -176,10 +176,12 @@ class NemotronModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -203,9 +205,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -124,10 +124,12 @@ class Transformer(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -152,9 +154,10 @@ class OlmoModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
return self.transformer(inputs, cache)
return self.transformer(inputs, mask, cache)
class Model(nn.Module):
@@ -167,9 +170,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
return self.model(inputs, cache)
return self.model(inputs, mask, cache)
@property
def layers(self):

View File

@@ -163,9 +163,11 @@ class LlamaModel(nn.Module):
self,
inputs: mx.array,
cache=None,
mask=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -190,8 +192,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache=None,
mask=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -178,10 +178,12 @@ class OpenELMModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.token_embeddings(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -205,9 +207,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.share_input_output_layers:
out = self.transformer.token_embeddings.as_linear(out)
else:

View File

@@ -143,9 +143,10 @@ class PhiModel(nn.Module):
config.hidden_size, eps=config.layer_norm_eps
)
def __call__(self, x, cache):
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
if mask is None:
mask = create_attention_mask(x, cache)
if cache is None:
@@ -167,9 +168,10 @@ class Model(nn.Module):
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
y = self.model(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@property

View File

@@ -168,10 +168,12 @@ class Phi3Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -194,9 +196,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
@property

View File

@@ -258,12 +258,14 @@ class Phi3Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -290,9 +292,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
if self.mup_width_multiplier:
out = out / self.mup_width_multiplier

View File

@@ -155,10 +155,12 @@ class PhiMoEModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -181,9 +183,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -175,6 +175,8 @@ class Model(nn.Module):
mask: mx.array = None,
cache=None,
) -> mx.array:
if mask is None:
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)

View File

@@ -174,9 +174,11 @@ class PlamoModel(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -202,8 +204,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
@property

View File

@@ -123,6 +123,7 @@ class QwenModel(nn.Module):
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
if mask is None:
mask = create_attention_mask(x, cache)
if cache is None:

View File

@@ -149,10 +149,12 @@ class Qwen2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -176,9 +178,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -187,10 +187,12 @@ class Qwen2MoeModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -213,9 +215,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -389,6 +389,7 @@ class Griffin(nn.Module):
def __call__(
self,
tokens,
mask: mx.array = None,
cache=None,
):
x = self.embed_tokens(tokens)
@@ -402,6 +403,7 @@ class Griffin(nn.Module):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
if mask is None:
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
@@ -418,12 +420,12 @@ class Model(nn.Module):
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
"""
Args:
tokens: Sequence of input tokens.
"""
logits = self.model(tokens, cache=cache)
logits = self.model(tokens, mask=mask, cache=cache)
if "lm_head" in self:
logits = self.lm_head(logits)
else:

View File

@@ -199,7 +199,10 @@ class Model(nn.Module):
mask: mx.array = None,
cache=None,
) -> mx.array:
if mask is None:
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)

View File

@@ -125,10 +125,12 @@ class Starcoder2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
@@ -152,9 +154,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -1,4 +1,4 @@
mlx>=0.19.2
mlx>=0.22.0
numpy
transformers[sentencepiece]>=4.39.3
protobuf

View File

@@ -12,6 +12,7 @@ def make_sampler(
top_p: float = 0.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
top_k: int = -1,
) -> Callable[mx.array, mx.array]:
"""
Make a sampler function for use with ``generate_step``.
@@ -25,6 +26,8 @@ def make_sampler(
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
top_k (int, optional): The top k tokens ranked by probability to constrain
the sampling to.
Returns:
Callable[mx.array, mx.array]:
@@ -36,6 +39,8 @@ def make_sampler(
return lambda x: top_p_sampling(x, top_p, temp)
elif min_p != 0.0:
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
elif top_k > 0:
return lambda x: top_k_sampling(x, top_k, temp)
else:
return lambda x: categorical_sampling(x, temp)
@@ -79,6 +84,33 @@ def make_logits_processors(
return logits_processors
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_k_sampling(
logprobs: mx.array,
top_k: int,
temperature=1.0,
) -> mx.array:
"""
Sample from only the top K tokens ranked by probability.
Args:
logprobs: A vector of log probabilities.
top_k (int): Top k tokens to sample from.
"""
vocab_size = logprobs.shape[-1]
if not isinstance(top_k, int) or not (0 < top_k < vocab_size):
raise ValueError(
f"`top_k` has to be an integer in the (0, {vocab_size}] interval,"
f" but is {top_k}."
)
logprobs = logprobs * (1 / temperature)
mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:]
masked_logprobs = mx.put_along_axis(
logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1
)
return mx.random.categorical(masked_logprobs, axis=-1)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling(
logprobs: mx.array,
@@ -87,7 +119,7 @@ def min_p_sampling(
temperature=1.0,
) -> mx.array:
"""
Apply min-p sampling to the logits.
Apply min-p sampling to the logprobs.
Min-p keeps all tokens that are above a minimum probability, scaled by the
probability of the most likely token. As a result, the filter is more

View File

@@ -590,14 +590,10 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type
self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if (
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
):
if self.tokenizer.chat_template:
prompt = self.tokenizer.apply_chat_template(
body["messages"],
body.get("tools", None),
tokenize=True,
add_generation_prompt=True,
)
else:

View File

@@ -266,6 +266,18 @@ class TokenizerWrapper:
else {tokenizer.eos_token_id}
)
def add_eos_token(self, token: str):
token_id = None
try:
token_id = int(token)
except ValueError:
token_id = self._tokenizer.convert_tokens_to_ids(token)
if token_id is None:
raise ValueError(f"'{token}' is not a token for this tokenizer")
self._eos_token_ids.add(token_id)
def __getattr__(self, attr):
if attr == "detokenizer":
return self._detokenizer

View File

@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
@@ -10,41 +10,47 @@ class Dataset:
Light-weight wrapper to hold a dataset.
"""
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
self._text_key = text_key
self._data = data
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
text_key: str = "text",
):
self._data = [tokenizer.encode(d[text_key]) for d in data]
for d in self._data:
if d[-1] != tokenizer.eos_token_id:
d.append(tokenizer.eos_token_id)
def __getitem__(self, idx: int):
return self._data[idx][self._text_key]
return self._data[idx]
def __len__(self):
if self._data is None:
return 0
return len(self._data)
class ChatDataset(Dataset):
class ChatDataset:
"""
A dataset for chat data in the format of {"messages": [...]}
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data)
self._tokenizer = tokenizer
self._data = [
tokenizer.apply_chat_template(
d["messages"],
tools=d.get("tools", None),
)
for d in data
]
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
text = self._tokenizer.apply_chat_template(
messages,
tools=self._data[idx].get("tools", None),
tokenize=False,
add_generation_prompt=True,
)
return text
return self._data[idx]
def __len__(self):
return len(self._data)
class CompletionsDataset(Dataset):
class CompletionsDataset:
"""
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values
@@ -55,36 +61,41 @@ class CompletionsDataset(Dataset):
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
completion_key: str = "completion",
prompt_key: str,
completion_key: str,
):
super().__init__(data)
self._tokenizer = tokenizer
self._prompt_key = prompt_key
self._completion_key = completion_key
self._data = [
tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]},
],
)
for d in data
]
def __getitem__(self, idx: int):
data = self._data[idx]
text = self._tokenizer.apply_chat_template(
[
{"role": "user", "content": data[self._prompt_key]},
{"role": "assistant", "content": data[self._completion_key]},
],
tokenize=False,
add_generation_prompt=True,
)
return text
return self._data[idx]
def __len__(self):
return len(self._data)
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
def create_dataset(
data,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
sample = data[0]
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif "prompt" in sample and "completion" in sample:
return CompletionsDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample:
return Dataset(data)
return Dataset(data, tokenizer)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
@@ -92,20 +103,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
)
def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
def load_local_dataset(
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer)
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
return train, valid, test
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
def load_hf_dataset(
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
from datasets import exceptions, load_dataset
try:
@@ -114,7 +135,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test")
train, valid, test = [
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
(
create_dataset(
dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys()
else []
)
for n in names
]
@@ -143,7 +170,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
return Dataset(train_ds, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
@@ -166,15 +193,22 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
if getattr(args, "hf_dataset", False):
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists():
train, valid, test = load_local_dataset(data_path, tokenizer)
train, valid, test = load_local_dataset(
data_path, tokenizer, prompt_feature, completion_feature
)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args.data, tokenizer)
train, valid, test = load_hf_dataset(
args.data, tokenizer, prompt_feature, completion_feature
)
if args.train and len(train) == 0:
raise ValueError(

View File

@@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
while True:
indices = np.random.permutation(len(batch_idx))
for i in indices:
# Encode batch
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
batch = [dataset[j] for j in batch_idx[i]]
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "

View File

@@ -2,10 +2,12 @@
import contextlib
import copy
import functools
import glob
import importlib
import json
import logging
import os
import shutil
import time
from dataclasses import dataclass
@@ -15,7 +17,17 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
try:
from modelscope import snapshot_download
except ImportError:
raise ImportError(
"Please run `pip install modelscope` to activate the ModelScope."
)
else:
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
@@ -153,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
if not model_path.exists():
try:
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
@@ -207,12 +220,6 @@ def generate_step(
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -256,24 +263,16 @@ def generate_step(
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
print(
"[Warning] Specifying sampling arguments to ``generate_step`` is "
"deprecated. Pass in a ``sampler`` instead."
)
if repetition_penalty is not None:
print(
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
"Pass in ``logits_processors`` instead."
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
sampler = sampler or make_sampler(
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
)
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
def _step(y):
with mx.stream(generation_stream):
@@ -287,9 +286,7 @@ def generate_step(
for processor in logits_processors:
logits = processor(tokens, logits)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
quantize_cache_fn(prompt_cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs)
@@ -300,9 +297,7 @@ def generate_step(
prompt_processed_tokens = 0
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
quantize_cache_fn(prompt_cache)
mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size
@@ -329,10 +324,162 @@ def generate_step(
n += 1
def speculative_generate_step(
prompt: mx.array,
model: nn.Module,
draft_model: nn.Module,
*,
num_draft_tokens=2,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
draft_model (nn.Module): The draft model for speculative decoding.
num_draft_tokens (int, optional): The number of draft tokens for
speculative decoding. Default: ``2``.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. The cache must be trimmable.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
y = prompt
tokens = None
# Create the KV cache for generation
if prompt_cache is None:
model_cache = cache.make_prompt_cache(model)
draft_cache = cache.make_prompt_cache(draft_model)
elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)):
raise ValueError("Wrong number of layers in the prompt cache.")
else:
model_cache = prompt_cache[: len(model.layers)]
draft_cache = prompt_cache[len(model.layers) :]
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
def _step(model, cache, y, n_predict=1):
with mx.stream(generation_stream):
logits = model(y[None], cache=cache)
logits = logits[:, -n_predict:, :]
quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs).squeeze(0)
return y, logprobs.squeeze(0)
def _prefill(model, cache, y):
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
quantize_cache_fn(cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
mx.metal.clear_cache()
return y
def _rewind_cache(num_draft, num_accept):
cache.trim_prompt_cache(model_cache, num_draft - num_accept)
cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0))
def _draft_generate(y, num_draft):
if num_draft == 0:
return mx.array([], mx.uint32)
ys = []
for _ in range(num_draft):
y, _ = _step(draft_model, draft_cache, y)
mx.async_eval(y)
ys.append(y)
return mx.concatenate(ys)
with mx.stream(generation_stream):
draft_y = _prefill(draft_model, draft_cache, y)
y = _prefill(model, model_cache, y)
ntoks = 0
# Set these so the finally block doesn't raise
num_draft = 0
n = 0
try:
while True:
num_draft = min(max_tokens - ntoks, num_draft_tokens)
draft_tokens = _draft_generate(draft_y, num_draft)
y = mx.concatenate([y, draft_tokens])
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
mx.eval(tokens, draft_tokens)
draft_tokens = draft_tokens.tolist()
tokens = tokens.tolist()
n = 0
while n < num_draft:
tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n]
if tn != dtn:
break
n += 1
ntoks += 1
yield tn, lpn
if ntoks == max_tokens:
break
if ntoks < max_tokens:
ntoks += 1
yield tokens[n], logprobs[n]
if ntoks == max_tokens:
break
y = mx.array([tokens[n]], mx.uint32)
draft_y = y
# If we accpeted all the draft tokens, include the last
# draft token in the next draft step since it hasn't been
# processed yet by the draft model
if n == num_draft:
draft_y = mx.concatenate(
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
)
_rewind_cache(num_draft, n)
finally:
_rewind_cache(num_draft, n)
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]],
draft_model: Optional[nn.Module] = None,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
"""
@@ -341,7 +488,11 @@ def stream_generate(
Args:
model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
prompt (Union[str, mx.array, List[int]]): The input prompt string or
integer tokens.
draft_model (Optional[nn.Module]): An optional draft model. If provided
then speculative decoding is used. The draft model must use the same
tokenizer as the main model. Default: ``None``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
@@ -353,16 +504,28 @@ def stream_generate(
tokenizer = TokenizerWrapper(tokenizer)
if not isinstance(prompt, mx.array):
prompt = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
if isinstance(prompt, str):
# Try to infer if special tokens are needed
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
prompt = mx.array(prompt)
detokenizer = tokenizer.detokenizer
if draft_model is None:
kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs)
else:
kwargs.pop("max_kv_size", None)
token_generator = speculative_generate_step(
prompt, model, draft_model, **kwargs
)
with wired_limit(model, [generation_stream]):
detokenizer.reset()
tic = time.perf_counter()
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
for n, (token, logprobs) in enumerate(token_generator):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
@@ -401,7 +564,7 @@ def stream_generate(
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
prompt: Union[str, List[int]],
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
@@ -412,7 +575,7 @@ def generate(
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`.
@@ -425,7 +588,6 @@ def generate(
)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
text = ""
for response in stream_generate(model, tokenizer, prompt, **kwargs):
@@ -558,7 +720,7 @@ def load(
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
lazy (bool): If ``False`` eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns:
@@ -652,12 +814,12 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
model, tokenizer = load("{upload_repo}")
prompt="hello"
prompt = "hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
if tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
@@ -670,12 +832,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
api.upload_large_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")

View File

@@ -27,8 +27,8 @@ setup(
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
python_requires=">=3.8",
extras_require={
"testing": ["datasets"],
"evaluation": ["lm-eval"],
"test": ["datasets"],
"evaluate": ["lm-eval", "tqdm"],
},
entry_points={
"console_scripts": [

View File

@@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase):
data = {"text": "This is an example for the model."}
self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
train, valid, test = datasets.load_dataset(args, None)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0)
@@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase):
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
},
test=False,
train=True,

View File

@@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map
from mlx_lm.models import rope_utils
from mlx_lm.models.base import create_causal_mask
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
@@ -128,6 +129,22 @@ class TestModels(unittest.TestCase):
self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def test_causal_mask_lengths(self):
mx.random.seed(8)
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
lengths = mx.array([1, 2, 3, 1])
q = mx.random.uniform(shape=(B, N_q, T_q, D))
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
v = k
mask = create_causal_mask(T_q, 0, lengths=lengths)
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
k[1, :, 2:] = mx.ones_like(k[1, :, 2:])
v[1, :, 2:] = mx.ones_like(v[1, :, 2:])
out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2]))
def test_rope(self):
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
self.assertTrue(isinstance(rope, nn.RoPE))
@@ -162,7 +179,13 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.dtype, t)
cache = make_prompt_cache(model)
outputs = model(inputs, cache)
outputs = model(inputs, cache=cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
if model_type != "mamba":
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
outputs = model(inputs, mask=mask)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
@@ -659,6 +682,43 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek_v3(self):
from mlx_lm.models import deepseek_v3
args = deepseek_v3.ModelArgs(
model_type="deepseek_v3",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
n_routed_experts=4,
n_group=2,
topk_group=1,
num_experts_per_tok=2,
n_shared_experts=1,
kv_lora_rank=4,
q_lora_rank=4,
qk_rope_head_dim=32,
v_head_dim=16,
qk_nope_head_dim=32,
rope_scaling={
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
},
)
model = deepseek_v3.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma2(self):
from mlx_lm.models import gemma2

View File

@@ -1,7 +1,7 @@
import unittest
import mlx.core as mx
from mlx_lm.sample_utils import min_p_sampling, top_p_sampling
from mlx_lm.sample_utils import min_p_sampling, top_k_sampling, top_p_sampling
class TestSampleUtils(unittest.TestCase):
@@ -42,6 +42,27 @@ class TestSampleUtils(unittest.TestCase):
token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3))
def test_top_k_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
token = top_k_sampling(logits, 1).item()
self.assertEqual(token, 0)
probs = mx.array([0.5, 0.0, 0.0, 0.5])[None]
tokens = set()
for _ in range(100):
token = top_k_sampling(logits, 2)
tokens.add(token.item())
self.assertEqual(tokens, {0, 3})
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = top_k_sampling(logits, 1)
self.assertEqual(tokens.tolist(), [0, 1])
if __name__ == "__main__":
unittest.main()

View File

@@ -17,7 +17,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
self.config = args
self.custom_attribute = "This is a custom model"
def load_weights(self, weights):
def load_weights(self, weights, **kwargs):
self.qwenWeights = weights
class CustomQwenConfig: