Merge branch 'ml-explore:main' into adding-dpo-training

This commit is contained in:
Gökdeniz Gülmez 2025-02-18 17:17:49 +01:00 committed by GitHub
commit ab35c87911
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 7 deletions

View File

@ -123,6 +123,18 @@ for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
print()
```
#### Sampling
The `generate` and `stream_generate` functions accept `sampler` and
`logits_processors` keyword arguments. A sampler is any callable which accepts
a possibly batched logits array and returns an array of sampled tokens. The
`logits_processors` must be a list of callables which take the token history
and current logits as input and return the processed logits. The logits
processors are applied in order.
Some standard sampling functions and logits processors are provided in
`mlx_lm.sample_utils`.
### Command Line
You can also use `mlx-lm` from the command line with:

View File

@ -382,8 +382,8 @@ def speculative_generate_step(
and a bool indicating if the token was generated by the draft model
"""
y = prompt
tokens = None
y = prompt.astype(mx.uint32)
prev_tokens = None
# Create the KV cache for generation
if prompt_cache is None:
@ -404,17 +404,41 @@ def speculative_generate_step(
kv_bits=kv_bits,
)
def _process_and_sample(tokens, logits):
if logits_processors:
for processor in logits_processors:
logits = processor(tokens, logits)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
logprobs = logprobs.squeeze(0)
y = sampler(logprobs)
return y, logprobs
def _step(model, cache, y, n_predict=1):
with mx.stream(generation_stream):
logits = model(y[None], cache=cache)
logits = logits[:, -n_predict:, :]
quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
logprobs = logprobs.squeeze(0)
y = sampler(logprobs)
return y, logprobs
if logits_processors:
nonlocal prev_tokens
out_y, out_logprobs = [], []
if n_predict > 1:
y = y[: -(n_predict - 1)]
for i in range(n_predict):
prev_tokens = (
mx.concat([prev_tokens, y]) if prev_tokens is not None else y
)
y, logprobs = _process_and_sample(
prev_tokens, logits[:, i : i + 1, :]
)
out_y.append(y)
out_logprobs.append(logprobs)
return mx.concatenate(out_y, axis=0), mx.concatenate(
out_logprobs, axis=0
)
else:
return _process_and_sample(None, logits)
def _prefill(model, cache, y):
while y.size > prefill_step_size:
@ -451,9 +475,14 @@ def speculative_generate_step(
while True:
num_draft = min(max_tokens - ntoks, num_draft_tokens)
draft_tokens = _draft_generate(draft_y, num_draft)
if prev_tokens is not None:
prev_tokens = prev_tokens[
: prev_tokens.size - draft_y.size - num_draft + 1
]
y = mx.concatenate([y, draft_tokens])
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
mx.eval(tokens, draft_tokens)
draft_tokens = draft_tokens.tolist()
tokens = tokens.tolist()
@ -485,6 +514,8 @@ def speculative_generate_step(
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
)
if prev_tokens is not None and n < num_draft:
prev_tokens = prev_tokens[: -(num_draft - n)]
_rewind_cache(num_draft, n)
finally:
_rewind_cache(num_draft, n)