mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
fixes
This commit is contained in:
parent
e34ecb79b4
commit
8c0b4ee7f3
@ -131,6 +131,18 @@ def setup_arg_parser():
|
||||
type=int,
|
||||
default=DEFAULT_QUANTIZED_KV_START,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--draft-model",
|
||||
type=str,
|
||||
help="A model to be used for speculative decoding.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-draft-tokens",
|
||||
type=int,
|
||||
help="Number of tokens to draft when using speculative decoding.",
|
||||
default=2,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -211,11 +223,16 @@ def main():
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompt = prompt[test_prompt.index("<query>") :]
|
||||
|
||||
prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
else:
|
||||
prompt = tokenizer.encode(prompt)
|
||||
|
||||
if args.draft_model is not None:
|
||||
draft_model, draft_tokenizer = load(args.draft_model)
|
||||
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
|
||||
raise ValueError("Draft model tokenizer does not match model tokenizer.")
|
||||
else:
|
||||
draft_model = None
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
response = generate(
|
||||
model,
|
||||
@ -229,6 +246,8 @@ def main():
|
||||
kv_bits=args.kv_bits,
|
||||
kv_group_size=args.kv_group_size,
|
||||
quantized_kv_start=args.quantized_kv_start,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=args.num_draft_tokens,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
|
@ -319,6 +319,8 @@ def speculative_generate_step(
|
||||
*,
|
||||
num_draft_tokens=2,
|
||||
max_tokens: int = 256,
|
||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: Optional[int] = None,
|
||||
@ -336,6 +338,11 @@ def speculative_generate_step(
|
||||
speculative decoding. Default: ``2``.
|
||||
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||
generator. Default: ``256``.
|
||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||
token from a vector of log probabilities. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place. The cache must be trimmable.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
@ -362,6 +369,15 @@ def speculative_generate_step(
|
||||
model_cache = prompt_cache[: len(model.layers)]
|
||||
draft_cache = prompt_cache[len(model.layers) :]
|
||||
|
||||
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
def _step(model, cache, y, n_predict=1):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=cache)
|
||||
@ -370,7 +386,7 @@ def speculative_generate_step(
|
||||
quantize_cache_fn(cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = mx.argmax(logprobs, axis=-1).squeeze(0)
|
||||
y = sampler(logprobs).squeeze(0)
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
def _prefill(model, cache, y):
|
||||
@ -401,6 +417,9 @@ def speculative_generate_step(
|
||||
y = _prefill(model, model_cache, y)
|
||||
|
||||
ntoks = 0
|
||||
# Set these so the finally block doesn't raise
|
||||
num_draft = 0
|
||||
n = 0
|
||||
try:
|
||||
while True:
|
||||
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
||||
@ -484,8 +503,10 @@ def stream_generate(
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
if draft_model is None:
|
||||
kwargs.pop("num_draft_tokens")
|
||||
token_generator = generate_step(prompt, model, **kwargs)
|
||||
else:
|
||||
kwargs.pop("max_kv_size")
|
||||
token_generator = speculative_generate_step(
|
||||
prompt, model, draft_model, **kwargs
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user