From 8c0b4ee7f347d002bd89672509793592bb5af253 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 13 Dec 2024 20:21:34 -0800 Subject: [PATCH] fixes --- llms/mlx_lm/generate.py | 21 ++++++++++++++++++++- llms/mlx_lm/utils.py | 23 ++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 26481d6b..0d286c75 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -131,6 +131,18 @@ def setup_arg_parser(): type=int, default=DEFAULT_QUANTIZED_KV_START, ) + parser.add_argument( + "--draft-model", + type=str, + help="A model to be used for speculative decoding.", + default=None, + ) + parser.add_argument( + "--num-draft-tokens", + type=int, + help="Number of tokens to draft when using speculative decoding.", + default=2, + ) return parser @@ -211,11 +223,16 @@ def main(): add_generation_prompt=True, ) prompt = prompt[test_prompt.index("") :] - prompt = tokenizer.encode(prompt, add_special_tokens=False) else: prompt = tokenizer.encode(prompt) + if args.draft_model is not None: + draft_model, draft_tokenizer = load(args.draft_model) + if draft_tokenizer.vocab_size != tokenizer.vocab_size: + raise ValueError("Draft model tokenizer does not match model tokenizer.") + else: + draft_model = None sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, @@ -229,6 +246,8 @@ def main(): kv_bits=args.kv_bits, kv_group_size=args.kv_group_size, quantized_kv_start=args.quantized_kv_start, + draft_model=draft_model, + num_draft_tokens=args.num_draft_tokens, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index cc420bf0..421e076b 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -319,6 +319,8 @@ def speculative_generate_step( *, num_draft_tokens=2, max_tokens: int = 256, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, prompt_cache: Optional[Any] = None, prefill_step_size: int = 512, kv_bits: Optional[int] = None, @@ -336,6 +338,11 @@ def speculative_generate_step( speculative decoding. Default: ``2``. max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite generator. Default: ``256``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. The cache must be trimmable. prefill_step_size (int): Step size for processing the prompt. @@ -362,6 +369,15 @@ def speculative_generate_step( model_cache = prompt_cache[: len(model.layers)] draft_cache = prompt_cache[len(model.layers) :] + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + def _step(model, cache, y, n_predict=1): with mx.stream(generation_stream): logits = model(y[None], cache=cache) @@ -370,7 +386,7 @@ def speculative_generate_step( quantize_cache_fn(cache) logprobs = logits - mx.logsumexp(logits, keepdims=True) - y = mx.argmax(logprobs, axis=-1).squeeze(0) + y = sampler(logprobs).squeeze(0) return y, logprobs.squeeze(0) def _prefill(model, cache, y): @@ -401,6 +417,9 @@ def speculative_generate_step( y = _prefill(model, model_cache, y) ntoks = 0 + # Set these so the finally block doesn't raise + num_draft = 0 + n = 0 try: while True: num_draft = min(max_tokens - ntoks, num_draft_tokens) @@ -484,8 +503,10 @@ def stream_generate( detokenizer = tokenizer.detokenizer if draft_model is None: + kwargs.pop("num_draft_tokens") token_generator = generate_step(prompt, model, **kwargs) else: + kwargs.pop("max_kv_size") token_generator = speculative_generate_step( prompt, model, draft_model, **kwargs )