mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 03:05:20 +08:00
Merge branch 'ml-explore:main' into adding-dpo-training
This commit is contained in:
commit
ab35c87911
@ -123,6 +123,18 @@ for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
print()
|
||||
```
|
||||
|
||||
#### Sampling
|
||||
|
||||
The `generate` and `stream_generate` functions accept `sampler` and
|
||||
`logits_processors` keyword arguments. A sampler is any callable which accepts
|
||||
a possibly batched logits array and returns an array of sampled tokens. The
|
||||
`logits_processors` must be a list of callables which take the token history
|
||||
and current logits as input and return the processed logits. The logits
|
||||
processors are applied in order.
|
||||
|
||||
Some standard sampling functions and logits processors are provided in
|
||||
`mlx_lm.sample_utils`.
|
||||
|
||||
### Command Line
|
||||
|
||||
You can also use `mlx-lm` from the command line with:
|
||||
|
@ -382,8 +382,8 @@ def speculative_generate_step(
|
||||
and a bool indicating if the token was generated by the draft model
|
||||
"""
|
||||
|
||||
y = prompt
|
||||
tokens = None
|
||||
y = prompt.astype(mx.uint32)
|
||||
prev_tokens = None
|
||||
|
||||
# Create the KV cache for generation
|
||||
if prompt_cache is None:
|
||||
@ -404,17 +404,41 @@ def speculative_generate_step(
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
def _process_and_sample(tokens, logits):
|
||||
if logits_processors:
|
||||
for processor in logits_processors:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
logprobs = logprobs.squeeze(0)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs
|
||||
|
||||
def _step(model, cache, y, n_predict=1):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=cache)
|
||||
logits = logits[:, -n_predict:, :]
|
||||
|
||||
quantize_cache_fn(cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
logprobs = logprobs.squeeze(0)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs
|
||||
if logits_processors:
|
||||
nonlocal prev_tokens
|
||||
out_y, out_logprobs = [], []
|
||||
if n_predict > 1:
|
||||
y = y[: -(n_predict - 1)]
|
||||
for i in range(n_predict):
|
||||
prev_tokens = (
|
||||
mx.concat([prev_tokens, y]) if prev_tokens is not None else y
|
||||
)
|
||||
y, logprobs = _process_and_sample(
|
||||
prev_tokens, logits[:, i : i + 1, :]
|
||||
)
|
||||
out_y.append(y)
|
||||
out_logprobs.append(logprobs)
|
||||
return mx.concatenate(out_y, axis=0), mx.concatenate(
|
||||
out_logprobs, axis=0
|
||||
)
|
||||
else:
|
||||
return _process_and_sample(None, logits)
|
||||
|
||||
def _prefill(model, cache, y):
|
||||
while y.size > prefill_step_size:
|
||||
@ -451,9 +475,14 @@ def speculative_generate_step(
|
||||
while True:
|
||||
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
||||
draft_tokens = _draft_generate(draft_y, num_draft)
|
||||
if prev_tokens is not None:
|
||||
prev_tokens = prev_tokens[
|
||||
: prev_tokens.size - draft_y.size - num_draft + 1
|
||||
]
|
||||
y = mx.concatenate([y, draft_tokens])
|
||||
|
||||
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
|
||||
|
||||
mx.eval(tokens, draft_tokens)
|
||||
draft_tokens = draft_tokens.tolist()
|
||||
tokens = tokens.tolist()
|
||||
@ -485,6 +514,8 @@ def speculative_generate_step(
|
||||
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
|
||||
)
|
||||
|
||||
if prev_tokens is not None and n < num_draft:
|
||||
prev_tokens = prev_tokens[: -(num_draft - n)]
|
||||
_rewind_cache(num_draft, n)
|
||||
finally:
|
||||
_rewind_cache(num_draft, n)
|
||||
|
Loading…
Reference in New Issue
Block a user