Merge branch 'main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez 2024-12-10 14:32:44 +01:00 committed by GitHub
commit ddad2105ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1579 additions and 414 deletions

View File

@ -289,4 +289,4 @@ if __name__ == "__main__":
tic = time.time() tic = time.time()
save_adapters("final_adapters.safetensors", flux, args) save_adapters("final_adapters.safetensors", flux, args)
print(f"Training successful. Saved final weights to {args.adapter_file}.") print("Training successful.")

View File

@ -85,6 +85,8 @@ class Flux(nn.Module):
def sanitize(self, weights): def sanitize(self, weights):
new_weights = {} new_weights = {}
for k, w in weights.items(): for k, w in weights.items():
if k.startswith("model.diffusion_model."):
k = k[22:]
if k.endswith(".scale"): if k.endswith(".scale"):
k = k[:-6] + ".weight" k = k[:-6] + ".weight"
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]: for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:

View File

@ -7,7 +7,7 @@ import mlx.core as mx
class FluxSampler: class FluxSampler:
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5): def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.15):
self._base_shift = base_shift self._base_shift = base_shift
self._max_shift = max_shift self._max_shift = max_shift
self._schnell = "schnell" in name self._schnell = "schnell" in name
@ -25,7 +25,7 @@ class FluxSampler:
): ):
t = mx.linspace(start, stop, num_steps + 1) t = mx.linspace(start, stop, num_steps + 1)
if self._schnell: if not self._schnell:
t = self._time_shift(image_sequence_length, t) t = self._time_shift(image_sequence_length, t)
return t.tolist() return t.tolist()
@ -50,6 +50,7 @@ class FluxSampler:
if noise is not None if noise is not None
else mx.random.normal(x.shape, dtype=x.dtype, key=key) else mx.random.normal(x.shape, dtype=x.dtype, key=key)
) )
t = t.reshape([-1] + [1] * (x.ndim - 1))
return x * (1 - t) + t * noise return x * (1 - t) + t * noise
def step(self, pred, x_t, t, t_prev): def step(self, pred, x_t, t, t_prev):

View File

@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
response = generate(model, tokenizer, prompt=prompt, verbose=True) text = generate(model, tokenizer, prompt=prompt, verbose=True)
``` ```
To see a description of all the arguments you can do: To see a description of all the arguments you can do:
@ -77,7 +77,7 @@ to see how to use the API in more detail.
The `mlx-lm` package also comes with functionality to quantize and optionally The `mlx-lm` package also comes with functionality to quantize and optionally
upload models to the Hugging Face Hub. upload models to the Hugging Face Hub.
You can convert models in the Python API with: You can convert models using the Python API:
```python ```python
from mlx_lm import convert from mlx_lm import convert
@ -100,8 +100,9 @@ To see a description of all the arguments you can do:
#### Streaming #### Streaming
For streaming generation, use the `stream_generate` function. This returns a For streaming generation, use the `stream_generate` function. This yields
generator object which streams the output text, token, and log probabilities. a generation response object.
For example, For example,
```python ```python
@ -117,8 +118,8 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True) print(response.text, end="", flush=True)
print() print()
``` ```
@ -162,6 +163,10 @@ mlx_lm.convert \
--upload-repo mlx-community/my-4bit-mistral --upload-repo mlx-community/my-4bit-mistral
``` ```
Models can also be converted and quantized directly in the
[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
Face Space.
### Long Prompts and Generations ### Long Prompts and Generations
`mlx-lm` has some tools to scale efficiently to long prompts and generations: `mlx-lm` has some tools to scale efficiently to long prompts and generations:

View File

@ -92,7 +92,7 @@ curl localhost:8080/v1/chat/completions \
- `system_fingerprint`: A unique identifier for the system. - `system_fingerprint`: A unique identifier for the system.
- `object`: Any of "chat.completions", "chat.completions.chunk" (for - `object`: Any of "chat.completion", "chat.completion.chunk" (for
streaming), or "text.completion". streaming), or "text.completion".
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`). - `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.19.3" __version__ = "0.20.2"

View File

@ -8,7 +8,7 @@ import time
import mlx.core as mx import mlx.core as mx
from .models.cache import make_prompt_cache, save_prompt_cache from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load, maybe_quantize_kv_cache from .utils import generate_step, load
DEFAULT_QUANTIZED_KV_START = 5000 DEFAULT_QUANTIZED_KV_START = 5000
@ -50,12 +50,6 @@ def setup_arg_parser():
action="store_true", action="store_true",
help="Use the default chat template", help="Use the default chat template",
) )
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
)
parser.add_argument( parser.add_argument(
"--max-kv-size", "--max-kv-size",
type=int, type=int,
@ -99,9 +93,6 @@ def main():
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config # Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None: if args.eos_token is not None:
@ -144,26 +135,28 @@ def main():
y = mx.array(tokenizer.encode(prompt)) y = mx.array(tokenizer.encode(prompt))
# Process the prompt # Process the prompt
processed = 0
step_size = 512
start = time.time() start = time.time()
max_msg_len = 0 max_msg_len = 0
while y.size > 0:
model(y[:step_size][None], cache=cache) def callback(processed, total_tokens):
mx.eval([c.state for c in cache])
mx.metal.clear_cache()
processed += min(y.size, step_size)
y = y[step_size:]
current = time.time() current = time.time()
speed = processed / (current - start) speed = processed / (current - start)
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
nonlocal max_msg_len
max_msg_len = max(max_msg_len, len(msg)) max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
maybe_quantize_kv_cache( for _ in generate_step(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits y,
) model,
max_tokens=0,
prompt_cache=cache,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
prompt_progress_callback=callback,
):
pass
print() print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")

View File

@ -5,7 +5,8 @@ import json
import mlx.core as mx import mlx.core as mx
from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache from .models.cache import make_prompt_cache
from .sample_utils import make_sampler
from .utils import load, stream_generate from .utils import load, stream_generate
DEFAULT_TEMP = 0.0 DEFAULT_TEMP = 0.0
@ -74,16 +75,15 @@ def main():
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for response, *_ in stream_generate( for response in stream_generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,
args.max_tokens, args.max_tokens,
temp=args.temp, sampler=make_sampler(args.temp, args.top_p),
top_p=args.top_p,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
): ):
print(response, flush=True, end="") print(response.text, flush=True, end="")
print() print()

355
llms/mlx_lm/evaluate.py Normal file
View File

@ -0,0 +1,355 @@
# Adapted from a PyTorch implementation by David Grangier
import argparse
import json
import logging
import os
from importlib.metadata import version
from pathlib import Path
from typing import Optional
import lm_eval
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
from .models.cache import make_prompt_cache
from .utils import load, stream_generate
PAD = 0
def _len_longest_common_prefix(a, b):
l = 0
for item_a, item_b in zip(a, b):
if item_a != item_b:
break
l += 1
return l
def _rstrip_until(s, untils):
"""Limit a string <s> to the first occurence of any substring in untils."""
l = len(s)
f = [s.find(u) for u in untils]
f = [l if x < 0 else x for x in f]
return s[: min(f)]
def _pad_inputs(
inputs,
maxlen,
genlen=0,
pad_left=False,
pad_multiple=32,
truncate=False,
):
# pad the prompts to the left with at least genlen tokens.
actual_maxlen = max(len(p) for p in inputs) + genlen
if actual_maxlen > maxlen:
if not truncate:
raise ValueError("Inputs are too long.")
else: # drop begining
actual_maxlen = maxlen
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
if pad_multiple > 0:
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
maxlen *= pad_multiple
assert PAD == 0
lr = np.array((1, 0) if pad_left else (0, 1))
return np.stack(
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
axis=0,
)
@register_model("mlxlm")
class MLXLM(LM):
def __init__(
self,
path_or_hf_repo: str,
batch_size: int = 16,
max_tokens: Optional[int] = None,
) -> None:
super().__init__()
self._batch_size = batch_size
self._model, self._tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self._tokenizer.model_max_length
def _score_fn(self, inputs, tokenize=True, step_size=32):
if tokenize:
inputs = self._tokenizer.encode(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:]
cache = make_prompt_cache(self._model)
mask = targets != PAD
scores, is_greedy = [], []
for i in range(0, inputs.shape[1], step_size):
logits = self._model(inputs[:, i : i + step_size], cache=cache)
log_probs = nn.log_softmax(logits.astype(mx.float32))
score = mx.take_along_axis(
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
)[..., 0]
ig = mask[:, i : i + step_size] * (
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
)
mx.eval(score, ig)
mx.metal.clear_cache()
is_greedy.append(ig)
scores.append(score)
scores = mx.concatenate(scores, axis=1)
is_greedy = mx.concatenate(is_greedy, axis=1)
return scores, mask.sum(axis=-1), is_greedy
def _loglikelihood(self, texts, score_spans=None, tokenize=True):
# sort by length to get batches with little padding.
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
sorted_spans = None
if score_spans is not None:
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
results = []
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
batch = sorted_inputs[i : i + self._batch_size]
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
for j in range(len(batch)):
if sorted_spans is None: # full sequence score
mask = mx.arange(scores[j].shape[-1]) < length
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
else: # subsequence score
start, end = sorted_spans[i + j]
score = scores[j][start:end].astype(mx.float32).sum()
ig = is_greedy[j][start:end].astype(mx.int32).sum()
length = end - start
results.append((score.item(), ig.item(), length))
# reorder the outputs
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(results))]
return results
def _tokenize(self, texts):
return [tuple(self._tokenizer.encode(t)) for t in texts]
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
:param requests: list[Instance]
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
`context: str`
Context string. Implementations of LM must be able to handle an
empty context string.
`continuation: str`
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
`logprob: float`
The log probability of `continuation`.
`isgreedy`:
Whether `continuation` would be generated by greedy sampling from `context`.
"""
logging.info("Estimating loglikelihood for %d pairs." % len(requests))
# tokenize prefix and prefix + completion for all requests.
tokenized = self._tokenize(
[t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]]
)
# max length (prefix + completion) and longest common prefix per question.
length_stats = {}
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8))
length_stats[prefix] = (
max(max_completed_l, len(completed)),
min(min_prefix_l, _len_longest_common_prefix(prefix, completed)),
)
# truncate requests for completed sequences longer than model context.
shortened = []
completion_spans = []
long_completions = 0
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
max_completed_l, prefix_l = length_stats[prefix]
# compute truncation length
truncation = max(0, max_completed_l - self._max_tokens - 1)
prefix_l = prefix_l - truncation
if prefix_l <= 0:
# completion too long, prefix is eliminated for some requests.
long_completions += 1
truncation = max(0, len(completed) - self._max_tokens - 1)
prefix_l = 1
# truncate the completed sequence
completed = completed[truncation:]
shortened.append(completed)
# scores do not include initial bos, substract 1 to span bounds
completion_spans.append((prefix_l - 1, len(completed) - 1))
if long_completions > 0:
logging.info(
f"Prefix eliminated for {long_completions} requests with "
+ "completion longer than context."
)
# model scoring, returns num_requests x (logp, is_greedy, length).
results = self._loglikelihood(
shortened,
score_spans=completion_spans,
tokenize=False,
)
return [(r[0], r[1] == r[2]) for r in results]
def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: EOT
Max context length: 4
Resulting input/prediction pairs:
INPUT: EOT 0 1 2
PRED: 0 1 2 3
INPUT: 3 4 5 6
PRED: 4 5 6 7
INPUT: 5 6 7 8
PRED: 8 9
Observe that:
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context,).
string: str
String for which we are computing overall loglikelihood
:return: list[tuple[float]]
A list of tuples (logprob,)
logprob: float
The log probability of `context` conditioned on the EOT token.
"""
logging.info(
"Estimating loglikelihood rolling for %d sequences." % len(requests)
)
inputs = [req.args[0] for req in requests]
return [t[0] for t in self._loglikelihood(inputs)]
def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, until).
context: str
Context string
until: [str]
The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token.
:return: list[str]
A list of strings continuation
continuation: str
The generated continuation.
"""
logging.info("Generating continuation for %d sequences." % len(requests))
contexts, options = zip(*[req.args for req in requests])
# contrary to the doc the second element of the tuple contains
# {'do_sample': False, 'until': ['\n\n'], 'temperature': 0}
keys = list(options[0].keys())
assert "until" in keys
untils = [x["until"] for x in options]
completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
if (
hasattr(self._tokenizer, "apply_chat_template")
and self._tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": context}]
context = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
max_tokens = min(
self._max_tokens,
self._tokenizer.model_max_length - len(self._tokenizer.encode(context)),
)
text = ""
for response in stream_generate(
self._model, self._tokenizer, prompt=context, max_tokens=max_tokens
):
text += response.text
if any(u in text for u in until):
text = _rstrip_until(text, until)
completions.append(text)
break
else:
completions.append(text)
return completions
def main():
parser = argparse.ArgumentParser(
"Evaluate an MLX model using lm-evaluation-harness."
)
parser.add_argument("--model", help="Model to evaluate", required=True)
parser.add_argument("--tasks", nargs="+", required=True)
parser.add_argument(
"--output-dir", default=".", help="Output directory for result files."
)
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
parser.add_argument("--num-shots", type=int, default=0, help="Number of shots")
parser.add_argument(
"--max-tokens",
type=int,
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
)
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
args = parser.parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Silence tokenizer warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
mx.random.seed(args.seed)
lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens)
results = lm_eval.simple_evaluate(
model=lm,
tasks=args.tasks,
num_fewshot=args.num_shots,
random_seed=args.seed,
numpy_random_seed=args.seed,
torch_random_seed=args.seed,
fewshot_random_seed=args.seed,
)
model_name = args.model.replace("/", "_")
task_names = "_".join(args.tasks)
ver = version("lm_eval")
filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json"
output_path = output_dir / filename
output_path.write_text(json.dumps(results["results"], indent=4))
print("Results:")
for result in results["results"].values():
print(json.dumps(result, indent=4))

View File

@ -42,7 +42,6 @@ response = generate(
tokenizer, tokenizer,
prompt=prompt, prompt=prompt,
verbose=True, verbose=True,
temp=0.0,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
) )

View File

@ -23,14 +23,6 @@ max_tokens = 1_000
# Specify if tokens and timing information will be printed # Specify if tokens and timing information will be printed
verbose = True verbose = True
# Some optional arguments for causal language model generation
generation_args = {
"temp": 0.7,
"repetition_penalty": 1.2,
"repetition_context_size": 20,
"top_p": 0.95,
}
# Generate a response with the specified settings # Generate a response with the specified settings
response = generate( response = generate(
model=model, model=model,
@ -38,5 +30,4 @@ response = generate(
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
verbose=verbose, verbose=verbose,
**generation_args,
) )

View File

@ -7,6 +7,7 @@ import sys
import mlx.core as mx import mlx.core as mx
from .models.cache import QuantizedKVCache, load_prompt_cache from .models.cache import QuantizedKVCache, load_prompt_cache
from .sample_utils import make_sampler
from .utils import generate, load from .utils import generate, load
DEFAULT_PROMPT = "hello" DEFAULT_PROMPT = "hello"
@ -41,17 +42,17 @@ def setup_arg_parser():
type=str, type=str,
help="Optional path for the trained adapter weights and config.", help="Optional path for the trained adapter weights and config.",
) )
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument( parser.add_argument(
"--eos-token", "--eos-token",
type=str, type=str,
default=None, default=None,
help="End of sequence token for tokenizer", help="End of sequence token for tokenizer",
) )
parser.add_argument(
"--system-prompt",
default=None,
help="System prompt to be used for the chat template",
)
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
"-p", "-p",
@ -76,7 +77,7 @@ def setup_arg_parser():
) )
parser.add_argument( parser.add_argument(
"--min-tokens-to-keep", "--min-tokens-to-keep",
type=float, type=int,
default=DEFAULT_MIN_TOKENS_TO_KEEP, default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.", help="Minimum tokens to keep for min-p sampling.",
) )
@ -97,11 +98,6 @@ def setup_arg_parser():
default=True, default=True,
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
) )
parser.add_argument(
"--colorize",
action="store_true",
help="Colorize output based on T[0] probability",
)
parser.add_argument( parser.add_argument(
"--max-kv-size", "--max-kv-size",
type=int, type=int,
@ -137,33 +133,6 @@ def setup_arg_parser():
return parser return parser
def colorprint(color, s):
color_codes = {
"black": 30,
"red": 31,
"green": 32,
"yellow": 33,
"blue": 34,
"magenta": 35,
"cyan": 36,
"white": 39,
}
ccode = color_codes.get(color, 30)
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
def colorprint_by_t0(s, t0):
if t0 > 0.95:
color = "white"
elif t0 > 0.70:
color = "green"
elif t0 > 0.30:
color = "yellow"
else:
color = "red"
colorprint(color, s)
def main(): def main():
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
@ -191,7 +160,6 @@ def main():
tokenizer_config = ( tokenizer_config = (
{} if not using_cache else json.loads(metadata["tokenizer_config"]) {} if not using_cache else json.loads(metadata["tokenizer_config"])
) )
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None: if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token tokenizer_config["eos_token"] = args.eos_token
@ -224,12 +192,16 @@ def main():
hasattr(tokenizer, "apply_chat_template") hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None and tokenizer.chat_template is not None
): ):
messages = [ if args.system_prompt is not None:
messages = [{"role": "system", "content": args.system_prompt}]
else:
messages = []
messages.append(
{ {
"role": "user", "role": "user",
"content": sys.stdin.read() if args.prompt == "-" else args.prompt, "content": sys.stdin.read() if args.prompt == "-" else args.prompt,
} }
] )
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
@ -237,8 +209,9 @@ def main():
# Treat the prompt as a suffix assuming that the prefix is in the # Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache. # stored kv cache.
if using_cache: if using_cache:
messages[-1]["content"] = "<query>"
test_prompt = tokenizer.apply_chat_template( test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}], messages,
tokenize=False, tokenize=False,
add_generation_prompt=True, add_generation_prompt=True,
) )
@ -246,21 +219,14 @@ def main():
else: else:
prompt = args.prompt prompt = args.prompt
if args.colorize and not args.verbose: sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
response = generate( response = generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,
args.max_tokens, max_tokens=args.max_tokens,
verbose=args.verbose, verbose=args.verbose,
formatter=formatter, sampler=sampler,
temp=args.temp,
top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits, kv_bits=args.kv_bits,

View File

@ -0,0 +1,163 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_layers: int
intermediate_size: int
num_attention_heads: int
vocab_size: int
rope_theta: float
layer_norm_epsilon: float
num_key_value_heads: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
attention_bias: bool = False
mlp_bias: bool = False
class AttentionModule(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or (dim // n_heads)
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__(
self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None
) -> mx.array:
B, L, D = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
q = self.rope(q, offset=cache.offset)
k = self.rope(k, offset=cache.offset)
k, v = cache.update_and_fetch(k, v)
else:
q = self.rope(q)
k = self.rope(k)
out = scaled_dot_product_attention(
q, k, v, cache=cache, scale=self.scale, mask=mask
)
out = out.transpose(0, 2, 1, 3).reshape(B, L, D)
return self.out_proj(out)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention = AttentionModule(args)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias)
def __call__(self, x: mx.array) -> mx.array:
return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = Attention(args)
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = MLP(args)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = x + self.attn.attention(self.ln_1(x), mask, cache)
out = h + self.mlp(self.ln_2(h))
return out
class ExaoneModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [TransformerBlock(args) for _ in range(args.num_layers)]
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
h = layer(h, mask, cache=c)
return self.ln_f(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = ExaoneModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.transformer(inputs, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.transformer.h

View File

@ -0,0 +1,291 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
attention_bias: bool
moe_topk: int
num_experts: int
num_shared_expert: int
use_mixed_mlp_moe: bool
use_qk_norm: bool
rms_norm_eps: float
rope_theta: float
use_cla: bool
cla_share_factor: 2
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
class DynamicNTKAlphaRoPE(nn.Module):
def __init__(
self,
dims: int,
base: float = 10000,
scaling_alpha: float = 1.0,
):
super().__init__()
self.dims = dims
base = base * scaling_alpha ** (dims / (dims - 2))
self._freqs = base ** (mx.arange(0, self.dims, 2) / self.dims)
def __call__(self, x, offset: int = 0):
return mx.fast.rope(
x,
self.dims,
traditional=False,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
class Attention(nn.Module):
def __init__(self, kv_proj: bool, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
if kv_proj:
self.k_proj = nn.Linear(
dim, n_kv_heads * head_dim, bias=args.attention_bias
)
self.v_proj = nn.Linear(
dim, n_kv_heads * head_dim, bias=args.attention_bias
)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
self.use_qk_norm = args.use_qk_norm
if self.use_qk_norm:
self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
self.rope = DynamicNTKAlphaRoPE(
head_dim,
base=args.rope_theta,
scaling_alpha=args.rope_scaling["alpha"],
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
kv_states=None,
) -> mx.array:
B, L, D = x.shape
queries = self.q_proj(x)
if kv_states is None:
keys, values = self.k_proj(x), self.v_proj(x)
kv_states = keys, values
else:
keys, values = kv_states
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
offset = cache.offset if cache else 0
queries = self.rope(queries, offset=offset)
keys = self.rope(keys, offset=offset)
if self.use_qk_norm:
queries = self.query_layernorm(queries)
keys = self.key_layernorm(keys)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), kv_states
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class Gate(nn.Module):
def __init__(self, dim, num_experts):
super().__init__()
self.wg = nn.Linear(dim, num_experts, bias=False)
def __call__(self, x) -> mx.array:
return self.wg(x)
class MoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
intermediate_size = args.intermediate_size
self.use_shared_mlp = args.use_mixed_mlp_moe
if args.use_mixed_mlp_moe:
self.shared_mlp = MLP(dim, intermediate_size * args.num_shared_expert)
self.num_experts = num_experts = args.num_experts
self.top_k = args.moe_topk
self.gate = Gate(dim, num_experts)
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
def __call__(
self,
x: mx.array,
):
gates = self.gate(x)
gates = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.use_shared_mlp:
shared_expert_output = self.shared_mlp(x)
y = y + shared_expert_output
return y
class DecoderLayer(nn.Module):
def __init__(self, args: ModelArgs, kv_proj: bool):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = Attention(kv_proj, args)
self.mlp = MoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
shared_kv_states: Optional[Tuple[mx.array, mx.array]] = None,
):
r, shared_kv_states = self.self_attn(
self.input_layernorm(x), mask, cache, shared_kv_states
)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, shared_kv_states
class HunYuanModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0)
for i in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for i, (layer, c) in enumerate(zip(self.layers, cache)):
if i % self.args.cla_share_factor == 0:
shared_kv_states = None
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = HunYuanModel(args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights):
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

View File

@ -7,6 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass @dataclass
@ -32,117 +33,6 @@ class ModelArgs(BaseModelArgs):
if self.num_key_value_heads is None: if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
if not "factor" in self.rope_scaling:
raise ValueError(f"rope_scaling must contain 'factor'")
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
"rope_type"
)
if rope_type is None:
raise ValueError(
f"rope_scaling must contain either 'type' or 'rope_type'"
)
if rope_type not in ["linear", "dynamic", "llama3"]:
raise ValueError(
"rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
)
class DynamicNTKScalingRoPE(nn.Module):
"""Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
rope_type: str = "default",
rope_scaling: dict = None,
):
super().__init__()
self.dims = dims
self.max_position_embeddings = max_position_embeddings
self.traditional = traditional
self.scale = scale
self.rope_type = rope_type
self.rope_scaling = rope_scaling
self.base = base
self.compute_freqs()
def compute_freqs(self):
if self.rope_type != "llama3":
self._freqs = None
return
factor = self.rope_scaling["factor"]
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
old_context_len = self.rope_scaling.get(
"original_max_position_embeddings",
8192,
)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims)
wavelens = 2 * mx.pi * freqs
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
self.base = None
def extra_repr(self):
return (
f"{self.dims}, traditional={self.traditional}, "
f"max_position_embeddings={self.max_position_embeddings}, "
f"scaling_factor={self.scale}, rope_type={self.rope_type}"
)
def __call__(self, x, offset: int = 0):
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=self.base,
scale=self.scale,
offset=offset,
freqs=self._freqs,
)
def initialize_rope(args: ModelArgs):
head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
rope_scaling = args.rope_scaling
rope_type = "default"
rope_scale = 1.0
if rope_scaling is not None:
rope_type = (
rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
)
if rope_type == "linear":
rope_scale = 1 / rope_scaling["factor"]
elif rope_type == "llama3":
rope_scale = 1.0 # The scaling is handled internally for llama3
return DynamicNTKScalingRoPE(
dims=head_dim,
max_position_embeddings=args.max_position_embeddings,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
rope_type=rope_type,
rope_scaling=rope_scaling,
)
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
@ -165,7 +55,13 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(args) self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__( def __call__(
self, self,

209
llms/mlx_lm/models/olmo2.py Normal file
View File

@ -0,0 +1,209 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
num_key_value_heads: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = self.q_norm(queries)
keys = self.k_norm(keys)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.post_attention_layernorm(self.self_attn(x, mask, cache))
h = x + r
r = self.post_feedforward_layernorm(self.mlp(h))
out = h + r
return out
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = LlamaModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers

View File

@ -0,0 +1,91 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
class Llama3RoPE(nn.Module):
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scaling_config: dict = None,
):
super().__init__()
self.dims = dims
self.max_position_embeddings = max_position_embeddings
self.traditional = traditional
factor = scaling_config["factor"]
low_freq_factor = scaling_config.get("low_freq_factor", 1.0)
high_freq_factor = scaling_config.get("high_freq_factor", 4.0)
old_context_len = scaling_config.get(
"original_max_position_embeddings",
8192,
)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
freqs = base ** (mx.arange(0, dims, 2) / dims)
wavelens = 2 * mx.pi * freqs
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
def extra_repr(self):
return (
f"{self.dims}, traditional={self.traditional}, "
f"max_position_embeddings={self.max_position_embeddings}"
)
def __call__(self, x, offset: int = 0):
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
def initialize_rope(
dims,
base,
traditional,
scaling_config: Optional[dict] = None,
max_position_embeddings: Optional[int] = None,
):
if scaling_config is not None:
rope_type = scaling_config.get("type") or scaling_config.get(
"rope_type", "default"
)
else:
rope_type = "default"
if rope_type in ["default", "linear"]:
scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0
return nn.RoPE(dims, traditional=traditional, base=base, scale=scale)
elif rope_type == "llama3":
return Llama3RoPE(
dims=dims,
max_position_embeddings=max_position_embeddings,
traditional=traditional,
base=base,
scaling_config=scaling_config,
)
else:
raise ValueError(f"Unsupported RoPE type {rope_type}")

View File

@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import math
from functools import partial from functools import partial
from typing import Callable, Dict, Optional from typing import Callable, Dict, Optional
@ -80,7 +81,7 @@ def make_logits_processors(
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling( def min_p_sampling(
logits: mx.array, logprobs: mx.array,
min_p: float, min_p: float,
min_tokens_to_keep: int = 1, min_tokens_to_keep: int = 1,
temperature=1.0, temperature=1.0,
@ -93,7 +94,7 @@ def min_p_sampling(
aggressive given a very high-probability token. aggressive given a very high-probability token.
Args: Args:
logits: The logits from the model's output. logprobs: A vector of log probabilities.
min_p (float): Minimum token probability. Typical values are in the min_p (float): Minimum token probability. Typical values are in the
0.01-0.2 range, comparably selective as setting `top_p` in the 0.01-0.2 range, comparably selective as setting `top_p` in the
0.99-0.8 range. 0.99-0.8 range.
@ -111,28 +112,27 @@ def min_p_sampling(
) )
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
# Softmax probabilities logprobs = logprobs * (1 / temperature)
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# Indices sorted in decreasing order # Indices sorted in decreasing order
sorted_indices = mx.argsort(-logits).squeeze(0) sorted_indices = mx.argsort(-logprobs).squeeze(0)
sorted_probs = probs[..., sorted_indices] sorted_logprobs = logprobs[..., sorted_indices]
# Top probability # Top probability
top_probs = probs[..., sorted_indices[0]] top_logprobs = logprobs[..., sorted_indices[0]]
# Calculate the min_p threshold # Calculate the min_p threshold
scaled_min_p = min_p * top_probs scaled_min_p = top_logprobs + math.log(min_p)
# Mask tokens that have a probability less than the scaled min_p # Mask tokens that have a probability less than the scaled min_p
tokens_to_remove = sorted_probs < scaled_min_p tokens_to_remove = sorted_logprobs < scaled_min_p
tokens_to_remove[..., :min_tokens_to_keep] = False tokens_to_remove[..., :min_tokens_to_keep] = False
# Create pool of tokens with probability less than scaled min_p # Create pool of tokens with probability less than scaled min_p
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
# Return sampled token # Return sampled token
sorted_token = mx.random.categorical(mx.log(selected_probs)) sorted_token = mx.random.categorical(selected_logprobs)
return sorted_indices[sorted_token] return sorted_indices[sorted_token]

View File

@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir
from ._version import __version__ from ._version import __version__
from .models.cache import make_prompt_cache from .models.cache import make_prompt_cache
from .sample_utils import make_logits_processors, make_sampler
from .utils import load, stream_generate from .utils import load, stream_generate
@ -464,25 +465,24 @@ class APIHandler(BaseHTTPRequestHandler):
text = "" text = ""
tic = time.perf_counter() tic = time.perf_counter()
for n, (segment, token, logprobs) in enumerate( sampler = make_sampler(self.temperature)
stream_generate( logits_processors = make_logits_processors(
self.logit_bias, self.repetition_penalty, self.repetition_context_size
)
for gen_response in stream_generate(
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
prompt=prompt, prompt=prompt,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
temp=self.temperature, sampler=sampler,
repetition_penalty=self.repetition_penalty, logits_processors=logits_processors,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache, prompt_cache=self.prompt_cache.cache,
),
): ):
if n == 0: segment = gen_response.text
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
text += segment text += segment
logging.debug(text) logging.debug(text)
token = gen_response.token
logprobs = gen_response.logprobs
tokens.append(token) tokens.append(token)
if self.logprobs > 0: if self.logprobs > 0:
@ -523,13 +523,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.prompt_cache.tokens.extend(tokens) self.prompt_cache.tokens.extend(tokens)
gen_time = time.perf_counter() - tic logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec")
prompt_tps = len(prompt) / prompt_time logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec")
gen_tps = len(tokens) / gen_time logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB")
peak_mem = mx.metal.get_peak_memory() / 1e9
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
if self.stream: if self.stream:
response = self.generate_response(segment, finish_reason) response = self.generate_response(segment, finish_reason)
@ -593,9 +589,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type # Determine response type
self.request_id = f"chatcmpl-{uuid.uuid4()}" self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = ( self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
"chat.completions.chunk" if self.stream else "chat.completions"
)
if ( if (
hasattr(self.tokenizer, "apply_chat_template") hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template and self.tokenizer.chat_template

View File

@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def reset(self): def reset(self):
self.offset = 0 self.offset = 0
self._tokens = [] self.tokens = []
self._text = "" self._text = ""
self._current_tokens = [] self._current_tokens = []
self._current_text = "" self._current_text = ""
def add_token(self, token): def add_token(self, token):
self._current_tokens.append(token) self._current_tokens.append(token)
self.tokens.append(token)
def finalize(self): def finalize(self):
self._tokens.extend(self._current_tokens)
self._text += self._tokenizer.decode(self._current_tokens) self._text += self._tokenizer.decode(self._current_tokens)
self._current_tokens = [] self._current_tokens = []
self._current_text = "" self._current_text = ""
@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
): ):
self._current_text = self._current_text[:-1] self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n": if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text self._text += self._current_text
self._current_tokens.clear() self._current_tokens.clear()
self._current_text = "" self._current_text = ""
return self._text + self._current_text return self._text + self._current_text
@property
def tokens(self):
return self._tokens
class SPMStreamingDetokenizer(StreamingDetokenizer): class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models. """A streaming detokenizer for SPM models.
@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
self.text += text self.text += text
def add_token(self, token): def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token] v = self.tokenmap[token]
if v.startswith(self._sep): if v.startswith(self._sep):
self._flush() self._flush()
@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
return current_text return current_text
def add_token(self, token): def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token] v = self.tokenmap[token]
is_added = token in self._added_ids is_added = token in self._added_ids
if is_added or self._byte_decoder[v[0]] == 32: if is_added or self._byte_decoder[v[0]] == 32:
@ -257,21 +254,33 @@ class TokenizerWrapper:
huggingface tokenizer. huggingface tokenizer.
""" """
def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer): def __init__(
self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None
):
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._detokenizer = detokenizer_class(tokenizer) self._detokenizer = detokenizer_class(tokenizer)
self._eos_token_ids = (
set(eos_token_ids)
if eos_token_ids is not None
else {tokenizer.eos_token_id}
)
def __getattr__(self, attr): def __getattr__(self, attr):
if attr == "detokenizer": if attr == "detokenizer":
return self._detokenizer return self._detokenizer
elif attr == "eos_token_ids":
return self._eos_token_ids
elif attr.startswith("_"): elif attr.startswith("_"):
return self.__getattribute__(attr) return self.__getattribute__(attr)
else: else:
return getattr(self._tokenizer, attr) return getattr(self._tokenizer, attr)
def __setattr__(self, attr, value): def __setattr__(self, attr, value):
if attr in {"detokenizer", "eos_token_ids"}:
if attr == "detokenizer": if attr == "detokenizer":
raise AttributeError("Cannot set the detokenizer.") raise AttributeError("Cannot set the detokenizer.")
elif attr == "eos_token_ids":
self._eos_token_ids = set(value) if value is not None else set()
elif attr.startswith("_"): elif attr.startswith("_"):
super().__setattr__(attr, value) super().__setattr__(attr, value)
else: else:
@ -318,7 +327,7 @@ def _is_bpe_decoder(decoder):
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}): def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
"""Load a huggingface tokenizer and try to infer the type of streaming """Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use. detokenizer to use.
@ -339,7 +348,10 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
elif _is_bpe_decoder(tokenizer_content["decoder"]): elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer detokenizer_class = BPEStreamingDetokenizer
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
return TokenizerWrapper( return TokenizerWrapper(
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
detokenizer_class, detokenizer_class,
eos_token_ids=eos_token_ids,
) )

View File

@ -98,6 +98,7 @@ def linear_to_lora_layers(
"cohere", "cohere",
"minicpm", "minicpm",
"deepseek", "deepseek",
"olmo2",
]: ]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"]) keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type in ["mixtral", "phimoe"]: if model.model_type in ["mixtral", "phimoe"]:
@ -150,6 +151,8 @@ def linear_to_lora_layers(
"mixer.out_proj", "mixer.out_proj",
] ]
) )
elif model.model_type == "exaone":
keys = set(["attn.attention.q_proj", "attn.attention.v_proj"])
else: else:
raise ValueError(f"Lora does not support {model.model_type}") raise ValueError(f"Lora does not support {model.model_type}")
@ -256,12 +259,14 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
return model return model
def print_trainable_parameters(model): def nparams(module):
def nparams(m): if hasattr(module, "bits"):
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): n = 0 if not hasattr(module, "bias") else module.bias.size
return m.weight.size * (32 // m.bits) return n + module.weight.size * 32 // module.bits
return sum(v.size for _, v in tree_flatten(m.parameters())) return sum(v.size for _, v in tree_flatten(module.parameters()))
def print_trainable_parameters(model):
leaf_modules = tree_flatten( leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
) )

View File

@ -8,6 +8,7 @@ import json
import logging import logging
import shutil import shutil
import time import time
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
@ -15,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
# Local imports # Local imports
@ -23,7 +24,7 @@ from .models import cache
from .sample_utils import make_logits_processors, make_sampler from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters from .tuner.utils import load_adapters, nparams
# Constants # Constants
MODEL_REMAPPING = { MODEL_REMAPPING = {
@ -44,6 +45,32 @@ class ModelNotFoundError(Exception):
super().__init__(self.message) super().__init__(self.message)
@dataclass
class GenerationResponse:
"""
The output of :func:`stream_generate`.
Args:
text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token.
logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second.
generation_tokens (int): The number of generated tokens.
generation_tps (float): The tokens-per-second for generation.
peak_memory (float): The peak memory used so far in GB.
"""
text: str
token: int
logprobs: mx.array
prompt_tokens: int
prompt_tps: float
generation_tokens: int
generation_tps: float
peak_memory: float
@contextlib.contextmanager @contextlib.contextmanager
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
""" """
@ -100,6 +127,17 @@ def _get_classes(config: dict):
return arch.Model, arch.ModelArgs return arch.Model, arch.ModelArgs
def compute_bits_per_weight(model):
model_bytes = tree_reduce(
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
)
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)
model_params = sum(nparams(m) for _, m in leaf_modules)
return model_bytes * 8 / model_params
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
""" """
Ensures the model is available locally. If the path does not exist locally, Ensures the model is available locally. If the path does not exist locally,
@ -155,20 +193,23 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
def generate_step( def generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
temp: float = 0.0, *,
repetition_penalty: Optional[float] = None, max_tokens: int = 256,
repetition_context_size: Optional[int] = 20, sampler: Optional[Callable[mx.array, mx.array]] = None,
top_p: float = 1.0, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None, max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None, prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None, prefill_step_size: int = 512,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -176,32 +217,25 @@ def generate_step(
Args: Args:
prompt (mx.array): The input prompt. prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used. max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
Default: ``0``. generator. Default: ``256``.
repetition_penalty (float, optional): The penalty factor for repeating sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
tokens. token from a vector of log probabilities. Default: ``None``.
repetition_context_size (int, optional): The number of tokens to logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
consider for repetition penalty. Default: ``20``. A list of functions that take tokens and logits and return the processed
top_p (float, optional): Nulceus sampling, higher means model considers logits. Default: ``None``.
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten. entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias. prefill_step_size (int): Step size for processing the prompt.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization. kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``. None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``. kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache. quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.
Yields: Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities. Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
@ -219,11 +253,24 @@ def generate_step(
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) if temp is not None or top_p is not None or min_tokens_to_keep is not None:
logits_processors = logits_processors or [] print(
logits_processors.extend( "[Warning] Specifying sampling arguments to ``generate_step`` is "
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) "deprecated. Pass in a ``sampler`` instead."
) )
if repetition_penalty is not None:
print(
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
"Pass in ``logits_processors`` instead."
)
sampler = sampler or make_sampler(
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
)
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
def _step(y): def _step(y):
with mx.stream(generation_stream): with mx.stream(generation_stream):
@ -245,9 +292,14 @@ def generate_step(
y = sampler(logprobs) y = sampler(logprobs)
return y, logprobs.squeeze(0) return y, logprobs.squeeze(0)
with mx.stream(generation_stream):
total_prompt_tokens = y.size
prompt_processed_tokens = 0
while y.size > prefill_step_size: while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache) model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache]) mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size
y = y[prefill_step_size:] y = y[prefill_step_size:]
mx.metal.clear_cache() mx.metal.clear_cache()
@ -256,70 +308,92 @@ def generate_step(
mx.async_eval(y, logprobs) mx.async_eval(y, logprobs)
n = 0 n = 0
while True: while True:
if n != max_tokens:
next_y, next_logprobs = _step(y) next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs) mx.async_eval(next_y, next_logprobs)
if n == 0:
mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
if n == max_tokens:
break
yield y.item(), logprobs yield y.item(), logprobs
if n % 256 == 0: if n % 256 == 0:
mx.metal.clear_cache() mx.metal.clear_cache()
n += 1
y, logprobs = next_y, next_logprobs y, logprobs = next_y, next_logprobs
n += 1
def stream_generate( def stream_generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, List[int]], prompt: Union[str, mx.array, List[int]],
max_tokens: int = 100,
**kwargs, **kwargs,
) -> Generator[Tuple[str, int, mx.array], None, None]: ) -> Generator[GenerationResponse, None, None]:
""" """
A generator producing text based on the given prompt from the model. A generator producing text based on the given prompt from the model.
Args: Args:
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens. prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`. kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details. See :func:`generate_step` for more details.
Yields: Yields:
Tuple[str, int, mx.array]: GenerationResponse: An instance containing the generated text segment and
The next text segment, token, and vector of log probabilities. associated metadata. See :class:`GenerationResponse` for details.
""" """
if not isinstance(tokenizer, TokenizerWrapper): if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer) tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array( if not isinstance(prompt, mx.array):
prompt = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt) prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
) )
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
for n, (token, logits) in zip( tic = time.perf_counter()
range(max_tokens), for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
generate_step(prompt_tokens, model, **kwargs), if n == 0:
): prompt_time = time.perf_counter() - tic
if token == tokenizer.eos_token_id: prompt_tps = prompt.size / prompt_time
tic = time.perf_counter()
if token in tokenizer.eos_token_ids:
break break
detokenizer.add_token(token) detokenizer.add_token(token)
if n == (max_tokens - 1): yield GenerationResponse(
break text=detokenizer.last_segment,
token=token,
yield detokenizer.last_segment, token, logits logprobs=logprobs,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9,
)
detokenizer.finalize() detokenizer.finalize()
yield detokenizer.last_segment, token, logits yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
logprobs=logprobs,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9,
)
def generate( def generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str, prompt: str,
max_tokens: int = 100,
verbose: bool = False, verbose: bool = False,
formatter: Optional[Callable] = None, formatter: Optional[Callable] = None,
**kwargs, **kwargs,
@ -331,67 +405,42 @@ def generate(
model (nn.Module): The language model. model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt. prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information. verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``. Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a kwargs: The remaining options get passed to :func:`stream_generate`.
probability and displays it. See :func:`stream_generate` for more details.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
""" """
if not isinstance(tokenizer, TokenizerWrapper): if formatter is not None:
tokenizer = TokenizerWrapper(tokenizer) print(
"[Warning] Text formatting is deprecated and no longer used. "
"The argument will be removed in a future version."
)
if verbose: if verbose:
print("=" * 10) print("=" * 10)
print("Prompt:", prompt) print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt)) text = ""
detokenizer = tokenizer.detokenizer for response in stream_generate(model, tokenizer, prompt, **kwargs):
if verbose:
with wired_limit(model, [generation_stream]): print(response.text, end="", flush=True)
tic = time.perf_counter() text += response.text
detokenizer.reset()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
if verbose: if verbose:
if formatter: print()
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)
token_count = n + 1
detokenizer.finalize()
if verbose:
gen_time = time.perf_counter() - tic
print(detokenizer.last_segment, flush=True)
print("=" * 10) print("=" * 10)
if token_count == 0: if len(text) == 0:
print("No tokens generated for this prompt") print("No text generated for this prompt")
return return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print( print(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
) )
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") print(
peak_mem = mx.metal.get_peak_memory() / 1e9 f"Generation: {response.generation_tokens} tokens, "
print(f"Peak memory: {peak_mem:.3f} GB") f"{response.generation_tps:.3f} tokens-per-sec"
)
return detokenizer.text print(f"Peak memory: {response.peak_memory:.3f} GB")
return text
def load_config(model_path: Path) -> dict: def load_config(model_path: Path) -> dict:
@ -418,11 +467,11 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
model_config (dict, optional): Configuration parameters for the model. model_config (dict, optional): Optional configuration parameters for the
Defaults to an empty dictionary. model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
A function that returns the model class and model args class given a config. A function that returns the model class and model args class given a config.
Defaults to the _get_classes function. Defaults to the ``_get_classes`` function.
Returns: Returns:
nn.Module: The loaded and initialized model. nn.Module: The loaded and initialized model.
@ -431,7 +480,6 @@ def load_model(
FileNotFoundError: If the weight files (.safetensors) are not found. FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated. ValueError: If the model class or args class are not found or cannot be instantiated.
""" """
config = load_config(model_path) config = load_config(model_path)
config.update(model_config) config.update(model_config)
@ -458,15 +506,20 @@ def load_model(
weights = model.sanitize(weights) weights = model.sanitize(weights)
if (quantization := config.get("quantization", None)) is not None: if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
def class_predicate(p, m): def class_predicate(p, m):
# Handle custom per layer quantizations
if p in config["quantization"]:
return config["quantization"][p]
if not hasattr(m, "to_quantized"): if not hasattr(m, "to_quantized"):
return False return False
# Handle legacy models which may not have everything quantized
return f"{p}.scales" in weights return f"{p}.scales" in weights
nn.quantize( nn.quantize(
model, model,
**quantization, group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=class_predicate, class_predicate=class_predicate,
) )
@ -476,7 +529,7 @@ def load_model(
mx.eval(model.parameters()) mx.eval(model.parameters())
model.eval() model.eval()
return model return model, config
def load( def load(
@ -509,11 +562,13 @@ def load(
""" """
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy, model_config) model, config = load_model(model_path, lazy)
if adapter_path is not None: if adapter_path is not None:
model = load_adapters(model, adapter_path) model = load_adapters(model, adapter_path)
model.eval() model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config) tokenizer = load_tokenizer(
model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None)
)
return model, tokenizer return model, tokenizer
@ -521,9 +576,10 @@ def load(
def fetch_from_hub( def fetch_from_hub(
model_path: Path, lazy: bool = False model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy) model, config = load_model(model_path, lazy)
config = load_config(model_path) tokenizer = load_tokenizer(
tokenizer = load_tokenizer(model_path) model_path, eos_token_ids=config.get("eos_token_id", None)
)
return model, config, tokenizer return model, config, tokenizer
@ -669,7 +725,13 @@ def save_weights(
def quantize_model( def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int model: nn.Module,
config: dict,
q_group_size: int,
q_bits: int,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
) -> Tuple: ) -> Tuple:
""" """
Applies quantization to the model weights. Applies quantization to the model weights.
@ -679,17 +741,37 @@ def quantize_model(
config (dict): Model configuration. config (dict): Model configuration.
q_group_size (int): Group size for quantization. q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization. q_bits (int): Bits per weight for quantization.
quant_predicate (Callable): A callable that decides how
to quantize each layer based on the path.
Accepts the layer `path`, the `module` and the model `config`.
Returns either a bool to signify quantize/no quantize or
a dict of quantization parameters to pass to `to_quantized`.
Returns: Returns:
Tuple: Tuple containing quantized weights and config. Tuple: Tuple containing quantized weights and config.
""" """
quantized_config = copy.deepcopy(config) quantized_config = copy.deepcopy(config)
nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
# Add any custom quantization parameters to the config as we go
def _class_predicate(p, m):
bool_or_params = quant_predicate(p, m, config)
quantized_config["quantization"][p] = bool_or_params
return bool_or_params
nn.quantize(
model,
q_group_size,
q_bits,
class_predicate=_class_predicate if quant_predicate else None,
)
# support hf model tree #957 # support hf model tree #957
quantized_config["quantization_config"] = quantized_config["quantization"] quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters())) quantized_weights = dict(tree_flatten(model.parameters()))
bpw = compute_bits_per_weight(model)
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
return quantized_weights, quantized_config return quantized_weights, quantized_config
@ -726,6 +808,9 @@ def convert(
upload_repo: str = None, upload_repo: str = None,
revision: Optional[str] = None, revision: Optional[str] = None,
dequantize: bool = False, dequantize: bool = False,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
): ):
# Check the save path is empty # Check the save path is empty
if isinstance(mlx_path, str): if isinstance(mlx_path, str):
@ -751,7 +836,9 @@ def convert(
if quantize: if quantize:
print("[INFO] Quantizing") print("[INFO] Quantizing")
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits) weights, config = quantize_model(
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
)
if dequantize: if dequantize:
print("[INFO] Dequantizing") print("[INFO] Dequantizing")

View File

@ -28,12 +28,14 @@ setup(
python_requires=">=3.8", python_requires=">=3.8",
extras_require={ extras_require={
"testing": ["datasets"], "testing": ["datasets"],
"evaluation": ["lm-eval"],
}, },
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
"mlx_lm.chat = mlx_lm.chat:main", "mlx_lm.chat = mlx_lm.chat:main",
"mlx_lm.convert = mlx_lm.convert:main", "mlx_lm.convert = mlx_lm.convert:main",
"mlx_lm.evaluate = mlx_lm.evaluate:main",
"mlx_lm.fuse = mlx_lm.fuse:main", "mlx_lm.fuse = mlx_lm.fuse:main",
"mlx_lm.generate = mlx_lm.generate:main", "mlx_lm.generate = mlx_lm.generate:main",
"mlx_lm.lora = mlx_lm.lora:main", "mlx_lm.lora = mlx_lm.lora:main",

View File

@ -2,6 +2,7 @@
import unittest import unittest
from mlx_lm.sample_utils import make_logits_processors
from mlx_lm.utils import generate, load from mlx_lm.utils import generate, load
@ -25,8 +26,8 @@ class TestGenerate(unittest.TestCase):
self.tokenizer, self.tokenizer,
"hello", "hello",
max_tokens=5, max_tokens=5,
logits_processors=make_logits_processors(logit_bias),
verbose=False, verbose=False,
logit_bias=logit_bias,
) )
self.assertEqual(text, "!!!!!") self.assertEqual(text, "!!!!!")

View File

@ -2,7 +2,9 @@
import unittest import unittest
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map from mlx.utils import tree_map
from mlx_lm.models import rope_utils
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
@ -126,6 +128,26 @@ class TestModels(unittest.TestCase):
self.assertEqual(cache.offset, 22) self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :])) self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def test_rope(self):
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
self.assertTrue(isinstance(rope, nn.RoPE))
rope = rope_utils.initialize_rope(
32,
base=100,
traditional=False,
scaling_config={"rope_type": "linear", "factor": 10.0},
)
self.assertTrue(isinstance(rope, nn.RoPE))
rope = rope_utils.initialize_rope(
32,
base=100,
traditional=False,
scaling_config={"rope_type": "llama3", "factor": 2.0},
)
self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE))
def model_test_runner(self, model, model_type, vocab_size, num_layers): def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers) self.assertEqual(len(model.layers), num_layers)
@ -760,6 +782,75 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers model, args.model_type, args.vocab_size, args.num_hidden_layers
) )
def test_hunyuan(self):
from mlx_lm.models import hunyuan
args = hunyuan.ModelArgs(
model_type="hunyuan",
hidden_size=128,
attention_bias=False,
intermediate_size=256,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
rms_norm_eps=1e-4,
rope_theta=1000,
vocab_size=1000,
moe_topk=2,
num_experts=2,
num_shared_expert=1,
use_mixed_mlp_moe=True,
use_qk_norm=True,
rope_scaling={
"alpha": 1000.0,
"factor": 1.0,
"type": "dynamic",
},
use_cla=True,
cla_share_factor=2,
)
model = hunyuan.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_olmo2(self):
from mlx_lm.models import olmo2
args = olmo2.ModelArgs(
model_type="olmo2",
hidden_size=128,
attention_bias=False,
intermediate_size=256,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
rms_norm_eps=1e-4,
rope_theta=1000,
vocab_size=1000,
)
model = olmo2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_exaone(self):
from mlx_lm.models import exaone
args = exaone.ModelArgs(
model_type="exaone",
hidden_size=128,
num_layers=4,
intermediate_size=256,
num_attention_heads=8,
num_key_value_heads=2,
vocab_size=1000,
layer_norm_epsilon=1e-4,
rope_theta=10000,
)
model = exaone.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -121,21 +121,20 @@ class TestPromptCache(unittest.TestCase):
def test_cache_with_generate(self): def test_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH) model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = zip(range(4), generate_step(prompt, model)) results = list(generate_step(prompt, model, max_tokens=4))
toks, all_logits = zip(*(r[1] for r in results)) toks, all_logits = zip(*results)
prompt_cache = make_prompt_cache(model) prompt_cache = make_prompt_cache(model)
i = 0 i = 0
for _, (tok, logits) in zip( for tok, logits in generate_step(
range(2), generate_step(prompt, model, prompt_cache=prompt_cache) prompt, model, prompt_cache=prompt_cache, max_tokens=2
): ):
self.assertEqual(tok, toks[i]) self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i])) self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1 i += 1
for _, (tok, logits) in zip( for tok, logits in generate_step(
range(1), mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
): ):
i += 1 i += 1
self.assertEqual(tok, toks[i]) self.assertEqual(tok, toks[i])

View File

@ -1,10 +1,10 @@
import unittest import unittest
import mlx.core as mx import mlx.core as mx
from mlx_lm.sample_utils import top_p_sampling from mlx_lm.sample_utils import min_p_sampling, top_p_sampling
class TestSamplingUtils(unittest.TestCase): class TestSampleUtils(unittest.TestCase):
def test_top_p_sampling(self): def test_top_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs) logits = mx.log(probs)
@ -28,6 +28,20 @@ class TestSamplingUtils(unittest.TestCase):
token = top_p_sampling(logits, 0.95, temperature).item() token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (1, 2, 3)) self.assertTrue(token in (1, 2, 3))
def test_min_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0
token = min_p_sampling(logits, 0.8)
self.assertEqual(token, 0)
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0
for _ in range(5):
token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase):
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
detokenizer.reset() detokenizer.reset()
text = "" text = ""
for t in tokens: for e, t in enumerate(tokens):
detokenizer.add_token(t) detokenizer.add_token(t)
seg = detokenizer.last_segment seg = detokenizer.last_segment
text += seg text += seg
self.assertEqual(detokenizer.tokens, tokens[: e + 1])
detokenizer.finalize() detokenizer.finalize()
text += detokenizer.last_segment text += detokenizer.last_segment
self.assertEqual(text, expected_text) self.assertEqual(text, expected_text)

View File

@ -32,7 +32,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
return CustomQwenModel, CustomQwenConfig return CustomQwenModel, CustomQwenConfig
model_path = get_model_path(HF_MODEL_PATH) model_path = get_model_path(HF_MODEL_PATH)
model = load_model(model_path, get_model_classes=custom_get_classes) model, _ = load_model(model_path, get_model_classes=custom_get_classes)
self.assertIsInstance(model, CustomQwenModel) self.assertIsInstance(model, CustomQwenModel)
self.assertTrue(hasattr(model, "custom_attribute")) self.assertTrue(hasattr(model, "custom_attribute"))
@ -41,7 +41,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
def test_load_model_with_default_get_classes(self): def test_load_model_with_default_get_classes(self):
model_path = get_model_path(HF_MODEL_PATH) model_path = get_model_path(HF_MODEL_PATH)
model = load_model(model_path) model, _ = load_model(model_path)
self.assertIsInstance(model, Qwen2Model) self.assertIsInstance(model, Qwen2Model)

View File

@ -76,6 +76,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
samples_per_sec = [] samples_per_sec = []
model.train(True) model.train(True)
train_iter.reset()
for batch_counter, batch in enumerate(train_iter): for batch_counter, batch in enumerate(train_iter):
x = mx.array(batch["audio"]) x = mx.array(batch["audio"])
y = mx.array(batch["label"]) y = mx.array(batch["label"])
@ -111,6 +112,7 @@ def test_epoch(model, test_iter):
model.train(False) model.train(False)
accs = [] accs = []
throughput = [] throughput = []
test_iter.reset()
for batch_counter, batch in enumerate(test_iter): for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["audio"]) x = mx.array(batch["audio"])
y = mx.array(batch["label"]) y = mx.array(batch["label"])

View File

@ -174,11 +174,6 @@ def load_torch_weights_and_config(
"*.txt", "*.txt",
], ],
) )
else:
raise RuntimeError(
f"Model {name_or_path} is not found in {available_models()},"
"on Hugging Face or as a local path."
)
if name_or_path.endswith(".pt"): if name_or_path.endswith(".pt"):
checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False) checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False)