Merge branch 'ml-explore:main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez
2025-01-13 20:16:04 +01:00
committed by GitHub
59 changed files with 1272 additions and 279 deletions

View File

@@ -32,7 +32,7 @@ jobs:
pip install --upgrade pip pip install --upgrade pip
pip install unittest-xml-reporting pip install unittest-xml-reporting
cd llms/ cd llms/
pip install -e ".[testing]" pip install -e ".[test]"
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |

View File

@@ -58,7 +58,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=True
) )
text = generate(model, tokenizer, prompt=prompt, verbose=True) text = generate(model, tokenizer, prompt=prompt, verbose=True)
@@ -115,7 +115,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=True
) )
for response in stream_generate(model, tokenizer, prompt, max_tokens=512): for response in stream_generate(model, tokenizer, prompt, max_tokens=512):

View File

@@ -266,14 +266,25 @@ Refer to the documentation for the model you are fine-tuning for more details.
{"prompt": "What is the capital of France?", "completion": "Paris."} {"prompt": "What is the capital of France?", "completion": "Paris."}
``` ```
For the `completions` data format, a different key can be used for the prompt
and completion by specifying the following in the YAML config:
```yaml
prompt_feature: "input"
completion_feature: "output"
```
Here, `"input"` is the expected key instead of the default `"prompt"`, and
`"output"` is the expected key instead of `"completion"`.
`text`: `text`:
```jsonl ```jsonl
{"text": "This is an example for the model."} {"text": "This is an example for the model."}
``` ```
Note, the format is automatically determined by the dataset. Note also, keys in Note, the format is automatically determined by the dataset. Note also, keys
each line not expected by the loader will be ignored. in each line not expected by the loader will be ignored.
> [!NOTE] > [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than > Each example in the datasets must be on a single line. Do not put more than
@@ -295,7 +306,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example: example:
``` ```yaml
hf_dataset: hf_dataset:
name: "billsum" name: "billsum"
prompt_feature: "text" prompt_feature: "text"

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.20.4" __version__ = "0.21.0"

View File

@@ -110,29 +110,17 @@ def main():
if tokenizer.chat_template is None: if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and ( if not args.ignore_chat_template and tokenizer.chat_template is not None:
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}] messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=False, continue_final_message=True
) )
# Treat the prompt as a prefix assuming that the suffix will be
# provided at generation time.
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
prompt = prompt[:-n]
else: else:
prompt = args.prompt prompt = tokenizer.encode(args.prompt)
cache = make_prompt_cache(model, args.max_kv_size) cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt)) y = mx.array(prompt)
# Process the prompt # Process the prompt
start = time.time() start = time.time()

View File

@@ -72,9 +72,7 @@ def main():
if query == "q": if query == "q":
break break
messages = [{"role": "user", "content": query}] messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate( for response in stream_generate(
model, model,
tokenizer, tokenizer,

View File

@@ -1,4 +1,8 @@
# Adapted from a PyTorch implementation by David Grangier # Copyright © 2024 Apple Inc.
"""
Adapted from a PyTorch implementation by David Grangier
"""
import argparse import argparse
import json import json
@@ -6,7 +10,7 @@ import logging
import os import os
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union
import lm_eval import lm_eval
import mlx.core as mx import mlx.core as mx
@@ -73,15 +77,19 @@ class MLXLM(LM):
path_or_hf_repo: str, path_or_hf_repo: str,
batch_size: int = 16, batch_size: int = 16,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
use_chat_template: Optional[bool] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self._batch_size = batch_size self._batch_size = batch_size
self._model, self._tokenizer = load(path_or_hf_repo) self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self._tokenizer.model_max_length self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template or (
self.tokenizer.chat_template is not None
)
def _score_fn(self, inputs, tokenize=True, step_size=32): def _score_fn(self, inputs, tokenize=True, step_size=32):
if tokenize: if tokenize:
inputs = self._tokenizer.encode(inputs) inputs = self._tokenize(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False) inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs) inputs = mx.array(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:] inputs, targets = inputs[..., :-1], inputs[..., 1:]
@@ -145,7 +153,12 @@ class MLXLM(LM):
return results return results
def _tokenize(self, texts): def _tokenize(self, texts):
return [tuple(self._tokenizer.encode(t)) for t in texts] return [
tuple(
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
)
for t in texts
]
def loglikelihood(self, requests) -> list[tuple[float, bool]]: def loglikelihood(self, requests) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context. """Compute log-likelihood of generating a continuation from a context.
@@ -217,6 +230,9 @@ class MLXLM(LM):
) )
return [(r[0], r[1] == r[2]) for r in results] return [(r[0], r[1] == r[2]) for r in results]
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
def loglikelihood_rolling(self, requests) -> list[float]: def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model. - We will use the full max context length of the model.
@@ -277,23 +293,16 @@ class MLXLM(LM):
assert "until" in keys assert "until" in keys
untils = [x["until"] for x in options] untils = [x["until"] for x in options]
completions = [] completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
if (
hasattr(self._tokenizer, "apply_chat_template")
and self._tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": context}]
context = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
context = self._tokenize(context)
max_tokens = min( max_tokens = min(
self._max_tokens, self._max_tokens,
self._tokenizer.model_max_length - len(self._tokenizer.encode(context)), self.tokenizer.model_max_length - len(context),
) )
text = "" text = ""
for response in stream_generate( for response in stream_generate(
self._model, self._tokenizer, prompt=context, max_tokens=max_tokens self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
): ):
text += response.text text += response.text
if any(u in text for u in until): if any(u in text for u in until):
@@ -321,7 +330,28 @@ def main():
type=int, type=int,
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.", help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
) )
parser.add_argument(
"--limit",
default=1.0,
help="Limit the number of examples per task.",
type=float,
)
parser.add_argument("--seed", type=int, default=123, help="Random seed.") parser.add_argument("--seed", type=int, default=123, help="Random seed.")
parser.add_argument(
"--fewshot-as-multiturn",
action="store_true",
help="Whether to provide the fewshot examples as a multiturn "
"conversation or a single user turn.",
default=False,
)
parser.add_argument(
"--apply-chat-template",
action=argparse.BooleanOptionalAction,
help="Specifies whether to apply a chat template to the prompt. If "
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
)
args = parser.parse_args() args = parser.parse_args()
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
@@ -332,12 +362,19 @@ def main():
mx.random.seed(args.seed) mx.random.seed(args.seed)
lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens) lm = MLXLM(
args.model,
batch_size=args.batch_size,
max_tokens=args.max_tokens,
use_chat_template=args.apply_chat_template,
)
results = lm_eval.simple_evaluate( results = lm_eval.simple_evaluate(
model=lm, model=lm,
tasks=args.tasks, tasks=args.tasks,
fewshot_as_multiturn=args.fewshot_as_multiturn,
apply_chat_template=lm.use_chat_template,
num_fewshot=args.num_shots, num_fewshot=args.num_shots,
limit=args.limit,
random_seed=args.seed, random_seed=args.seed,
numpy_random_seed=args.seed, numpy_random_seed=args.seed,
torch_random_seed=args.seed, torch_random_seed=args.seed,

View File

@@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model)
# User turn # User turn
prompt = "Hi my name is <Name>." prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response # Assistant response
response = generate( response = generate(
@@ -32,9 +30,7 @@ response = generate(
# User turn # User turn
prompt = "What's my name?" prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response # Assistant response
response = generate( response = generate(

View File

@@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}]
# Transform the prompt into the chat template # Transform the prompt into the chat template
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
conversation=conversation, tokenize=False, add_generation_prompt=True conversation=conversation, add_generation_prompt=True
) )
# Specify the maximum number of tokens # Specify the maximum number of tokens

View File

@@ -0,0 +1,75 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
```
/path/to/mpirun \
-np 2 \
--hostfile /path/to/hosts.txt \
python /path/to/pipeline_generate.py --prompt "hello world"
```
Make sure you can run MLX over MPI on two hosts. For more information see the
documentation:
https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
"""
import argparse
import mlx.core as mx
from mlx_lm import load, stream_generate
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
model_repo = "mlx-community/DeepSeek-V3-3bit"
model, tokenizer = load(model_repo, lazy=True)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
group = mx.distributed.init()
rank = group.rank()
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
def rprint(*args, **kwargs):
if rank == 0:
print(*args, **kwargs)
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens):
rprint(response.text, end="", flush=True)
rprint()
rprint("=" * 10)
rprint(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
rprint(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
rprint(f"Peak memory: {response.peak_memory:.3f} GB")

View File

@@ -1,7 +1,6 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
import codecs
import json import json
import sys import sys
@@ -44,10 +43,11 @@ def setup_arg_parser():
help="Optional path for the trained adapter weights and config.", help="Optional path for the trained adapter weights and config.",
) )
parser.add_argument( parser.add_argument(
"--eos-token", "--extra-eos-token",
type=str, type=str,
default=None, default=(),
help="End of sequence token for tokenizer", nargs="+",
help="Add tokens in the list of eos tokens that stop generation.",
) )
parser.add_argument( parser.add_argument(
"--system-prompt", "--system-prompt",
@@ -131,6 +131,18 @@ def setup_arg_parser():
type=int, type=int,
default=DEFAULT_QUANTIZED_KV_START, default=DEFAULT_QUANTIZED_KV_START,
) )
parser.add_argument(
"--draft-model",
type=str,
help="A model to be used for speculative decoding.",
default=None,
)
parser.add_argument(
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
default=2,
)
return parser return parser
@@ -162,8 +174,6 @@ def main():
{} if not using_cache else json.loads(metadata["tokenizer_config"]) {} if not using_cache else json.loads(metadata["tokenizer_config"])
) )
tokenizer_config["trust_remote_code"] = True tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model_path = args.model model_path = args.model
if using_cache: if using_cache:
@@ -182,6 +192,8 @@ def main():
adapter_path=args.adapter_path, adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config, tokenizer_config=tokenizer_config,
) )
for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token)
if args.use_default_chat_template: if args.use_default_chat_template:
if tokenizer.chat_template is None: if tokenizer.chat_template is None:
@@ -189,22 +201,14 @@ def main():
elif using_cache: elif using_cache:
tokenizer.chat_template = metadata["chat_template"] tokenizer.chat_template = metadata["chat_template"]
prompt = codecs.decode(args.prompt, "unicode_escape") prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt
if not args.ignore_chat_template and ( if not args.ignore_chat_template and tokenizer.chat_template is not None:
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
if args.system_prompt is not None: if args.system_prompt is not None:
messages = [{"role": "system", "content": args.system_prompt}] messages = [{"role": "system", "content": args.system_prompt}]
else: else:
messages = [] messages = []
messages.append( messages.append({"role": "user", "content": prompt})
{
"role": "user",
"content": sys.stdin.read() if prompt == "-" else prompt,
}
)
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
@@ -219,7 +223,16 @@ def main():
add_generation_prompt=True, add_generation_prompt=True,
) )
prompt = prompt[test_prompt.index("<query>") :] prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
prompt = tokenizer.encode(prompt)
if args.draft_model is not None:
draft_model, draft_tokenizer = load(args.draft_model)
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
raise ValueError("Draft model tokenizer does not match model tokenizer.")
else:
draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate( response = generate(
model, model,
@@ -233,6 +246,8 @@ def main():
kv_bits=args.kv_bits, kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size, kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start, quantized_kv_start=args.quantized_kv_start,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
) )
if not args.verbose: if not args.verbose:
print(response) print(response)

View File

@@ -2,6 +2,7 @@
import argparse import argparse
import math import math
import os
import re import re
import types import types
from pathlib import Path from pathlib import Path
@@ -57,6 +58,8 @@ CONFIG_DEFAULTS = {
"test": False, "test": False,
"test_batches": 500, "test_batches": 500,
"max_seq_length": 2048, "max_seq_length": 2048,
"config": None,
"grad_checkpoint": False,
"lr_schedule": None, "lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
} }
@@ -66,6 +69,7 @@ def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str,
help="The path to the local model directory or Hugging Face repo.", help="The path to the local model directory or Hugging Face repo.",
) )
@@ -74,7 +78,6 @@ def build_parser():
"--train", "--train",
action="store_true", action="store_true",
help="Do training", help="Do training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--data", "--data",
@@ -88,7 +91,6 @@ def build_parser():
"--fine-tune-type", "--fine-tune-type",
type=str, type=str,
choices=["lora", "dora", "full"], choices=["lora", "dora", "full"],
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.", help="Type of fine-tuning to perform: lora, dora, or full.",
) )
parser.add_argument( parser.add_argument(
@@ -133,7 +135,6 @@ def build_parser():
"--test", "--test",
action="store_true", action="store_true",
help="Evaluate on the test set after training", help="Evaluate on the test set after training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--test-batches", "--test-batches",
@@ -148,16 +149,15 @@ def build_parser():
parser.add_argument( parser.add_argument(
"-c", "-c",
"--config", "--config",
default=None, type=str,
help="A YAML configuration file with the training options", help="A YAML configuration file with the training options",
) )
parser.add_argument( parser.add_argument(
"--grad-checkpoint", "--grad-checkpoint",
action="store_true", action="store_true",
help="Use gradient checkpointing to reduce memory use.", help="Use gradient checkpointing to reduce memory use.",
default=None,
) )
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed") parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser return parser
@@ -271,6 +271,7 @@ def run(args, training_callback: TrainingCallback = None):
def main(): def main():
os.environ["TOKENIZERS_PARALLELISM"] = "true"
parser = build_parser() parser = build_parser()
args = parser.parse_args() args = parser.parse_args()
config = args.config config = args.config

View File

@@ -6,19 +6,18 @@ from transformers.commands.user import tabulate
def ask_for_confirmation(message: str) -> bool: def ask_for_confirmation(message: str) -> bool:
"""Ask user for confirmation with Y/N prompt.
Returns True for Y/yes, False for N/no/empty."""
y = ("y", "yes", "1") y = ("y", "yes", "1")
n = ("n", "no", "0") n = ("n", "no", "0", "")
all_values = y + n + ("",) full_message = f"{message} (y/n) "
full_message = f"{message} (Y/n) "
while True: while True:
answer = input(full_message).lower() answer = input(full_message).lower()
if answer == "":
return False
if answer in y: if answer in y:
return True return True
if answer in n: if answer in n:
return False return False
print(f"Invalid input. Must be one of {all_values}") print(f"Invalid input. Must be one of: yes/no/y/n or empty for no")
def main(): def main():
@@ -43,9 +42,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
if args.scan: if args.scan:
print( print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".')
"Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".'
)
hf_cache_info = scan_cache_dir() hf_cache_info = scan_cache_dir()
print( print(
tabulate( tabulate(
@@ -86,35 +83,41 @@ def main():
if args.pattern in repo.repo_id if args.pattern in repo.repo_id
] ]
if repos: if repos:
print("\nFound the following models:")
print( print(
tabulate( tabulate(
rows=[ rows=[
[ [
repo.repo_id, repo.repo_id,
repo.size_on_disk_str, # Added size information
str(repo.repo_path), str(repo.repo_path),
] ]
for repo in repos for repo in repos
], ],
headers=[ headers=[
"REPO ID", "REPO ID",
"SIZE", # Added size header
"LOCAL PATH", "LOCAL PATH",
], ],
) )
) )
confirmed = ask_for_confirmation(f"Confirm deletion ?") confirmed = ask_for_confirmation(
"\nAre you sure you want to delete these models?"
)
if confirmed: if confirmed:
for model_info in repos: for model_info in repos:
print(f"\nDeleting {model_info.repo_id}...")
for revision in sorted( for revision in sorted(
model_info.revisions, key=lambda revision: revision.commit_hash model_info.revisions, key=lambda revision: revision.commit_hash
): ):
strategy = hf_cache_info.delete_revisions(revision.commit_hash) strategy = hf_cache_info.delete_revisions(revision.commit_hash)
strategy.execute() strategy.execute()
print("Model(s) deleted.") print("\nModel(s) deleted successfully.")
else: else:
print("Deletion is cancelled. Do nothing.") print("\nDeletion cancelled - no changes made.")
else: else:
print(f"No models found.") print(f'No models found matching pattern "{args.pattern}"')
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -23,7 +23,12 @@ class BaseModelArgs:
) )
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): def create_causal_mask(
N: int,
offset: int = 0,
window_size: Optional[int] = None,
lengths: Optional[mx.array] = None,
):
rinds = mx.arange(offset + N) rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None] linds = linds[:, None]
@@ -31,6 +36,9 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
mask = linds < rinds mask = linds < rinds
if window_size is not None: if window_size is not None:
mask = mask | (linds > rinds + window_size) mask = mask | (linds > rinds + window_size)
if lengths is not None:
lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths)
return mask * -1e9 return mask * -1e9

View File

@@ -155,10 +155,12 @@ class CohereModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -180,9 +182,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out) out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale out = out * self.model.args.logit_scale
return out return out

View File

@@ -6,7 +6,7 @@ from typing import Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache from .cache import KVCache, RotatingKVCache
@@ -151,20 +151,18 @@ class CohereModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
T = h.shape[1]
if T > 1:
offset = cache[0].offset if cache else 0
mask = create_causal_mask(T, offset).astype(h.dtype)
else:
mask = None
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
if mask is None:
j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j])
for layer, c in zip(self.layers, cache): for layer, c in zip(self.layers, cache):
h = layer(h, mask, c) h = layer(h, mask, c)
@@ -181,9 +179,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out) out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale out = out * self.model.args.logit_scale
return out return out

View File

@@ -197,10 +197,12 @@ class DBRX(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.wte(inputs) h = self.wte(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -223,9 +225,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.transformer(inputs, cache) out = self.transformer(inputs, mask, cache)
return self.lm_head(out) return self.lm_head(out)
@property @property

View File

@@ -211,8 +211,10 @@ class DeepseekModel(nn.Module):
self, self,
x: mx.array, x: mx.array,
cache: Optional[Any] = None, cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array: ) -> mx.array:
h = self.embed_tokens(x) h = self.embed_tokens(x)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -236,8 +238,9 @@ class Model(nn.Module):
self, self,
inputs: mx.array, inputs: mx.array,
cache: Optional[Any] = None, cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, cache, mask)
return self.lm_head(out) return self.lm_head(out)
def sanitize(self, weights): def sanitize(self, weights):

View File

@@ -370,8 +370,11 @@ class DeepseekV2Model(nn.Module):
self, self,
x: mx.array, x: mx.array,
cache: Optional[Any] = None, cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array: ) -> mx.array:
h = self.embed_tokens(x) h = self.embed_tokens(x)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -395,8 +398,9 @@ class Model(nn.Module):
self, self,
inputs: mx.array, inputs: mx.array,
cache: Optional[Any] = None, cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, cache, mask)
return self.lm_head(out) return self.lm_head(out)
def sanitize(self, weights): def sanitize(self, weights):

View File

@@ -0,0 +1,460 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek_v3"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
routed_scaling_factor: float = 1.0
kv_lora_rank: int = 512
q_lora_rank: int = 1536
qk_rope_head_dim: int = 64
v_head_dim: int = 128
qk_nope_head_dim: int = 128
topk_method: str = "noaux_tc"
scoring_func: str = "sigmoid"
norm_topk_prob: bool = True
n_group: Optional[int] = None
topk_group: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Dict = None
attention_bias: bool = False
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
return mx.clip(linear_func, 0, 1)
class DeepseekV3YarnRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
super().__init__()
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
scaling_factor, mscale_all_dim
)
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
freq_inter = scaling_factor * base ** (
mx.arange(0, dim, 2, dtype=mx.float32) / dim
)
low, high = yarn_find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_position_embeddings,
)
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
self._freqs = (freq_inter * freq_extra) / (
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
)
def __call__(self, x, offset=0):
if self.mscale != 1.0:
x = self.mscale * x
return mx.fast.rope(
x,
x.shape[-1],
traditional=True,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
class DeepseekV3Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.scale = self.q_head_dim**-0.5
if self.q_lora_rank is None:
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
)
else:
self.q_a_proj = nn.Linear(
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
self.q_b_proj = nn.Linear(
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=config.attention_bias,
)
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self.rope = DeepseekV3YarnRotaryEmbedding(
dim=self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**rope_kwargs,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekV3MLP(nn.Module):
def __init__(
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = (
config.intermediate_size if intermediate_size is None else intermediate_size
)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x):
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.sigmoid(gates.astype(mx.float32))
assert self.topk_method == "noaux_tc", "Unsupported topk method."
bsz, seq_len = x.shape[:2]
scores = scores + self.e_score_correction_bias
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if self.top_k > 1 and self.norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * self.routed_scaling_factor
return inds, scores
class DeepseekV3MoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV3MLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekV3DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekV3Attention(config)
self.mlp = (
DeepseekV3MoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekV3MLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
# Protect against overflow for fp16
if out.dtype == mx.float16:
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
return out
class DeepseekV3Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekV3DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.layers = self.layers[start : start + layers_per_rank]
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1))
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h)[: h.shape[0]]
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekV3Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
# Remove multi-token prediction layer
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")}
@property
def layers(self):
return self.model.layers

View File

@@ -123,9 +123,11 @@ class ExaoneModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.wte(inputs) h = self.wte(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -149,9 +151,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.transformer(inputs, cache) out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings: if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out) out = self.transformer.wte.as_linear(out)
else: else:

View File

@@ -138,11 +138,13 @@ class GemmaModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5) h = h * (self.args.hidden_size**0.5)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -164,9 +166,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out) out = self.model.embed_tokens.as_linear(out)
return out return out

View File

@@ -160,11 +160,13 @@ class GemmaModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5) h = h * (self.args.hidden_size**0.5)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -187,9 +189,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out) out = self.model.embed_tokens.as_linear(out)
out = mx.tanh(out / self.final_logit_softcapping) out = mx.tanh(out / self.final_logit_softcapping)
out = out * self.final_logit_softcapping out = out * self.final_logit_softcapping

View File

@@ -126,6 +126,7 @@ class GPT2Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
_, L = inputs.shape _, L = inputs.shape
@@ -138,6 +139,7 @@ class GPT2Model(nn.Module):
position_ids = mx.array(np.arange(L)) position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids) hidden_states += self.wpe(position_ids)
if mask is None:
mask = create_attention_mask(hidden_states, cache) mask = create_attention_mask(hidden_states, cache)
if cache is None: if cache is None:
@@ -159,9 +161,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
out = self.model.wte.as_linear(out) out = self.model.wte.as_linear(out)
return out return out

View File

@@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
B, L = inputs.shape B, L = inputs.shape
@@ -144,15 +145,16 @@ class GPTBigCodeModel(nn.Module):
hidden_states = self.wte(inputs) hidden_states = self.wte(inputs)
mask = None mask = None
if hidden_states.shape[1] > 1: if mask is not None and hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_attention_mask(hidden_states, cache) mask = create_attention_mask(hidden_states, cache)
if cache is None: if cache is None:
cache = [None] * len(self.h) cache = [None] * len(self.h)
position_ids = mx.array(np.arange(L))
else:
position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L))
hidden_states += self.wpe(position_ids)
for layer, c in zip(self.h, cache): for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c) hidden_states = layer(hidden_states, mask, cache=c)
@@ -172,9 +174,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.transformer(inputs, cache) out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings: if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out) out = self.transformer.wte.as_linear(out)
else: else:

View File

@@ -146,12 +146,14 @@ class GPTNeoXModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
_, L = inputs.shape _, L = inputs.shape
hidden_states = self.embed_in(inputs) hidden_states = self.embed_in(inputs)
if mask is None:
mask = create_attention_mask(hidden_states, cache) mask = create_attention_mask(hidden_states, cache)
if cache is None: if cache is None:
@@ -176,9 +178,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
return out return out
def sanitize(self, weights): def sanitize(self, weights):

View File

@@ -239,10 +239,12 @@ class HunYuanModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -266,9 +268,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
return self.model.embed_tokens.as_linear(out) return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights): def sanitize(self, weights):

View File

@@ -193,10 +193,12 @@ class InternLM2Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.tok_embeddings(inputs) h = self.tok_embeddings(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -220,9 +222,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings: if self.args.tie_word_embeddings:
out = self.model.tok_embeddings.as_linear(out) out = self.model.tok_embeddings.as_linear(out)
else: else:

View File

@@ -155,10 +155,12 @@ class LlamaModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -182,9 +184,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings: if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out) out = self.model.embed_tokens.as_linear(out)
else: else:

View File

@@ -158,10 +158,12 @@ class MiniCPMModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) * self.args.scale_emb h = self.embed_tokens(inputs) * self.args.scale_emb
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -186,9 +188,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
if not self.args.tie_word_embeddings: if not self.args.tie_word_embeddings:
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base)) out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))

View File

@@ -162,10 +162,12 @@ class MixtralModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -188,9 +190,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
return self.lm_head(out) return self.lm_head(out)
def sanitize(self, weights): def sanitize(self, weights):

View File

@@ -176,10 +176,12 @@ class NemotronModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -203,9 +205,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings: if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out) out = self.model.embed_tokens.as_linear(out)
else: else:

View File

@@ -124,10 +124,12 @@ class Transformer(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.wte(inputs) h = self.wte(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -152,9 +154,10 @@ class OlmoModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
return self.transformer(inputs, cache) return self.transformer(inputs, mask, cache)
class Model(nn.Module): class Model(nn.Module):
@@ -167,9 +170,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
return self.model(inputs, cache) return self.model(inputs, mask, cache)
@property @property
def layers(self): def layers(self):

View File

@@ -163,9 +163,11 @@ class LlamaModel(nn.Module):
self, self,
inputs: mx.array, inputs: mx.array,
cache=None, cache=None,
mask=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -190,8 +192,9 @@ class Model(nn.Module):
self, self,
inputs: mx.array, inputs: mx.array,
cache=None, cache=None,
mask=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, cache, mask)
if self.args.tie_word_embeddings: if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out) out = self.model.embed_tokens.as_linear(out)
else: else:

View File

@@ -178,10 +178,12 @@ class OpenELMModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.token_embeddings(inputs) h = self.token_embeddings(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -205,9 +207,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.transformer(inputs, cache) out = self.transformer(inputs, mask, cache)
if self.args.share_input_output_layers: if self.args.share_input_output_layers:
out = self.transformer.token_embeddings.as_linear(out) out = self.transformer.token_embeddings.as_linear(out)
else: else:

View File

@@ -143,9 +143,10 @@ class PhiModel(nn.Module):
config.hidden_size, eps=config.layer_norm_eps config.hidden_size, eps=config.layer_norm_eps
) )
def __call__(self, x, cache): def __call__(self, x, mask, cache):
x = self.embed_tokens(x) x = self.embed_tokens(x)
if mask is None:
mask = create_attention_mask(x, cache) mask = create_attention_mask(x, cache)
if cache is None: if cache is None:
@@ -167,9 +168,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
x: mx.array, x: mx.array,
mask: mx.array = None,
cache=None, cache=None,
) -> mx.array: ) -> mx.array:
y = self.model(x, cache) y = self.model(x, mask, cache)
return self.lm_head(y) return self.lm_head(y)
@property @property

View File

@@ -168,10 +168,12 @@ class Phi3Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -194,9 +196,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
return self.lm_head(out) return self.lm_head(out)
@property @property

View File

@@ -258,12 +258,14 @@ class Phi3Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if self.mup_embedding_multiplier: if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h h = self.mup_embedding_multiplier * h
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -290,9 +292,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out) out = self.model.embed_tokens.as_linear(out)
if self.mup_width_multiplier: if self.mup_width_multiplier:
out = out / self.mup_width_multiplier out = out / self.mup_width_multiplier

View File

@@ -155,10 +155,12 @@ class PhiMoEModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
) -> mx.array: ) -> mx.array:
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -181,9 +183,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
return self.lm_head(out) return self.lm_head(out)
def sanitize(self, weights): def sanitize(self, weights):

View File

@@ -175,6 +175,8 @@ class Model(nn.Module):
mask: mx.array = None, mask: mx.array = None,
cache=None, cache=None,
) -> mx.array: ) -> mx.array:
if mask is None:
mask = create_attention_mask(x, cache) mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache) y = self.transformer(x, mask, cache)

View File

@@ -174,9 +174,11 @@ class PlamoModel(nn.Module):
self, self,
inputs: mx.array, inputs: mx.array,
cache: Optional[Any] = None, cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array: ) -> mx.array:
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -202,8 +204,9 @@ class Model(nn.Module):
self, self,
inputs: mx.array, inputs: mx.array,
cache: Optional[Any] = None, cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array: ) -> mx.array:
out = self.model(inputs, cache) out = self.model(inputs, cache, mask)
return self.lm_head(out) return self.lm_head(out)
@property @property

View File

@@ -123,6 +123,7 @@ class QwenModel(nn.Module):
def __call__(self, inputs, mask=None, cache=None): def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs) x = self.wte(inputs)
if mask is None:
mask = create_attention_mask(x, cache) mask = create_attention_mask(x, cache)
if cache is None: if cache is None:

View File

@@ -149,10 +149,12 @@ class Qwen2Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -176,9 +178,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings: if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out) out = self.model.embed_tokens.as_linear(out)
else: else:

View File

@@ -187,10 +187,12 @@ class Qwen2MoeModel(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -213,9 +215,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
return self.lm_head(out) return self.lm_head(out)
def sanitize(self, weights): def sanitize(self, weights):

View File

@@ -389,6 +389,7 @@ class Griffin(nn.Module):
def __call__( def __call__(
self, self,
tokens, tokens,
mask: mx.array = None,
cache=None, cache=None,
): ):
x = self.embed_tokens(tokens) x = self.embed_tokens(tokens)
@@ -402,6 +403,7 @@ class Griffin(nn.Module):
if block.temporal_block_type != "recurrent": if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]] mask_cache = [cache[i]]
if mask is None:
mask = create_attention_mask(x, mask_cache) mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers): for i, block in enumerate(self.layers):
@@ -418,12 +420,12 @@ class Model(nn.Module):
self.model_type = config.model_type self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array: def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
""" """
Args: Args:
tokens: Sequence of input tokens. tokens: Sequence of input tokens.
""" """
logits = self.model(tokens, cache=cache) logits = self.model(tokens, mask=mask, cache=cache)
if "lm_head" in self: if "lm_head" in self:
logits = self.lm_head(logits) logits = self.lm_head(logits)
else: else:

View File

@@ -199,7 +199,10 @@ class Model(nn.Module):
mask: mx.array = None, mask: mx.array = None,
cache=None, cache=None,
) -> mx.array: ) -> mx.array:
if mask is None:
mask = create_attention_mask(x, cache) mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache) y = self.model(x, mask, cache)
return self.lm_head(y) return self.lm_head(y)

View File

@@ -125,10 +125,12 @@ class Starcoder2Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
@@ -152,9 +154,10 @@ class Model(nn.Module):
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
mask: mx.array = None,
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings: if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out) out = self.model.embed_tokens.as_linear(out)
else: else:

View File

@@ -1,4 +1,4 @@
mlx>=0.19.2 mlx>=0.22.0
numpy numpy
transformers[sentencepiece]>=4.39.3 transformers[sentencepiece]>=4.39.3
protobuf protobuf

View File

@@ -12,6 +12,7 @@ def make_sampler(
top_p: float = 0.0, top_p: float = 0.0,
min_p: float = 0.0, min_p: float = 0.0,
min_tokens_to_keep: int = 1, min_tokens_to_keep: int = 1,
top_k: int = -1,
) -> Callable[mx.array, mx.array]: ) -> Callable[mx.array, mx.array]:
""" """
Make a sampler function for use with ``generate_step``. Make a sampler function for use with ``generate_step``.
@@ -25,6 +26,8 @@ def make_sampler(
probability) that a token probability must have to be considered. probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling. be filtered by min_p sampling.
top_k (int, optional): The top k tokens ranked by probability to constrain
the sampling to.
Returns: Returns:
Callable[mx.array, mx.array]: Callable[mx.array, mx.array]:
@@ -36,6 +39,8 @@ def make_sampler(
return lambda x: top_p_sampling(x, top_p, temp) return lambda x: top_p_sampling(x, top_p, temp)
elif min_p != 0.0: elif min_p != 0.0:
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
elif top_k > 0:
return lambda x: top_k_sampling(x, top_k, temp)
else: else:
return lambda x: categorical_sampling(x, temp) return lambda x: categorical_sampling(x, temp)
@@ -79,6 +84,33 @@ def make_logits_processors(
return logits_processors return logits_processors
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_k_sampling(
logprobs: mx.array,
top_k: int,
temperature=1.0,
) -> mx.array:
"""
Sample from only the top K tokens ranked by probability.
Args:
logprobs: A vector of log probabilities.
top_k (int): Top k tokens to sample from.
"""
vocab_size = logprobs.shape[-1]
if not isinstance(top_k, int) or not (0 < top_k < vocab_size):
raise ValueError(
f"`top_k` has to be an integer in the (0, {vocab_size}] interval,"
f" but is {top_k}."
)
logprobs = logprobs * (1 / temperature)
mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:]
masked_logprobs = mx.put_along_axis(
logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1
)
return mx.random.categorical(masked_logprobs, axis=-1)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling( def min_p_sampling(
logprobs: mx.array, logprobs: mx.array,
@@ -87,7 +119,7 @@ def min_p_sampling(
temperature=1.0, temperature=1.0,
) -> mx.array: ) -> mx.array:
""" """
Apply min-p sampling to the logits. Apply min-p sampling to the logprobs.
Min-p keeps all tokens that are above a minimum probability, scaled by the Min-p keeps all tokens that are above a minimum probability, scaled by the
probability of the most likely token. As a result, the filter is more probability of the most likely token. As a result, the filter is more

View File

@@ -590,14 +590,10 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type # Determine response type
self.request_id = f"chatcmpl-{uuid.uuid4()}" self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if ( if self.tokenizer.chat_template:
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
):
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
body["messages"], body["messages"],
body.get("tools", None), body.get("tools", None),
tokenize=True,
add_generation_prompt=True, add_generation_prompt=True,
) )
else: else:

View File

@@ -266,6 +266,18 @@ class TokenizerWrapper:
else {tokenizer.eos_token_id} else {tokenizer.eos_token_id}
) )
def add_eos_token(self, token: str):
token_id = None
try:
token_id = int(token)
except ValueError:
token_id = self._tokenizer.convert_tokens_to_ids(token)
if token_id is None:
raise ValueError(f"'{token}' is not a token for this tokenizer")
self._eos_token_ids.add(token_id)
def __getattr__(self, attr): def __getattr__(self, attr):
if attr == "detokenizer": if attr == "detokenizer":
return self._detokenizer return self._detokenizer

View File

@@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@@ -10,41 +10,47 @@ class Dataset:
Light-weight wrapper to hold a dataset. Light-weight wrapper to hold a dataset.
""" """
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"): def __init__(
self._text_key = text_key self,
self._data = data data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
text_key: str = "text",
):
self._data = [tokenizer.encode(d[text_key]) for d in data]
for d in self._data:
if d[-1] != tokenizer.eos_token_id:
d.append(tokenizer.eos_token_id)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx][self._text_key] return self._data[idx]
def __len__(self): def __len__(self):
if self._data is None:
return 0
return len(self._data) return len(self._data)
class ChatDataset(Dataset): class ChatDataset:
""" """
A dataset for chat data in the format of {"messages": [...]} A dataset for chat data in the format of {"messages": [...]}
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data) self._data = [
self._tokenizer = tokenizer tokenizer.apply_chat_template(
d["messages"],
tools=d.get("tools", None),
)
for d in data
]
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
messages = self._data[idx]["messages"] return self._data[idx]
text = self._tokenizer.apply_chat_template(
messages, def __len__(self):
tools=self._data[idx].get("tools", None), return len(self._data)
tokenize=False,
add_generation_prompt=True,
)
return text
class CompletionsDataset(Dataset): class CompletionsDataset:
""" """
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values or using user-provided keys for prompt and completion values
@@ -55,36 +61,41 @@ class CompletionsDataset(Dataset):
self, self,
data: List[Dict[str, str]], data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt", prompt_key: str,
completion_key: str = "completion", completion_key: str,
): ):
super().__init__(data) self._data = [
self._tokenizer = tokenizer tokenizer.apply_chat_template(
self._prompt_key = prompt_key [
self._completion_key = completion_key {"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]},
],
)
for d in data
]
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
data = self._data[idx] return self._data[idx]
text = self._tokenizer.apply_chat_template(
[ def __len__(self):
{"role": "user", "content": data[self._prompt_key]}, return len(self._data)
{"role": "assistant", "content": data[self._completion_key]},
],
tokenize=False,
add_generation_prompt=True,
)
return text
def create_dataset(data, tokenizer: PreTrainedTokenizer = None): def create_dataset(
data,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
sample = data[0] sample = data[0]
if "messages" in sample: if "messages" in sample:
return ChatDataset(data, tokenizer) return ChatDataset(data, tokenizer)
elif "prompt" in sample and "completion" in sample: elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer) return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample: elif "text" in sample:
return Dataset(data) return Dataset(data, tokenizer)
else: else:
raise ValueError( raise ValueError(
"Unsupported data format, check the supported formats here:\n" "Unsupported data format, check the supported formats here:\n"
@@ -92,20 +103,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
) )
def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer): def load_local_dataset(
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
def load_subset(path): def load_subset(path):
if not path.exists(): if not path.exists():
return [] return []
with open(path, "r") as fid: with open(path, "r") as fid:
data = [json.loads(l) for l in fid] data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer) return create_dataset(data, tokenizer, prompt_feature, completion_feature)
names = ("train", "valid", "test") names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
return train, valid, test return train, valid, test
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): def load_hf_dataset(
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
from datasets import exceptions, load_dataset from datasets import exceptions, load_dataset
try: try:
@@ -114,7 +135,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test") names = ("train", "valid", "test")
train, valid, test = [ train, valid, test = [
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else [] (
create_dataset(
dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys()
else []
)
for n in names for n in names
] ]
@@ -143,7 +170,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
if prompt_feature and completion_feature: if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature: elif text_feature:
return Dataset(train_ds, text_key=text_feature) return Dataset(train_ds, tokenizer, text_key=text_feature)
else: else:
raise ValueError( raise ValueError(
"Specify either a prompt and completion feature or a text " "Specify either a prompt and completion feature or a text "
@@ -166,15 +193,22 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
def load_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None: if getattr(args, "hf_dataset", False):
train, valid, test = load_custom_hf_dataset(args, tokenizer) train, valid, test = load_custom_hf_dataset(args, tokenizer)
else: else:
data_path = Path(args.data) data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists(): if data_path.exists():
train, valid, test = load_local_dataset(data_path, tokenizer) train, valid, test = load_local_dataset(
data_path, tokenizer, prompt_feature, completion_feature
)
else: else:
print(f"Loading Hugging Face dataset {args.data}.") print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args.data, tokenizer) train, valid, test = load_hf_dataset(
args.data, tokenizer, prompt_feature, completion_feature
)
if args.train and len(train) == 0: if args.train and len(train) == 0:
raise ValueError( raise ValueError(

View File

@@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
while True: while True:
indices = np.random.permutation(len(batch_idx)) indices = np.random.permutation(len(batch_idx))
for i in indices: for i in indices:
# Encode batch batch = [dataset[j] for j in batch_idx[i]]
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length: if max(lengths) > max_seq_length:
print( print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "

View File

@@ -2,10 +2,12 @@
import contextlib import contextlib
import copy import copy
import functools
import glob import glob
import importlib import importlib
import json import json
import logging import logging
import os
import shutil import shutil
import time import time
from dataclasses import dataclass from dataclasses import dataclass
@@ -15,7 +17,17 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
try:
from modelscope import snapshot_download
except ImportError:
raise ImportError(
"Please run `pip install modelscope` to activate the ModelScope."
)
else:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@@ -153,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
Path: The path to the model. Path: The path to the model.
""" """
model_path = Path(path_or_hf_repo) model_path = Path(path_or_hf_repo)
if not model_path.exists(): if not model_path.exists():
try: try:
model_path = Path( model_path = Path(
snapshot_download( snapshot_download(
repo_id=path_or_hf_repo, path_or_hf_repo,
revision=revision, revision=revision,
allow_patterns=[ allow_patterns=[
"*.json", "*.json",
@@ -207,12 +220,6 @@ def generate_step(
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None, prompt_progress_callback: Optional[Callable[int, int]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@@ -256,24 +263,16 @@ def generate_step(
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")
if temp is not None or top_p is not None or min_tokens_to_keep is not None: prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
print(
"[Warning] Specifying sampling arguments to ``generate_step`` is " quantize_cache_fn = functools.partial(
"deprecated. Pass in a ``sampler`` instead." maybe_quantize_kv_cache,
) quantized_kv_start=quantized_kv_start,
if repetition_penalty is not None: kv_group_size=kv_group_size,
print( kv_bits=kv_bits,
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
"Pass in ``logits_processors`` instead."
) )
sampler = sampler or make_sampler( sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
)
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
def _step(y): def _step(y):
with mx.stream(generation_stream): with mx.stream(generation_stream):
@@ -287,9 +286,7 @@ def generate_step(
for processor in logits_processors: for processor in logits_processors:
logits = processor(tokens, logits) logits = processor(tokens, logits)
maybe_quantize_kv_cache( quantize_cache_fn(prompt_cache)
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
logprobs = logits - mx.logsumexp(logits, keepdims=True) logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs) y = sampler(logprobs)
@@ -300,9 +297,7 @@ def generate_step(
prompt_processed_tokens = 0 prompt_processed_tokens = 0
while y.size > prefill_step_size: while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache) model(y[:prefill_step_size][None], cache=prompt_cache)
maybe_quantize_kv_cache( quantize_cache_fn(prompt_cache)
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
mx.eval([c.state for c in prompt_cache]) mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size prompt_processed_tokens += prefill_step_size
@@ -329,10 +324,162 @@ def generate_step(
n += 1 n += 1
def speculative_generate_step(
prompt: mx.array,
model: nn.Module,
draft_model: nn.Module,
*,
num_draft_tokens=2,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
draft_model (nn.Module): The draft model for speculative decoding.
num_draft_tokens (int, optional): The number of draft tokens for
speculative decoding. Default: ``2``.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. The cache must be trimmable.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
y = prompt
tokens = None
# Create the KV cache for generation
if prompt_cache is None:
model_cache = cache.make_prompt_cache(model)
draft_cache = cache.make_prompt_cache(draft_model)
elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)):
raise ValueError("Wrong number of layers in the prompt cache.")
else:
model_cache = prompt_cache[: len(model.layers)]
draft_cache = prompt_cache[len(model.layers) :]
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
def _step(model, cache, y, n_predict=1):
with mx.stream(generation_stream):
logits = model(y[None], cache=cache)
logits = logits[:, -n_predict:, :]
quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs).squeeze(0)
return y, logprobs.squeeze(0)
def _prefill(model, cache, y):
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
quantize_cache_fn(cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
mx.metal.clear_cache()
return y
def _rewind_cache(num_draft, num_accept):
cache.trim_prompt_cache(model_cache, num_draft - num_accept)
cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0))
def _draft_generate(y, num_draft):
if num_draft == 0:
return mx.array([], mx.uint32)
ys = []
for _ in range(num_draft):
y, _ = _step(draft_model, draft_cache, y)
mx.async_eval(y)
ys.append(y)
return mx.concatenate(ys)
with mx.stream(generation_stream):
draft_y = _prefill(draft_model, draft_cache, y)
y = _prefill(model, model_cache, y)
ntoks = 0
# Set these so the finally block doesn't raise
num_draft = 0
n = 0
try:
while True:
num_draft = min(max_tokens - ntoks, num_draft_tokens)
draft_tokens = _draft_generate(draft_y, num_draft)
y = mx.concatenate([y, draft_tokens])
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
mx.eval(tokens, draft_tokens)
draft_tokens = draft_tokens.tolist()
tokens = tokens.tolist()
n = 0
while n < num_draft:
tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n]
if tn != dtn:
break
n += 1
ntoks += 1
yield tn, lpn
if ntoks == max_tokens:
break
if ntoks < max_tokens:
ntoks += 1
yield tokens[n], logprobs[n]
if ntoks == max_tokens:
break
y = mx.array([tokens[n]], mx.uint32)
draft_y = y
# If we accpeted all the draft tokens, include the last
# draft token in the next draft step since it hasn't been
# processed yet by the draft model
if n == num_draft:
draft_y = mx.concatenate(
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
)
_rewind_cache(num_draft, n)
finally:
_rewind_cache(num_draft, n)
def stream_generate( def stream_generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]], prompt: Union[str, mx.array, List[int]],
draft_model: Optional[nn.Module] = None,
**kwargs, **kwargs,
) -> Generator[GenerationResponse, None, None]: ) -> Generator[GenerationResponse, None, None]:
""" """
@@ -341,7 +488,11 @@ def stream_generate(
Args: Args:
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. prompt (Union[str, mx.array, List[int]]): The input prompt string or
integer tokens.
draft_model (Optional[nn.Module]): An optional draft model. If provided
then speculative decoding is used. The draft model must use the same
tokenizer as the main model. Default: ``None``.
kwargs: The remaining options get passed to :func:`generate_step`. kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details. See :func:`generate_step` for more details.
@@ -353,16 +504,28 @@ def stream_generate(
tokenizer = TokenizerWrapper(tokenizer) tokenizer = TokenizerWrapper(tokenizer)
if not isinstance(prompt, mx.array): if not isinstance(prompt, mx.array):
prompt = mx.array( if isinstance(prompt, str):
prompt if isinstance(prompt, list) else tokenizer.encode(prompt) # Try to infer if special tokens are needed
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
) )
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
prompt = mx.array(prompt)
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
if draft_model is None:
kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs)
else:
kwargs.pop("max_kv_size", None)
token_generator = speculative_generate_step(
prompt, model, draft_model, **kwargs
)
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
tic = time.perf_counter() tic = time.perf_counter()
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): for n, (token, logprobs) in enumerate(token_generator):
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time prompt_tps = prompt.size / prompt_time
@@ -401,7 +564,7 @@ def stream_generate(
def generate( def generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str, prompt: Union[str, List[int]],
verbose: bool = False, verbose: bool = False,
formatter: Optional[Callable] = None, formatter: Optional[Callable] = None,
**kwargs, **kwargs,
@@ -412,7 +575,7 @@ def generate(
Args: Args:
model (nn.Module): The language model. model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt. prompt (Union[str, List[int]]): The input prompt string or integer tokens.
verbose (bool): If ``True``, print tokens and timing information. verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``. Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`. kwargs: The remaining options get passed to :func:`stream_generate`.
@@ -425,7 +588,6 @@ def generate(
) )
if verbose: if verbose:
print("=" * 10) print("=" * 10)
print("Prompt:", prompt)
text = "" text = ""
for response in stream_generate(model, tokenizer, prompt, **kwargs): for response in stream_generate(model, tokenizer, prompt, **kwargs):
@@ -558,7 +720,7 @@ def load(
Defaults to an empty dictionary. Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``. to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If ``False`` eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
Returns: Returns:
@@ -654,10 +816,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
prompt = "hello" prompt = "hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: if tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}] messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=True
) )
response = generate(model, tokenizer, prompt=prompt, verbose=True) response = generate(model, tokenizer, prompt=prompt, verbose=True)
@@ -670,12 +832,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
api = HfApi() api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True) api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder( api.upload_large_folder(
folder_path=path, folder_path=path,
repo_id=upload_repo, repo_id=upload_repo,
repo_type="model", repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
) )
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.") print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")

View File

@@ -27,8 +27,8 @@ setup(
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
python_requires=">=3.8", python_requires=">=3.8",
extras_require={ extras_require={
"testing": ["datasets"], "test": ["datasets"],
"evaluation": ["lm-eval"], "evaluate": ["lm-eval", "tqdm"],
}, },
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [

View File

@@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase):
data = {"text": "This is an example for the model."} data = {"text": "This is an example for the model."}
self.save_data(4 * [data]) self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
train, valid, test = datasets.load_dataset(args, None) tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertEqual(len(train), 4) self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4) self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0) self.assertEqual(len(test), 0)
@@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase):
"name": "billsum", "name": "billsum",
"prompt_feature": "text", "prompt_feature": "text",
"completion_feature": "summary", "completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
}, },
test=False, test=False,
train=True, train=True,

View File

@@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_map from mlx.utils import tree_map
from mlx_lm.models import rope_utils from mlx_lm.models import rope_utils
from mlx_lm.models.base import create_causal_mask
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
@@ -128,6 +129,22 @@ class TestModels(unittest.TestCase):
self.assertEqual(cache.offset, 22) self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :])) self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def test_causal_mask_lengths(self):
mx.random.seed(8)
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
lengths = mx.array([1, 2, 3, 1])
q = mx.random.uniform(shape=(B, N_q, T_q, D))
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
v = k
mask = create_causal_mask(T_q, 0, lengths=lengths)
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
k[1, :, 2:] = mx.ones_like(k[1, :, 2:])
v[1, :, 2:] = mx.ones_like(v[1, :, 2:])
out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2]))
def test_rope(self): def test_rope(self):
rope = rope_utils.initialize_rope(32, base=100, traditional=False) rope = rope_utils.initialize_rope(32, base=100, traditional=False)
self.assertTrue(isinstance(rope, nn.RoPE)) self.assertTrue(isinstance(rope, nn.RoPE))
@@ -162,7 +179,13 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.dtype, t) self.assertEqual(outputs.dtype, t)
cache = make_prompt_cache(model) cache = make_prompt_cache(model)
outputs = model(inputs, cache) outputs = model(inputs, cache=cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
if model_type != "mamba":
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
outputs = model(inputs, mask=mask)
self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t) self.assertEqual(outputs.dtype, t)
@@ -659,6 +682,43 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers model, args.model_type, args.vocab_size, args.num_hidden_layers
) )
def test_deepseek_v3(self):
from mlx_lm.models import deepseek_v3
args = deepseek_v3.ModelArgs(
model_type="deepseek_v3",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
n_routed_experts=4,
n_group=2,
topk_group=1,
num_experts_per_tok=2,
n_shared_experts=1,
kv_lora_rank=4,
q_lora_rank=4,
qk_rope_head_dim=32,
v_head_dim=16,
qk_nope_head_dim=32,
rope_scaling={
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
},
)
model = deepseek_v3.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma2(self): def test_gemma2(self):
from mlx_lm.models import gemma2 from mlx_lm.models import gemma2

View File

@@ -1,7 +1,7 @@
import unittest import unittest
import mlx.core as mx import mlx.core as mx
from mlx_lm.sample_utils import min_p_sampling, top_p_sampling from mlx_lm.sample_utils import min_p_sampling, top_k_sampling, top_p_sampling
class TestSampleUtils(unittest.TestCase): class TestSampleUtils(unittest.TestCase):
@@ -42,6 +42,27 @@ class TestSampleUtils(unittest.TestCase):
token = min_p_sampling(logits, 0.05) token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3)) self.assertTrue(token in (0, 3))
def test_top_k_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
token = top_k_sampling(logits, 1).item()
self.assertEqual(token, 0)
probs = mx.array([0.5, 0.0, 0.0, 0.5])[None]
tokens = set()
for _ in range(100):
token = top_k_sampling(logits, 2)
tokens.add(token.item())
self.assertEqual(tokens, {0, 3})
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = top_k_sampling(logits, 1)
self.assertEqual(tokens.tolist(), [0, 1])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -17,7 +17,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
self.config = args self.config = args
self.custom_attribute = "This is a custom model" self.custom_attribute = "This is a custom model"
def load_weights(self, weights): def load_weights(self, weights, **kwargs):
self.qwenWeights = weights self.qwenWeights = weights
class CustomQwenConfig: class CustomQwenConfig: