From 7b07b14e6742b2c162a20587f15abed972618c02 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 13 Feb 2025 19:19:53 -0800 Subject: [PATCH 1/2] add logits processor to spec gen (#1260) --- llms/mlx_lm/utils.py | 45 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 64813123..78a2e802 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -382,8 +382,8 @@ def speculative_generate_step( and a bool indicating if the token was generated by the draft model """ - y = prompt - tokens = None + y = prompt.astype(mx.uint32) + prev_tokens = None # Create the KV cache for generation if prompt_cache is None: @@ -404,17 +404,41 @@ def speculative_generate_step( kv_bits=kv_bits, ) + def _process_and_sample(tokens, logits): + if logits_processors: + for processor in logits_processors: + logits = processor(tokens, logits) + + logprobs = logits - mx.logsumexp(logits, keepdims=True) + logprobs = logprobs.squeeze(0) + y = sampler(logprobs) + return y, logprobs + def _step(model, cache, y, n_predict=1): with mx.stream(generation_stream): logits = model(y[None], cache=cache) logits = logits[:, -n_predict:, :] quantize_cache_fn(cache) - - logprobs = logits - mx.logsumexp(logits, keepdims=True) - logprobs = logprobs.squeeze(0) - y = sampler(logprobs) - return y, logprobs + if logits_processors: + nonlocal prev_tokens + out_y, out_logprobs = [], [] + if n_predict > 1: + y = y[: -(n_predict - 1)] + for i in range(n_predict): + prev_tokens = ( + mx.concat([prev_tokens, y]) if prev_tokens is not None else y + ) + y, logprobs = _process_and_sample( + prev_tokens, logits[:, i : i + 1, :] + ) + out_y.append(y) + out_logprobs.append(logprobs) + return mx.concatenate(out_y, axis=0), mx.concatenate( + out_logprobs, axis=0 + ) + else: + return _process_and_sample(None, logits) def _prefill(model, cache, y): while y.size > prefill_step_size: @@ -451,9 +475,14 @@ def speculative_generate_step( while True: num_draft = min(max_tokens - ntoks, num_draft_tokens) draft_tokens = _draft_generate(draft_y, num_draft) + if prev_tokens is not None: + prev_tokens = prev_tokens[ + : prev_tokens.size - draft_y.size - num_draft + 1 + ] y = mx.concatenate([y, draft_tokens]) tokens, logprobs = _step(model, model_cache, y, num_draft + 1) + mx.eval(tokens, draft_tokens) draft_tokens = draft_tokens.tolist() tokens = tokens.tolist() @@ -485,6 +514,8 @@ def speculative_generate_step( [mx.array(draft_tokens[-1:], mx.uint32), draft_y] ) + if prev_tokens is not None and n < num_draft: + prev_tokens = prev_tokens[: -(num_draft - n)] _rewind_cache(num_draft, n) finally: _rewind_cache(num_draft, n) From 96bf37008e91de86538bdacf3a12a479a322902b Mon Sep 17 00:00:00 2001 From: Matthias Neumayer Date: Fri, 14 Feb 2025 04:32:56 +0100 Subject: [PATCH 2/2] Update README.md to include how to set temperature (#1280) * Update README.md to include how to set temperature * nits --------- Co-authored-by: Awni Hannun --- llms/README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/llms/README.md b/llms/README.md index 4f7451c1..e2d1db59 100644 --- a/llms/README.md +++ b/llms/README.md @@ -123,6 +123,18 @@ for response in stream_generate(model, tokenizer, prompt, max_tokens=512): print() ``` +#### Sampling + +The `generate` and `stream_generate` functions accept `sampler` and +`logits_processors` keyword arguments. A sampler is any callable which accepts +a possibly batched logits array and returns an array of sampled tokens. The +`logits_processors` must be a list of callables which take the token history +and current logits as input and return the processed logits. The logits +processors are applied in order. + +Some standard sampling functions and logits processors are provided in +`mlx_lm.sample_utils`. + ### Command Line You can also use `mlx-lm` from the command line with: