diff --git a/flux/dreambooth.py b/flux/dreambooth.py index ffdb02d7..f82178b9 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -289,4 +289,4 @@ if __name__ == "__main__": tic = time.time() save_adapters("final_adapters.safetensors", flux, args) - print(f"Training successful. Saved final weights to {args.adapter_file}.") + print("Training successful.") diff --git a/flux/flux/model.py b/flux/flux/model.py index 18ea70b0..d8ad9d9b 100644 --- a/flux/flux/model.py +++ b/flux/flux/model.py @@ -85,6 +85,8 @@ class Flux(nn.Module): def sanitize(self, weights): new_weights = {} for k, w in weights.items(): + if k.startswith("model.diffusion_model."): + k = k[22:] if k.endswith(".scale"): k = k[:-6] + ".weight" for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]: diff --git a/flux/flux/sampler.py b/flux/flux/sampler.py index 3bff1ca2..6f293edc 100644 --- a/flux/flux/sampler.py +++ b/flux/flux/sampler.py @@ -7,7 +7,7 @@ import mlx.core as mx class FluxSampler: - def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5): + def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.15): self._base_shift = base_shift self._max_shift = max_shift self._schnell = "schnell" in name @@ -25,7 +25,7 @@ class FluxSampler: ): t = mx.linspace(start, stop, num_steps + 1) - if self._schnell: + if not self._schnell: t = self._time_shift(image_sequence_length, t) return t.tolist() @@ -50,6 +50,7 @@ class FluxSampler: if noise is not None else mx.random.normal(x.shape, dtype=x.dtype, key=key) ) + t = t.reshape([-1] + [1] * (x.ndim - 1)) return x * (1 - t) + t * noise def step(self, pred, x_t, t, t_prev): diff --git a/llms/README.md b/llms/README.md index eeb3ed6a..4fff4207 100644 --- a/llms/README.md +++ b/llms/README.md @@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -response = generate(model, tokenizer, prompt=prompt, verbose=True) +text = generate(model, tokenizer, prompt=prompt, verbose=True) ``` To see a description of all the arguments you can do: @@ -77,7 +77,7 @@ to see how to use the API in more detail. The `mlx-lm` package also comes with functionality to quantize and optionally upload models to the Hugging Face Hub. -You can convert models in the Python API with: +You can convert models using the Python API: ```python from mlx_lm import convert @@ -100,8 +100,9 @@ To see a description of all the arguments you can do: #### Streaming -For streaming generation, use the `stream_generate` function. This returns a -generator object which streams the output text, token, and log probabilities. +For streaming generation, use the `stream_generate` function. This yields +a generation response object. + For example, ```python @@ -117,8 +118,8 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): - print(t, end="", flush=True) +for response in stream_generate(model, tokenizer, prompt, max_tokens=512): + print(response.text, end="", flush=True) print() ``` @@ -162,6 +163,10 @@ mlx_lm.convert \ --upload-repo mlx-community/my-4bit-mistral ``` +Models can also be converted and quantized directly in the +[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging +Face Space. + ### Long Prompts and Generations `mlx-lm` has some tools to scale efficiently to long prompts and generations: diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 2976a09f..e544c6fa 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -92,7 +92,7 @@ curl localhost:8080/v1/chat/completions \ - `system_fingerprint`: A unique identifier for the system. -- `object`: Any of "chat.completions", "chat.completions.chunk" (for +- `object`: Any of "chat.completion", "chat.completion.chunk" (for streaming), or "text.completion". - `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`). diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 3811616f..0f885fba 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.19.3" +__version__ = "0.20.2" diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 987b640d..9d7d1603 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -8,7 +8,7 @@ import time import mlx.core as mx from .models.cache import make_prompt_cache, save_prompt_cache -from .utils import load, maybe_quantize_kv_cache +from .utils import generate_step, load DEFAULT_QUANTIZED_KV_START = 5000 @@ -50,12 +50,6 @@ def setup_arg_parser(): action="store_true", help="Use the default chat template", ) - parser.add_argument( - "--cache-limit-gb", - type=int, - default=None, - help="Set the MLX cache limit in GB", - ) parser.add_argument( "--max-kv-size", type=int, @@ -99,9 +93,6 @@ def main(): parser = setup_arg_parser() args = parser.parse_args() - if args.cache_limit_gb is not None: - mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} if args.eos_token is not None: @@ -144,26 +135,28 @@ def main(): y = mx.array(tokenizer.encode(prompt)) # Process the prompt - processed = 0 - step_size = 512 start = time.time() max_msg_len = 0 - while y.size > 0: - model(y[:step_size][None], cache=cache) - mx.eval([c.state for c in cache]) - mx.metal.clear_cache() - processed += min(y.size, step_size) - y = y[step_size:] + def callback(processed, total_tokens): current = time.time() speed = processed / (current - start) msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" + nonlocal max_msg_len max_msg_len = max(max_msg_len, len(msg)) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) - maybe_quantize_kv_cache( - cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits - ) + for _ in generate_step( + y, + model, + max_tokens=0, + prompt_cache=cache, + kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, + quantized_kv_start=args.quantized_kv_start, + prompt_progress_callback=callback, + ): + pass print() print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index c03056a6..7795d8d7 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -5,7 +5,8 @@ import json import mlx.core as mx -from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache +from .models.cache import make_prompt_cache +from .sample_utils import make_sampler from .utils import load, stream_generate DEFAULT_TEMP = 0.0 @@ -74,16 +75,15 @@ def main(): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - for response, *_ in stream_generate( + for response in stream_generate( model, tokenizer, prompt, args.max_tokens, - temp=args.temp, - top_p=args.top_p, + sampler=make_sampler(args.temp, args.top_p), prompt_cache=prompt_cache, ): - print(response, flush=True, end="") + print(response.text, flush=True, end="") print() diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py new file mode 100644 index 00000000..423d5823 --- /dev/null +++ b/llms/mlx_lm/evaluate.py @@ -0,0 +1,355 @@ +# Adapted from a PyTorch implementation by David Grangier + +import argparse +import json +import logging +import os +from importlib.metadata import version +from pathlib import Path +from typing import Optional + +import lm_eval +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model +from tqdm import tqdm + +from .models.cache import make_prompt_cache +from .utils import load, stream_generate + +PAD = 0 + + +def _len_longest_common_prefix(a, b): + l = 0 + for item_a, item_b in zip(a, b): + if item_a != item_b: + break + l += 1 + return l + + +def _rstrip_until(s, untils): + """Limit a string to the first occurence of any substring in untils.""" + l = len(s) + f = [s.find(u) for u in untils] + f = [l if x < 0 else x for x in f] + return s[: min(f)] + + +def _pad_inputs( + inputs, + maxlen, + genlen=0, + pad_left=False, + pad_multiple=32, + truncate=False, +): + # pad the prompts to the left with at least genlen tokens. + actual_maxlen = max(len(p) for p in inputs) + genlen + if actual_maxlen > maxlen: + if not truncate: + raise ValueError("Inputs are too long.") + else: # drop begining + actual_maxlen = maxlen + inputs = [p[max(0, len(p) - maxlen) :] for p in inputs] + if pad_multiple > 0: + maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple + maxlen *= pad_multiple + assert PAD == 0 + lr = np.array((1, 0) if pad_left else (0, 1)) + return np.stack( + [np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs], + axis=0, + ) + + +@register_model("mlxlm") +class MLXLM(LM): + def __init__( + self, + path_or_hf_repo: str, + batch_size: int = 16, + max_tokens: Optional[int] = None, + ) -> None: + super().__init__() + self._batch_size = batch_size + self._model, self._tokenizer = load(path_or_hf_repo) + self._max_tokens = max_tokens or self._tokenizer.model_max_length + + def _score_fn(self, inputs, tokenize=True, step_size=32): + if tokenize: + inputs = self._tokenizer.encode(inputs) + inputs = _pad_inputs(inputs, self._max_tokens, truncate=False) + inputs = mx.array(inputs) + inputs, targets = inputs[..., :-1], inputs[..., 1:] + + cache = make_prompt_cache(self._model) + + mask = targets != PAD + + scores, is_greedy = [], [] + for i in range(0, inputs.shape[1], step_size): + logits = self._model(inputs[:, i : i + step_size], cache=cache) + + log_probs = nn.log_softmax(logits.astype(mx.float32)) + score = mx.take_along_axis( + log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1 + )[..., 0] + ig = mask[:, i : i + step_size] * ( + targets[:, i : i + step_size] == mx.argmax(logits, axis=-1) + ) + + mx.eval(score, ig) + mx.metal.clear_cache() + + is_greedy.append(ig) + scores.append(score) + + scores = mx.concatenate(scores, axis=1) + is_greedy = mx.concatenate(is_greedy, axis=1) + + return scores, mask.sum(axis=-1), is_greedy + + def _loglikelihood(self, texts, score_spans=None, tokenize=True): + # sort by length to get batches with little padding. + sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i])) + sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))] + sorted_spans = None + if score_spans is not None: + sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))] + + results = [] + for i in tqdm(range(0, len(sorted_inputs), self._batch_size)): + batch = sorted_inputs[i : i + self._batch_size] + scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize) + for j in range(len(batch)): + if sorted_spans is None: # full sequence score + mask = mx.arange(scores[j].shape[-1]) < length + score = (scores[j].astype(mx.float32) * mask).sum(axis=-1) + ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1) + else: # subsequence score + start, end = sorted_spans[i + j] + score = scores[j][start:end].astype(mx.float32).sum() + ig = is_greedy[j][start:end].astype(mx.int32).sum() + length = end - start + + results.append((score.item(), ig.item(), length)) + + # reorder the outputs + inv_sort = np.argsort(sorted_indices) + results = [results[inv_sort[i]] for i in range(len(results))] + + return results + + def _tokenize(self, texts): + return [tuple(self._tokenizer.encode(t)) for t in texts] + + def loglikelihood(self, requests) -> list[tuple[float, bool]]: + """Compute log-likelihood of generating a continuation from a context. + Downstream tasks should attempt to use loglikelihood instead of other + LM calls whenever possible. + :param requests: list[Instance] + A list of Instance objects, with property `args` which returns a tuple (context, continuation). + `context: str` + Context string. Implementations of LM must be able to handle an + empty context string. + `continuation: str` + The continuation over which log likelihood will be calculated. If + there is a word boundary, the space should be in the continuation. + For example, context="hello" continuation=" world" is correct. + :return: list[tuple[float, bool]] + A list of pairs (logprob, isgreedy) + `logprob: float` + The log probability of `continuation`. + `isgreedy`: + Whether `continuation` would be generated by greedy sampling from `context`. + """ + logging.info("Estimating loglikelihood for %d pairs." % len(requests)) + + # tokenize prefix and prefix + completion for all requests. + tokenized = self._tokenize( + [t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]] + ) + + # max length (prefix + completion) and longest common prefix per question. + length_stats = {} + for prefix, completed in zip(tokenized[0::2], tokenized[1::2]): + max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8)) + length_stats[prefix] = ( + max(max_completed_l, len(completed)), + min(min_prefix_l, _len_longest_common_prefix(prefix, completed)), + ) + + # truncate requests for completed sequences longer than model context. + shortened = [] + completion_spans = [] + long_completions = 0 + for prefix, completed in zip(tokenized[0::2], tokenized[1::2]): + max_completed_l, prefix_l = length_stats[prefix] + # compute truncation length + truncation = max(0, max_completed_l - self._max_tokens - 1) + prefix_l = prefix_l - truncation + if prefix_l <= 0: + # completion too long, prefix is eliminated for some requests. + long_completions += 1 + truncation = max(0, len(completed) - self._max_tokens - 1) + prefix_l = 1 + # truncate the completed sequence + completed = completed[truncation:] + shortened.append(completed) + # scores do not include initial bos, substract 1 to span bounds + completion_spans.append((prefix_l - 1, len(completed) - 1)) + + if long_completions > 0: + logging.info( + f"Prefix eliminated for {long_completions} requests with " + + "completion longer than context." + ) + + # model scoring, returns num_requests x (logp, is_greedy, length). + results = self._loglikelihood( + shortened, + score_spans=completion_spans, + tokenize=False, + ) + return [(r[0], r[1] == r[2]) for r in results] + + def loglikelihood_rolling(self, requests) -> list[float]: + """Compute full log-likelihood of a string, with no truncation, for perplexity computation + - We will use the full max context length of the model. + - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to + the max context length. + - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations + which may simply concatenate multiple documents together. + - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into + multiple chunks, the last input will still a full-sized context. + Example: + Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] + Prefix: EOT + Max context length: 4 + Resulting input/prediction pairs: + INPUT: EOT 0 1 2 + PRED: 0 1 2 3 + INPUT: 3 4 5 6 + PRED: 4 5 6 7 + INPUT: 5 6 7 8 + PRED: 8 9 + Observe that: + 1. Each token is predicted exactly once + 2. For the last pair, we provide the full context, but only score the last two tokens + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context,). + string: str + String for which we are computing overall loglikelihood + :return: list[tuple[float]] + A list of tuples (logprob,) + logprob: float + The log probability of `context` conditioned on the EOT token. + """ + logging.info( + "Estimating loglikelihood rolling for %d sequences." % len(requests) + ) + inputs = [req.args[0] for req in requests] + return [t[0] for t in self._loglikelihood(inputs)] + + def generate_until(self, requests) -> list[str]: + """Generate greedily until a stopping sequence + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context, until). + context: str + Context string + until: [str] + The string sequences to generate until. These string sequences + may each span across multiple tokens, or may be part of one token. + :return: list[str] + A list of strings continuation + continuation: str + The generated continuation. + """ + logging.info("Generating continuation for %d sequences." % len(requests)) + contexts, options = zip(*[req.args for req in requests]) + # contrary to the doc the second element of the tuple contains + # {'do_sample': False, 'until': ['\n\n'], 'temperature': 0} + keys = list(options[0].keys()) + assert "until" in keys + untils = [x["until"] for x in options] + completions = [] + for context, until in tqdm(zip(contexts, untils), total=len(contexts)): + if ( + hasattr(self._tokenizer, "apply_chat_template") + and self._tokenizer.chat_template is not None + ): + messages = [{"role": "user", "content": context}] + context = self._tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + max_tokens = min( + self._max_tokens, + self._tokenizer.model_max_length - len(self._tokenizer.encode(context)), + ) + text = "" + for response in stream_generate( + self._model, self._tokenizer, prompt=context, max_tokens=max_tokens + ): + text += response.text + if any(u in text for u in until): + text = _rstrip_until(text, until) + completions.append(text) + break + else: + completions.append(text) + return completions + + +def main(): + parser = argparse.ArgumentParser( + "Evaluate an MLX model using lm-evaluation-harness." + ) + parser.add_argument("--model", help="Model to evaluate", required=True) + parser.add_argument("--tasks", nargs="+", required=True) + parser.add_argument( + "--output-dir", default=".", help="Output directory for result files." + ) + parser.add_argument("--batch-size", type=int, default=16, help="Batch size") + parser.add_argument("--num-shots", type=int, default=0, help="Number of shots") + parser.add_argument( + "--max-tokens", + type=int, + help="Maximum nunber of tokens to generate. Defaults to the model's max context length.", + ) + parser.add_argument("--seed", type=int, default=123, help="Random seed.") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Silence tokenizer warnings + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + mx.random.seed(args.seed) + + lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens) + + results = lm_eval.simple_evaluate( + model=lm, + tasks=args.tasks, + num_fewshot=args.num_shots, + random_seed=args.seed, + numpy_random_seed=args.seed, + torch_random_seed=args.seed, + fewshot_random_seed=args.seed, + ) + + model_name = args.model.replace("/", "_") + task_names = "_".join(args.tasks) + ver = version("lm_eval") + filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json" + output_path = output_dir / filename + output_path.write_text(json.dumps(results["results"], indent=4)) + print("Results:") + for result in results["results"].values(): + print(json.dumps(result, indent=4)) diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py index 3bf01688..c7512b3c 100644 --- a/llms/mlx_lm/examples/chat.py +++ b/llms/mlx_lm/examples/chat.py @@ -42,7 +42,6 @@ response = generate( tokenizer, prompt=prompt, verbose=True, - temp=0.0, prompt_cache=prompt_cache, ) diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py index 25730617..e6535b47 100644 --- a/llms/mlx_lm/examples/generate_response.py +++ b/llms/mlx_lm/examples/generate_response.py @@ -23,14 +23,6 @@ max_tokens = 1_000 # Specify if tokens and timing information will be printed verbose = True -# Some optional arguments for causal language model generation -generation_args = { - "temp": 0.7, - "repetition_penalty": 1.2, - "repetition_context_size": 20, - "top_p": 0.95, -} - # Generate a response with the specified settings response = generate( model=model, @@ -38,5 +30,4 @@ response = generate( prompt=prompt, max_tokens=max_tokens, verbose=verbose, - **generation_args, ) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 51169def..0c1b4acd 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -7,6 +7,7 @@ import sys import mlx.core as mx from .models.cache import QuantizedKVCache, load_prompt_cache +from .sample_utils import make_sampler from .utils import generate, load DEFAULT_PROMPT = "hello" @@ -41,17 +42,17 @@ def setup_arg_parser(): type=str, help="Optional path for the trained adapter weights and config.", ) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Enable trusting remote code for tokenizer", - ) parser.add_argument( "--eos-token", type=str, default=None, help="End of sequence token for tokenizer", ) + parser.add_argument( + "--system-prompt", + default=None, + help="System prompt to be used for the chat template", + ) parser.add_argument( "--prompt", "-p", @@ -76,7 +77,7 @@ def setup_arg_parser(): ) parser.add_argument( "--min-tokens-to-keep", - type=float, + type=int, default=DEFAULT_MIN_TOKENS_TO_KEEP, help="Minimum tokens to keep for min-p sampling.", ) @@ -97,11 +98,6 @@ def setup_arg_parser(): default=True, help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", ) - parser.add_argument( - "--colorize", - action="store_true", - help="Colorize output based on T[0] probability", - ) parser.add_argument( "--max-kv-size", type=int, @@ -137,33 +133,6 @@ def setup_arg_parser(): return parser -def colorprint(color, s): - color_codes = { - "black": 30, - "red": 31, - "green": 32, - "yellow": 33, - "blue": 34, - "magenta": 35, - "cyan": 36, - "white": 39, - } - ccode = color_codes.get(color, 30) - print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True) - - -def colorprint_by_t0(s, t0): - if t0 > 0.95: - color = "white" - elif t0 > 0.70: - color = "green" - elif t0 > 0.30: - color = "yellow" - else: - color = "red" - colorprint(color, s) - - def main(): parser = setup_arg_parser() args = parser.parse_args() @@ -191,8 +160,7 @@ def main(): tokenizer_config = ( {} if not using_cache else json.loads(metadata["tokenizer_config"]) ) - if args.trust_remote_code: - tokenizer_config["trust_remote_code"] = True + tokenizer_config["trust_remote_code"] = True if args.eos_token is not None: tokenizer_config["eos_token"] = args.eos_token @@ -224,12 +192,16 @@ def main(): hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None ): - messages = [ + if args.system_prompt is not None: + messages = [{"role": "system", "content": args.system_prompt}] + else: + messages = [] + messages.append( { "role": "user", "content": sys.stdin.read() if args.prompt == "-" else args.prompt, } - ] + ) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) @@ -237,8 +209,9 @@ def main(): # Treat the prompt as a suffix assuming that the prefix is in the # stored kv cache. if using_cache: + messages[-1]["content"] = "" test_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], + messages, tokenize=False, add_generation_prompt=True, ) @@ -246,21 +219,14 @@ def main(): else: prompt = args.prompt - if args.colorize and not args.verbose: - raise ValueError("Cannot use --colorize with --verbose=False") - formatter = colorprint_by_t0 if args.colorize else None - + sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, tokenizer, prompt, - args.max_tokens, + max_tokens=args.max_tokens, verbose=args.verbose, - formatter=formatter, - temp=args.temp, - top_p=args.top_p, - min_p=args.min_p, - min_tokens_to_keep=args.min_tokens_to_keep, + sampler=sampler, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits, diff --git a/llms/mlx_lm/models/exaone.py b/llms/mlx_lm/models/exaone.py new file mode 100644 index 00000000..eaed5dd8 --- /dev/null +++ b/llms/mlx_lm/models/exaone.py @@ -0,0 +1,163 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_layers: int + intermediate_size: int + num_attention_heads: int + vocab_size: int + rope_theta: float + layer_norm_epsilon: float + num_key_value_heads: int + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + attention_bias: bool = False + mlp_bias: bool = False + + +class AttentionModule(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim or (dim // n_heads) + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + + def __call__( + self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None + ) -> mx.array: + B, L, D = x.shape + q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + q = self.rope(q, offset=cache.offset) + k = self.rope(k, offset=cache.offset) + k, v = cache.update_and_fetch(k, v) + else: + q = self.rope(q) + k = self.rope(k) + + out = scaled_dot_product_attention( + q, k, v, cache=cache, scale=self.scale, mask=mask + ) + out = out.transpose(0, 2, 1, 3).reshape(B, L, D) + return self.out_proj(out) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.attention = AttentionModule(args) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + hidden_dim = args.intermediate_size + self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) + self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) + self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias) + + def __call__(self, x: mx.array) -> mx.array: + return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.attn = Attention(args) + self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.mlp = MLP(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + h = x + self.attn.attention(self.ln_1(x), mask, cache) + out = h + self.mlp(self.ln_2(h)) + return out + + +class ExaoneModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.wte = nn.Embedding(args.vocab_size, args.hidden_size) + self.h = [TransformerBlock(args) for _ in range(args.num_layers)] + self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.wte(inputs) + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.h) + + for layer, c in zip(self.h, cache): + h = layer(h, mask, cache=c) + + return self.ln_f(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.transformer = ExaoneModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.transformer(inputs, cache) + if self.args.tie_word_embeddings: + out = self.transformer.wte.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.transformer.h diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py new file mode 100644 index 00000000..b098c20d --- /dev/null +++ b/llms/mlx_lm/models/hunyuan.py @@ -0,0 +1,291 @@ +# Copyright © 2023-2024 Apple Inc. + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + num_key_value_heads: int + attention_bias: bool + moe_topk: int + num_experts: int + num_shared_expert: int + use_mixed_mlp_moe: bool + use_qk_norm: bool + rms_norm_eps: float + rope_theta: float + use_cla: bool + cla_share_factor: 2 + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False + + def __post_init__(self): + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + +class DynamicNTKAlphaRoPE(nn.Module): + def __init__( + self, + dims: int, + base: float = 10000, + scaling_alpha: float = 1.0, + ): + super().__init__() + self.dims = dims + base = base * scaling_alpha ** (dims / (dims - 2)) + self._freqs = base ** (mx.arange(0, self.dims, 2) / self.dims) + + def __call__(self, x, offset: int = 0): + return mx.fast.rope( + x, + self.dims, + traditional=False, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) + + +class Attention(nn.Module): + def __init__(self, kv_proj: bool, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + if kv_proj: + self.k_proj = nn.Linear( + dim, n_kv_heads * head_dim, bias=args.attention_bias + ) + self.v_proj = nn.Linear( + dim, n_kv_heads * head_dim, bias=args.attention_bias + ) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) + self.use_qk_norm = args.use_qk_norm + if self.use_qk_norm: + self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps) + self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps) + + self.rope = DynamicNTKAlphaRoPE( + head_dim, + base=args.rope_theta, + scaling_alpha=args.rope_scaling["alpha"], + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + kv_states=None, + ) -> mx.array: + B, L, D = x.shape + + queries = self.q_proj(x) + + if kv_states is None: + keys, values = self.k_proj(x), self.v_proj(x) + kv_states = keys, values + else: + keys, values = kv_states + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + offset = cache.offset if cache else 0 + queries = self.rope(queries, offset=offset) + keys = self.rope(keys, offset=offset) + if self.use_qk_norm: + queries = self.query_layernorm(queries) + keys = self.key_layernorm(keys) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), kv_states + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class Gate(nn.Module): + def __init__(self, dim, num_experts): + super().__init__() + self.wg = nn.Linear(dim, num_experts, bias=False) + + def __call__(self, x) -> mx.array: + return self.wg(x) + + +class MoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + intermediate_size = args.intermediate_size + self.use_shared_mlp = args.use_mixed_mlp_moe + + if args.use_mixed_mlp_moe: + self.shared_mlp = MLP(dim, intermediate_size * args.num_shared_expert) + + self.num_experts = num_experts = args.num_experts + self.top_k = args.moe_topk + + self.gate = Gate(dim, num_experts) + self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + + def __call__( + self, + x: mx.array, + ): + gates = self.gate(x) + gates = mx.softmax(gates, axis=-1, precise=True) + + k = self.top_k + inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) + scores = mx.take_along_axis(gates, inds, axis=-1) + + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + + if self.use_shared_mlp: + shared_expert_output = self.shared_mlp(x) + y = y + shared_expert_output + + return y + + +class DecoderLayer(nn.Module): + def __init__(self, args: ModelArgs, kv_proj: bool): + super().__init__() + self.hidden_size = args.hidden_size + self.self_attn = Attention(kv_proj, args) + self.mlp = MoeBlock(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + shared_kv_states: Optional[Tuple[mx.array, mx.array]] = None, + ): + r, shared_kv_states = self.self_attn( + self.input_layernorm(x), mask, cache, shared_kv_states + ) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, shared_kv_states + + +class HunYuanModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for i, (layer, c) in enumerate(zip(self.layers, cache)): + if i % self.args.cla_share_factor == 0: + shared_kv_states = None + h, shared_kv_states = layer(h, mask, c, shared_kv_states) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = HunYuanModel(args) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + return self.model.embed_tokens.as_linear(out) + + def sanitize(self, weights): + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n in ["up_proj", "down_proj", "gate_proj"]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) + return weights + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 438278e5..290cb83e 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -7,6 +7,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope @dataclass @@ -32,117 +33,6 @@ class ModelArgs(BaseModelArgs): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads - if self.rope_scaling: - if not "factor" in self.rope_scaling: - raise ValueError(f"rope_scaling must contain 'factor'") - rope_type = self.rope_scaling.get("type") or self.rope_scaling.get( - "rope_type" - ) - if rope_type is None: - raise ValueError( - f"rope_scaling must contain either 'type' or 'rope_type'" - ) - if rope_type not in ["linear", "dynamic", "llama3"]: - raise ValueError( - "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'" - ) - - -class DynamicNTKScalingRoPE(nn.Module): - """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE.""" - - def __init__( - self, - dims: int, - max_position_embeddings: int = 2048, - traditional: bool = False, - base: float = 10000, - scale: float = 1.0, - rope_type: str = "default", - rope_scaling: dict = None, - ): - super().__init__() - self.dims = dims - self.max_position_embeddings = max_position_embeddings - self.traditional = traditional - self.scale = scale - self.rope_type = rope_type - self.rope_scaling = rope_scaling - self.base = base - self.compute_freqs() - - def compute_freqs(self): - if self.rope_type != "llama3": - self._freqs = None - return - factor = self.rope_scaling["factor"] - low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) - high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) - old_context_len = self.rope_scaling.get( - "original_max_position_embeddings", - 8192, - ) - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) - wavelens = 2 * mx.pi * freqs - - freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) - is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) - smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) - self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) - self.base = None - - def extra_repr(self): - return ( - f"{self.dims}, traditional={self.traditional}, " - f"max_position_embeddings={self.max_position_embeddings}, " - f"scaling_factor={self.scale}, rope_type={self.rope_type}" - ) - - def __call__(self, x, offset: int = 0): - return mx.fast.rope( - x, - self.dims, - traditional=self.traditional, - base=self.base, - scale=self.scale, - offset=offset, - freqs=self._freqs, - ) - - -def initialize_rope(args: ModelArgs): - head_dim = args.head_dim or args.hidden_size // args.num_attention_heads - - rope_scaling = args.rope_scaling - rope_type = "default" - rope_scale = 1.0 - - if rope_scaling is not None: - rope_type = ( - rope_scaling.get("type") or rope_scaling.get("rope_type") or "default" - ) - if rope_type == "linear": - rope_scale = 1 / rope_scaling["factor"] - elif rope_type == "llama3": - rope_scale = 1.0 # The scaling is handled internally for llama3 - - return DynamicNTKScalingRoPE( - dims=head_dim, - max_position_embeddings=args.max_position_embeddings, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - rope_type=rope_type, - rope_scaling=rope_scaling, - ) - class Attention(nn.Module): def __init__(self, args: ModelArgs): @@ -165,7 +55,13 @@ class Attention(nn.Module): self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) - self.rope = initialize_rope(args) + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) def __call__( self, diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py new file mode 100644 index 00000000..64d7e116 --- /dev/null +++ b/llms/mlx_lm/models/olmo2.py @@ -0,0 +1,209 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + num_key_value_heads: Optional[int] = None + attention_bias: bool = False + mlp_bias: bool = False + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + + self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) + self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + queries = self.q_norm(queries) + keys = self.k_norm(keys) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + if hasattr(args, "mlp_bias"): + mlp_bias = args.mlp_bias + else: + mlp_bias = False + + self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.post_feedforward_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.post_attention_layernorm(self.self_attn(x, mask, cache)) + h = x + r + r = self.post_feedforward_layernorm(self.mlp(h)) + out = h + r + return out + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = LlamaModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/rope_utils.py b/llms/mlx_lm/models/rope_utils.py new file mode 100644 index 00000000..d30b432d --- /dev/null +++ b/llms/mlx_lm/models/rope_utils.py @@ -0,0 +1,91 @@ +# Copyright © 2023-2024 Apple Inc. + +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + + +class Llama3RoPE(nn.Module): + + def __init__( + self, + dims: int, + max_position_embeddings: int = 2048, + traditional: bool = False, + base: float = 10000, + scaling_config: dict = None, + ): + super().__init__() + self.dims = dims + self.max_position_embeddings = max_position_embeddings + self.traditional = traditional + + factor = scaling_config["factor"] + low_freq_factor = scaling_config.get("low_freq_factor", 1.0) + high_freq_factor = scaling_config.get("high_freq_factor", 4.0) + old_context_len = scaling_config.get( + "original_max_position_embeddings", + 8192, + ) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + freqs = base ** (mx.arange(0, dims, 2) / dims) + wavelens = 2 * mx.pi * freqs + + freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) + is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) + smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) + + def extra_repr(self): + return ( + f"{self.dims}, traditional={self.traditional}, " + f"max_position_embeddings={self.max_position_embeddings}" + ) + + def __call__(self, x, offset: int = 0): + return mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) + + +def initialize_rope( + dims, + base, + traditional, + scaling_config: Optional[dict] = None, + max_position_embeddings: Optional[int] = None, +): + if scaling_config is not None: + rope_type = scaling_config.get("type") or scaling_config.get( + "rope_type", "default" + ) + else: + rope_type = "default" + + if rope_type in ["default", "linear"]: + scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0 + return nn.RoPE(dims, traditional=traditional, base=base, scale=scale) + + elif rope_type == "llama3": + return Llama3RoPE( + dims=dims, + max_position_embeddings=max_position_embeddings, + traditional=traditional, + base=base, + scaling_config=scaling_config, + ) + else: + raise ValueError(f"Unsupported RoPE type {rope_type}") diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index c27b52d8..f9868422 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import math from functools import partial from typing import Callable, Dict, Optional @@ -80,7 +81,7 @@ def make_logits_processors( @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( - logits: mx.array, + logprobs: mx.array, min_p: float, min_tokens_to_keep: int = 1, temperature=1.0, @@ -93,7 +94,7 @@ def min_p_sampling( aggressive given a very high-probability token. Args: - logits: The logits from the model's output. + logprobs: A vector of log probabilities. min_p (float): Minimum token probability. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in the 0.99-0.8 range. @@ -111,28 +112,27 @@ def min_p_sampling( ) # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 - # Softmax probabilities - probs = mx.softmax(logits * (1 / temperature), axis=-1) + logprobs = logprobs * (1 / temperature) # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logits).squeeze(0) - sorted_probs = probs[..., sorted_indices] + sorted_indices = mx.argsort(-logprobs).squeeze(0) + sorted_logprobs = logprobs[..., sorted_indices] # Top probability - top_probs = probs[..., sorted_indices[0]] + top_logprobs = logprobs[..., sorted_indices[0]] # Calculate the min_p threshold - scaled_min_p = min_p * top_probs + scaled_min_p = top_logprobs + math.log(min_p) # Mask tokens that have a probability less than the scaled min_p - tokens_to_remove = sorted_probs < scaled_min_p + tokens_to_remove = sorted_logprobs < scaled_min_p tokens_to_remove[..., :min_tokens_to_keep] = False # Create pool of tokens with probability less than scaled min_p - selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) + selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) # Return sampled token - sorted_token = mx.random.categorical(mx.log(selected_probs)) + sorted_token = mx.random.categorical(selected_logprobs) return sorted_indices[sorted_token] diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index c1365b36..ce09cf45 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir from ._version import __version__ from .models.cache import make_prompt_cache +from .sample_utils import make_logits_processors, make_sampler from .utils import load, stream_generate @@ -464,25 +465,24 @@ class APIHandler(BaseHTTPRequestHandler): text = "" tic = time.perf_counter() - for n, (segment, token, logprobs) in enumerate( - stream_generate( - model=self.model, - tokenizer=self.tokenizer, - prompt=prompt, - max_tokens=self.max_tokens, - temp=self.temperature, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - logit_bias=self.logit_bias, - prompt_cache=self.prompt_cache.cache, - ), + sampler = make_sampler(self.temperature) + logits_processors = make_logits_processors( + self.logit_bias, self.repetition_penalty, self.repetition_context_size + ) + for gen_response in stream_generate( + model=self.model, + tokenizer=self.tokenizer, + prompt=prompt, + max_tokens=self.max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prompt_cache=self.prompt_cache.cache, ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - + segment = gen_response.text text += segment logging.debug(text) + token = gen_response.token + logprobs = gen_response.logprobs tokens.append(token) if self.logprobs > 0: @@ -523,13 +523,9 @@ class APIHandler(BaseHTTPRequestHandler): self.prompt_cache.tokens.extend(tokens) - gen_time = time.perf_counter() - tic - prompt_tps = len(prompt) / prompt_time - gen_tps = len(tokens) / gen_time - peak_mem = mx.metal.get_peak_memory() / 1e9 - logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec") - logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec") - logging.debug(f"Peak memory: {peak_mem:.3f} GB") + logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec") + logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec") + logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB") if self.stream: response = self.generate_response(segment, finish_reason) @@ -593,9 +589,7 @@ class APIHandler(BaseHTTPRequestHandler): # Determine response type self.request_id = f"chatcmpl-{uuid.uuid4()}" - self.object_type = ( - "chat.completions.chunk" if self.stream else "chat.completions" - ) + self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" if ( hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 9d390733..10a257f6 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): def reset(self): self.offset = 0 - self._tokens = [] + self.tokens = [] self._text = "" self._current_tokens = [] self._current_text = "" def add_token(self, token): self._current_tokens.append(token) + self.tokens.append(token) def finalize(self): - self._tokens.extend(self._current_tokens) self._text += self._tokenizer.decode(self._current_tokens) self._current_tokens = [] self._current_text = "" @@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): ): self._current_text = self._current_text[:-1] if self._current_text and self._current_text[-1] == "\n": - self._tokens.extend(self._current_tokens) self._text += self._current_text self._current_tokens.clear() self._current_text = "" return self._text + self._current_text - @property - def tokens(self): - return self._tokens - class SPMStreamingDetokenizer(StreamingDetokenizer): """A streaming detokenizer for SPM models. @@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): self.text += text def add_token(self, token): + self.tokens.append(token) v = self.tokenmap[token] if v.startswith(self._sep): self._flush() @@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): return current_text def add_token(self, token): + self.tokens.append(token) v = self.tokenmap[token] is_added = token in self._added_ids if is_added or self._byte_decoder[v[0]] == 32: @@ -257,21 +254,33 @@ class TokenizerWrapper: huggingface tokenizer. """ - def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer): + def __init__( + self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None + ): self._tokenizer = tokenizer self._detokenizer = detokenizer_class(tokenizer) + self._eos_token_ids = ( + set(eos_token_ids) + if eos_token_ids is not None + else {tokenizer.eos_token_id} + ) def __getattr__(self, attr): if attr == "detokenizer": return self._detokenizer + elif attr == "eos_token_ids": + return self._eos_token_ids elif attr.startswith("_"): return self.__getattribute__(attr) else: return getattr(self._tokenizer, attr) def __setattr__(self, attr, value): - if attr == "detokenizer": - raise AttributeError("Cannot set the detokenizer.") + if attr in {"detokenizer", "eos_token_ids"}: + if attr == "detokenizer": + raise AttributeError("Cannot set the detokenizer.") + elif attr == "eos_token_ids": + self._eos_token_ids = set(value) if value is not None else set() elif attr.startswith("_"): super().__setattr__(attr, value) else: @@ -318,7 +327,7 @@ def _is_bpe_decoder(decoder): return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" -def load_tokenizer(model_path, tokenizer_config_extra={}): +def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None): """Load a huggingface tokenizer and try to infer the type of streaming detokenizer to use. @@ -339,7 +348,10 @@ def load_tokenizer(model_path, tokenizer_config_extra={}): elif _is_bpe_decoder(tokenizer_content["decoder"]): detokenizer_class = BPEStreamingDetokenizer + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] return TokenizerWrapper( AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), detokenizer_class, + eos_token_ids=eos_token_ids, ) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index a44663fb..213bcad7 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -98,6 +98,7 @@ def linear_to_lora_layers( "cohere", "minicpm", "deepseek", + "olmo2", ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type in ["mixtral", "phimoe"]: @@ -150,6 +151,8 @@ def linear_to_lora_layers( "mixer.out_proj", ] ) + elif model.model_type == "exaone": + keys = set(["attn.attention.q_proj", "attn.attention.v_proj"]) else: raise ValueError(f"Lora does not support {model.model_type}") @@ -256,12 +259,14 @@ def remove_lora_layers(model: nn.Module) -> nn.Module: return model -def print_trainable_parameters(model): - def nparams(m): - if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): - return m.weight.size * (32 // m.bits) - return sum(v.size for _, v in tree_flatten(m.parameters())) +def nparams(module): + if hasattr(module, "bits"): + n = 0 if not hasattr(module, "bias") else module.bias.size + return n + module.weight.size * 32 // module.bits + return sum(v.size for _, v in tree_flatten(module.parameters())) + +def print_trainable_parameters(model): leaf_modules = tree_flatten( model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d4afd428..d81bb66a 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -8,6 +8,7 @@ import json import logging import shutil import time +from dataclasses import dataclass from pathlib import Path from textwrap import dedent from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union @@ -15,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten, tree_reduce +from mlx.utils import tree_flatten, tree_map, tree_reduce from transformers import PreTrainedTokenizer # Local imports @@ -23,7 +24,7 @@ from .models import cache from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model -from .tuner.utils import load_adapters +from .tuner.utils import load_adapters, nparams # Constants MODEL_REMAPPING = { @@ -44,6 +45,32 @@ class ModelNotFoundError(Exception): super().__init__(self.message) +@dataclass +class GenerationResponse: + """ + The output of :func:`stream_generate`. + + Args: + text (str): The next segment of decoded text. This can be an empty string. + token (int): The next token. + logprobs (mx.array): A vector of log probabilities. + prompt_tokens (int): The number of tokens in the prompt. + prompt_tps (float): The prompt processing tokens-per-second. + generation_tokens (int): The number of generated tokens. + generation_tps (float): The tokens-per-second for generation. + peak_memory (float): The peak memory used so far in GB. + """ + + text: str + token: int + logprobs: mx.array + prompt_tokens: int + prompt_tps: float + generation_tokens: int + generation_tps: float + peak_memory: float + + @contextlib.contextmanager def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): """ @@ -100,6 +127,17 @@ def _get_classes(config: dict): return arch.Model, arch.ModelArgs +def compute_bits_per_weight(model): + model_bytes = tree_reduce( + lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 + ) + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + model_params = sum(nparams(m) for _, m in leaf_modules) + return model_bytes * 8 / model_params + + def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: """ Ensures the model is available locally. If the path does not exist locally, @@ -155,20 +193,23 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ def generate_step( prompt: mx.array, model: nn.Module, - temp: float = 0.0, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = 20, - top_p: float = 1.0, - min_p: float = 0.0, - min_tokens_to_keep: int = 1, - prefill_step_size: int = 512, + *, + max_tokens: int = 256, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, - logit_bias: Optional[Dict[int, float]] = None, - logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prefill_step_size: int = 512, kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, + prompt_progress_callback: Optional[Callable[int, int]] = None, + temp: Optional[float] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = None, + top_p: Optional[float] = None, + min_p: Optional[float] = None, + min_tokens_to_keep: Optional[int] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -176,32 +217,25 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. - repetition_penalty (float, optional): The penalty factor for repeating - tokens. - repetition_context_size (int, optional): The number of tokens to - consider for repetition penalty. Default: ``20``. - top_p (float, optional): Nulceus sampling, higher means model considers - more less likely words. - min_p (float, optional): The minimum value (scaled by the top token's - probability) that a token probability must have to be considered. - min_tokens_to_keep (int, optional): Minimum number of tokens that cannot - be filtered by min_p sampling. - prefill_step_size (int): Step size for processing the prompt. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. - logit_bias (dictionary, optional): Additive logit bias. - logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): - A list of functions that take tokens and logits and return the processed - logits. Default: ``None``. + prefill_step_size (int): Step size for processing the prompt. kv_bits (int, optional): Number of bits to use for KV cache quantization. - None implies no cache quantization. Default: ``None``. + None implies no cache quantization. Default: ``None``. kv_group_size (int): Group size for KV cache quantization. Default: ``64``. quantized_kv_start (int): Step to begin using a quantized KV cache. - when ``kv_bits`` is non-None. Default: ``0``. + when ``kv_bits`` is non-None. Default: ``0``. + prompt_prorgress_callback (Callable[int, int]): A call-back which takes the + prompt tokens processed so far and the total number of prompt tokens. Yields: Tuple[mx.array, mx.array]: One token and a vector of log probabilities. @@ -219,11 +253,24 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") - sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) - logits_processors = logits_processors or [] - logits_processors.extend( - make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) + if temp is not None or top_p is not None or min_tokens_to_keep is not None: + print( + "[Warning] Specifying sampling arguments to ``generate_step`` is " + "deprecated. Pass in a ``sampler`` instead." + ) + if repetition_penalty is not None: + print( + "[Warning] Specifying ``repetition_penalty`` is deprecated. " + "Pass in ``logits_processors`` instead." + ) + + sampler = sampler or make_sampler( + temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 ) + logits_processors = logits_processors or make_logits_processors( + None, repetition_penalty, repetition_context_size or 20 + ) + prompt_progress_callback = prompt_progress_callback or (lambda *_: None) def _step(y): with mx.stream(generation_stream): @@ -245,81 +292,108 @@ def generate_step( y = sampler(logprobs) return y, logprobs.squeeze(0) - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - y = y[prefill_step_size:] - mx.metal.clear_cache() + with mx.stream(generation_stream): + total_prompt_tokens = y.size + prompt_processed_tokens = 0 + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) + prompt_processed_tokens += prefill_step_size + y = y[prefill_step_size:] + mx.metal.clear_cache() - y, logprobs = _step(y) + y, logprobs = _step(y) mx.async_eval(y, logprobs) n = 0 while True: - next_y, next_logprobs = _step(y) - mx.async_eval(next_y, next_logprobs) + if n != max_tokens: + next_y, next_logprobs = _step(y) + mx.async_eval(next_y, next_logprobs) + if n == 0: + mx.eval(y) + prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) + if n == max_tokens: + break yield y.item(), logprobs if n % 256 == 0: mx.metal.clear_cache() - n += 1 y, logprobs = next_y, next_logprobs + n += 1 def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: Union[str, List[int]], - max_tokens: int = 100, + prompt: Union[str, mx.array, List[int]], **kwargs, -) -> Generator[Tuple[str, int, mx.array], None, None]: +) -> Generator[GenerationResponse, None, None]: """ A generator producing text based on the given prompt from the model. Args: model (nn.Module): The model to use for generation. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (Union[str, List[int]]): The input prompt string or integer tokens. - max_tokens (int): The maximum number of tokens. Default: ``100``. + prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. Yields: - Tuple[str, int, mx.array]: - The next text segment, token, and vector of log probabilities. + GenerationResponse: An instance containing the generated text segment and + associated metadata. See :class:`GenerationResponse` for details. """ if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt_tokens = mx.array( - prompt if isinstance(prompt, list) else tokenizer.encode(prompt) - ) + if not isinstance(prompt, mx.array): + prompt = mx.array( + prompt if isinstance(prompt, list) else tokenizer.encode(prompt) + ) + detokenizer = tokenizer.detokenizer with wired_limit(model, [generation_stream]): detokenizer.reset() - for n, (token, logits) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if token == tokenizer.eos_token_id: + tic = time.perf_counter() + for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): + if n == 0: + prompt_time = time.perf_counter() - tic + prompt_tps = prompt.size / prompt_time + tic = time.perf_counter() + if token in tokenizer.eos_token_ids: break detokenizer.add_token(token) - if n == (max_tokens - 1): - break - - yield detokenizer.last_segment, token, logits + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=prompt.size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + ) detokenizer.finalize() - yield detokenizer.last_segment, token, logits + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=prompt.size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + ) def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: str, - max_tokens: int = 100, verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, @@ -331,67 +405,42 @@ def generate( model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. prompt (str): The string prompt. - max_tokens (int): The maximum number of tokens. Default: ``100``. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. - formatter (Optional[Callable]): A function which takes a token and a - probability and displays it. - kwargs: The remaining options get passed to :func:`generate_step`. - See :func:`generate_step` for more details. + kwargs: The remaining options get passed to :func:`stream_generate`. + See :func:`stream_generate` for more details. """ - if not isinstance(tokenizer, TokenizerWrapper): - tokenizer = TokenizerWrapper(tokenizer) - + if formatter is not None: + print( + "[Warning] Text formatting is deprecated and no longer used. " + "The argument will be removed in a future version." + ) if verbose: print("=" * 10) print("Prompt:", prompt) - prompt_tokens = mx.array(tokenizer.encode(prompt)) - detokenizer = tokenizer.detokenizer - - with wired_limit(model, [generation_stream]): - tic = time.perf_counter() - detokenizer.reset() - for n, (token, logprobs) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) - - if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - prob = mx.exp(logprobs[token]).item() - formatter(detokenizer.last_segment, prob) - else: - print(detokenizer.last_segment, end="", flush=True) - - token_count = n + 1 - detokenizer.finalize() - + text = "" + for response in stream_generate(model, tokenizer, prompt, **kwargs): if verbose: - gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) - print("=" * 10) - if token_count == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print( - f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" - ) - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 1e9 - print(f"Peak memory: {peak_mem:.3f} GB") + print(response.text, end="", flush=True) + text += response.text - return detokenizer.text + if verbose: + print() + print("=" * 10) + if len(text) == 0: + print("No text generated for this prompt") + return + print( + f"Prompt: {response.prompt_tokens} tokens, " + f"{response.prompt_tps:.3f} tokens-per-sec" + ) + print( + f"Generation: {response.generation_tokens} tokens, " + f"{response.generation_tps:.3f} tokens-per-sec" + ) + print(f"Peak memory: {response.peak_memory:.3f} GB") + return text def load_config(model_path: Path) -> dict: @@ -418,11 +467,11 @@ def load_model( lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` - model_config (dict, optional): Configuration parameters for the model. - Defaults to an empty dictionary. + model_config (dict, optional): Optional configuration parameters for the + model. Defaults to an empty dictionary. get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): A function that returns the model class and model args class given a config. - Defaults to the _get_classes function. + Defaults to the ``_get_classes`` function. Returns: nn.Module: The loaded and initialized model. @@ -431,7 +480,6 @@ def load_model( FileNotFoundError: If the weight files (.safetensors) are not found. ValueError: If the model class or args class are not found or cannot be instantiated. """ - config = load_config(model_path) config.update(model_config) @@ -458,15 +506,20 @@ def load_model( weights = model.sanitize(weights) if (quantization := config.get("quantization", None)) is not None: - # Handle legacy models which may not have everything quantized + def class_predicate(p, m): + # Handle custom per layer quantizations + if p in config["quantization"]: + return config["quantization"][p] if not hasattr(m, "to_quantized"): return False + # Handle legacy models which may not have everything quantized return f"{p}.scales" in weights nn.quantize( model, - **quantization, + group_size=quantization["group_size"], + bits=quantization["bits"], class_predicate=class_predicate, ) @@ -476,7 +529,7 @@ def load_model( mx.eval(model.parameters()) model.eval() - return model + return model, config def load( @@ -509,11 +562,13 @@ def load( """ model_path = get_model_path(path_or_hf_repo) - model = load_model(model_path, lazy, model_config) + model, config = load_model(model_path, lazy) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval() - tokenizer = load_tokenizer(model_path, tokenizer_config) + tokenizer = load_tokenizer( + model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None) + ) return model, tokenizer @@ -521,9 +576,10 @@ def load( def fetch_from_hub( model_path: Path, lazy: bool = False ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: - model = load_model(model_path, lazy) - config = load_config(model_path) - tokenizer = load_tokenizer(model_path) + model, config = load_model(model_path, lazy) + tokenizer = load_tokenizer( + model_path, eos_token_ids=config.get("eos_token_id", None) + ) return model, config, tokenizer @@ -669,7 +725,13 @@ def save_weights( def quantize_model( - model: nn.Module, config: dict, q_group_size: int, q_bits: int + model: nn.Module, + config: dict, + q_group_size: int, + q_bits: int, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ) -> Tuple: """ Applies quantization to the model weights. @@ -679,17 +741,37 @@ def quantize_model( config (dict): Model configuration. q_group_size (int): Group size for quantization. q_bits (int): Bits per weight for quantization. + quant_predicate (Callable): A callable that decides how + to quantize each layer based on the path. + Accepts the layer `path`, the `module` and the model `config`. + Returns either a bool to signify quantize/no quantize or + a dict of quantization parameters to pass to `to_quantized`. Returns: Tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) - nn.quantize(model, q_group_size, q_bits) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} + + # Add any custom quantization parameters to the config as we go + def _class_predicate(p, m): + bool_or_params = quant_predicate(p, m, config) + quantized_config["quantization"][p] = bool_or_params + return bool_or_params + + nn.quantize( + model, + q_group_size, + q_bits, + class_predicate=_class_predicate if quant_predicate else None, + ) # support hf model tree #957 quantized_config["quantization_config"] = quantized_config["quantization"] quantized_weights = dict(tree_flatten(model.parameters())) + bpw = compute_bits_per_weight(model) + print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.") + return quantized_weights, quantized_config @@ -726,6 +808,9 @@ def convert( upload_repo: str = None, revision: Optional[str] = None, dequantize: bool = False, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ): # Check the save path is empty if isinstance(mlx_path, str): @@ -751,7 +836,9 @@ def convert( if quantize: print("[INFO] Quantizing") model.load_weights(list(weights.items())) - weights, config = quantize_model(model, config, q_group_size, q_bits) + weights, config = quantize_model( + model, config, q_group_size, q_bits, quant_predicate=quant_predicate + ) if dequantize: print("[INFO] Dequantizing") diff --git a/llms/setup.py b/llms/setup.py index 1c696dc0..b88dcd33 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -28,12 +28,14 @@ setup( python_requires=">=3.8", extras_require={ "testing": ["datasets"], + "evaluation": ["lm-eval"], }, entry_points={ "console_scripts": [ "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", "mlx_lm.chat = mlx_lm.chat:main", "mlx_lm.convert = mlx_lm.convert:main", + "mlx_lm.evaluate = mlx_lm.evaluate:main", "mlx_lm.fuse = mlx_lm.fuse:main", "mlx_lm.generate = mlx_lm.generate:main", "mlx_lm.lora = mlx_lm.lora:main", diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index e0a372a9..f2345394 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -2,6 +2,7 @@ import unittest +from mlx_lm.sample_utils import make_logits_processors from mlx_lm.utils import generate, load @@ -25,8 +26,8 @@ class TestGenerate(unittest.TestCase): self.tokenizer, "hello", max_tokens=5, + logits_processors=make_logits_processors(logit_bias), verbose=False, - logit_bias=logit_bias, ) self.assertEqual(text, "!!!!!") diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 1efde5ae..374a5113 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -2,7 +2,9 @@ import unittest import mlx.core as mx +import mlx.nn as nn from mlx.utils import tree_map +from mlx_lm.models import rope_utils from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache @@ -126,6 +128,26 @@ class TestModels(unittest.TestCase): self.assertEqual(cache.offset, 22) self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def test_rope(self): + rope = rope_utils.initialize_rope(32, base=100, traditional=False) + self.assertTrue(isinstance(rope, nn.RoPE)) + + rope = rope_utils.initialize_rope( + 32, + base=100, + traditional=False, + scaling_config={"rope_type": "linear", "factor": 10.0}, + ) + self.assertTrue(isinstance(rope, nn.RoPE)) + + rope = rope_utils.initialize_rope( + 32, + base=100, + traditional=False, + scaling_config={"rope_type": "llama3", "factor": 2.0}, + ) + self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE)) + def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(len(model.layers), num_layers) @@ -760,6 +782,75 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_hunyuan(self): + from mlx_lm.models import hunyuan + + args = hunyuan.ModelArgs( + model_type="hunyuan", + hidden_size=128, + attention_bias=False, + intermediate_size=256, + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-4, + rope_theta=1000, + vocab_size=1000, + moe_topk=2, + num_experts=2, + num_shared_expert=1, + use_mixed_mlp_moe=True, + use_qk_norm=True, + rope_scaling={ + "alpha": 1000.0, + "factor": 1.0, + "type": "dynamic", + }, + use_cla=True, + cla_share_factor=2, + ) + model = hunyuan.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_olmo2(self): + from mlx_lm.models import olmo2 + + args = olmo2.ModelArgs( + model_type="olmo2", + hidden_size=128, + attention_bias=False, + intermediate_size=256, + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-4, + rope_theta=1000, + vocab_size=1000, + ) + model = olmo2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_exaone(self): + from mlx_lm.models import exaone + + args = exaone.ModelArgs( + model_type="exaone", + hidden_size=128, + num_layers=4, + intermediate_size=256, + num_attention_heads=8, + num_key_value_heads=2, + vocab_size=1000, + layer_norm_epsilon=1e-4, + rope_theta=10000, + ) + model = exaone.Model(args) + self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 0867ab56..de5694d5 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -121,21 +121,20 @@ class TestPromptCache(unittest.TestCase): def test_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = zip(range(4), generate_step(prompt, model)) - toks, all_logits = zip(*(r[1] for r in results)) + results = list(generate_step(prompt, model, max_tokens=4)) + toks, all_logits = zip(*results) prompt_cache = make_prompt_cache(model) i = 0 - for _, (tok, logits) in zip( - range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + for tok, logits in generate_step( + prompt, model, prompt_cache=prompt_cache, max_tokens=2 ): self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) i += 1 - for _, (tok, logits) in zip( - range(1), - generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + for tok, logits in generate_step( + mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1 ): i += 1 self.assertEqual(tok, toks[i]) diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index ec0e2cb7..ebc90ce8 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -1,10 +1,10 @@ import unittest import mlx.core as mx -from mlx_lm.sample_utils import top_p_sampling +from mlx_lm.sample_utils import min_p_sampling, top_p_sampling -class TestSamplingUtils(unittest.TestCase): +class TestSampleUtils(unittest.TestCase): def test_top_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) @@ -28,6 +28,20 @@ class TestSamplingUtils(unittest.TestCase): token = top_p_sampling(logits, 0.95, temperature).item() self.assertTrue(token in (1, 2, 3)) + def test_min_p_sampling(self): + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + temperature = 1.0 + token = min_p_sampling(logits, 0.8) + self.assertEqual(token, 0) + + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + temperature = 1.0 + for _ in range(5): + token = min_p_sampling(logits, 0.05) + self.assertTrue(token in (0, 3)) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 9c30d51e..db6b9f9e 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase): detokenizer = tokenizer.detokenizer detokenizer.reset() text = "" - for t in tokens: + for e, t in enumerate(tokens): detokenizer.add_token(t) seg = detokenizer.last_segment text += seg + self.assertEqual(detokenizer.tokens, tokens[: e + 1]) detokenizer.finalize() text += detokenizer.last_segment self.assertEqual(text, expected_text) diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py index 73ee1352..5821f9e9 100644 --- a/llms/tests/test_utils_load_model.py +++ b/llms/tests/test_utils_load_model.py @@ -32,7 +32,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): return CustomQwenModel, CustomQwenConfig model_path = get_model_path(HF_MODEL_PATH) - model = load_model(model_path, get_model_classes=custom_get_classes) + model, _ = load_model(model_path, get_model_classes=custom_get_classes) self.assertIsInstance(model, CustomQwenModel) self.assertTrue(hasattr(model, "custom_attribute")) @@ -41,7 +41,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): def test_load_model_with_default_get_classes(self): model_path = get_model_path(HF_MODEL_PATH) - model = load_model(model_path) + model, _ = load_model(model_path) self.assertIsInstance(model, Qwen2Model) diff --git a/speechcommands/main.py b/speechcommands/main.py index 0d8da9fd..ed328f4c 100644 --- a/speechcommands/main.py +++ b/speechcommands/main.py @@ -76,6 +76,7 @@ def train_epoch(model, train_iter, optimizer, epoch): samples_per_sec = [] model.train(True) + train_iter.reset() for batch_counter, batch in enumerate(train_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) @@ -111,6 +112,7 @@ def test_epoch(model, test_iter): model.train(False) accs = [] throughput = [] + test_iter.reset() for batch_counter, batch in enumerate(test_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) diff --git a/whisper/convert.py b/whisper/convert.py index 301fd5b4..7369fafa 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -174,11 +174,6 @@ def load_torch_weights_and_config( "*.txt", ], ) - else: - raise RuntimeError( - f"Model {name_or_path} is not found in {available_models()}," - "on Hugging Face or as a local path." - ) if name_or_path.endswith(".pt"): checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False)