mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
fixes
This commit is contained in:
parent
e34ecb79b4
commit
8c0b4ee7f3
@ -131,6 +131,18 @@ def setup_arg_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=DEFAULT_QUANTIZED_KV_START,
|
default=DEFAULT_QUANTIZED_KV_START,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--draft-model",
|
||||||
|
type=str,
|
||||||
|
help="A model to be used for speculative decoding.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-draft-tokens",
|
||||||
|
type=int,
|
||||||
|
help="Number of tokens to draft when using speculative decoding.",
|
||||||
|
default=2,
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -211,11 +223,16 @@ def main():
|
|||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
prompt = prompt[test_prompt.index("<query>") :]
|
prompt = prompt[test_prompt.index("<query>") :]
|
||||||
|
|
||||||
prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
else:
|
else:
|
||||||
prompt = tokenizer.encode(prompt)
|
prompt = tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
if args.draft_model is not None:
|
||||||
|
draft_model, draft_tokenizer = load(args.draft_model)
|
||||||
|
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
|
||||||
|
raise ValueError("Draft model tokenizer does not match model tokenizer.")
|
||||||
|
else:
|
||||||
|
draft_model = None
|
||||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
@ -229,6 +246,8 @@ def main():
|
|||||||
kv_bits=args.kv_bits,
|
kv_bits=args.kv_bits,
|
||||||
kv_group_size=args.kv_group_size,
|
kv_group_size=args.kv_group_size,
|
||||||
quantized_kv_start=args.quantized_kv_start,
|
quantized_kv_start=args.quantized_kv_start,
|
||||||
|
draft_model=draft_model,
|
||||||
|
num_draft_tokens=args.num_draft_tokens,
|
||||||
)
|
)
|
||||||
if not args.verbose:
|
if not args.verbose:
|
||||||
print(response)
|
print(response)
|
||||||
|
@ -319,6 +319,8 @@ def speculative_generate_step(
|
|||||||
*,
|
*,
|
||||||
num_draft_tokens=2,
|
num_draft_tokens=2,
|
||||||
max_tokens: int = 256,
|
max_tokens: int = 256,
|
||||||
|
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||||
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
prompt_cache: Optional[Any] = None,
|
prompt_cache: Optional[Any] = None,
|
||||||
prefill_step_size: int = 512,
|
prefill_step_size: int = 512,
|
||||||
kv_bits: Optional[int] = None,
|
kv_bits: Optional[int] = None,
|
||||||
@ -336,6 +338,11 @@ def speculative_generate_step(
|
|||||||
speculative decoding. Default: ``2``.
|
speculative decoding. Default: ``2``.
|
||||||
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||||
generator. Default: ``256``.
|
generator. Default: ``256``.
|
||||||
|
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||||
|
token from a vector of log probabilities. Default: ``None``.
|
||||||
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
|
A list of functions that take tokens and logits and return the processed
|
||||||
|
logits. Default: ``None``.
|
||||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||||
provided, the cache will be updated in place. The cache must be trimmable.
|
provided, the cache will be updated in place. The cache must be trimmable.
|
||||||
prefill_step_size (int): Step size for processing the prompt.
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
@ -362,6 +369,15 @@ def speculative_generate_step(
|
|||||||
model_cache = prompt_cache[: len(model.layers)]
|
model_cache = prompt_cache[: len(model.layers)]
|
||||||
draft_cache = prompt_cache[len(model.layers) :]
|
draft_cache = prompt_cache[len(model.layers) :]
|
||||||
|
|
||||||
|
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
||||||
|
|
||||||
|
quantize_cache_fn = functools.partial(
|
||||||
|
maybe_quantize_kv_cache,
|
||||||
|
quantized_kv_start=quantized_kv_start,
|
||||||
|
kv_group_size=kv_group_size,
|
||||||
|
kv_bits=kv_bits,
|
||||||
|
)
|
||||||
|
|
||||||
def _step(model, cache, y, n_predict=1):
|
def _step(model, cache, y, n_predict=1):
|
||||||
with mx.stream(generation_stream):
|
with mx.stream(generation_stream):
|
||||||
logits = model(y[None], cache=cache)
|
logits = model(y[None], cache=cache)
|
||||||
@ -370,7 +386,7 @@ def speculative_generate_step(
|
|||||||
quantize_cache_fn(cache)
|
quantize_cache_fn(cache)
|
||||||
|
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
y = mx.argmax(logprobs, axis=-1).squeeze(0)
|
y = sampler(logprobs).squeeze(0)
|
||||||
return y, logprobs.squeeze(0)
|
return y, logprobs.squeeze(0)
|
||||||
|
|
||||||
def _prefill(model, cache, y):
|
def _prefill(model, cache, y):
|
||||||
@ -401,6 +417,9 @@ def speculative_generate_step(
|
|||||||
y = _prefill(model, model_cache, y)
|
y = _prefill(model, model_cache, y)
|
||||||
|
|
||||||
ntoks = 0
|
ntoks = 0
|
||||||
|
# Set these so the finally block doesn't raise
|
||||||
|
num_draft = 0
|
||||||
|
n = 0
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
||||||
@ -484,8 +503,10 @@ def stream_generate(
|
|||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
if draft_model is None:
|
if draft_model is None:
|
||||||
|
kwargs.pop("num_draft_tokens")
|
||||||
token_generator = generate_step(prompt, model, **kwargs)
|
token_generator = generate_step(prompt, model, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
kwargs.pop("max_kv_size")
|
||||||
token_generator = speculative_generate_step(
|
token_generator = speculative_generate_step(
|
||||||
prompt, model, draft_model, **kwargs
|
prompt, model, draft_model, **kwargs
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user