mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-18 16:31:12 +08:00
Allow prompt callback to generate_step
(#1133)
* allow prompt callback and use in cache_prompt * nit * comments * bump version
This commit is contained in:
parent
0ca162cfb2
commit
1963df8565
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.20.1"
|
__version__ = "0.20.2"
|
||||||
|
@ -8,7 +8,7 @@ import time
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .models.cache import make_prompt_cache, save_prompt_cache
|
from .models.cache import make_prompt_cache, save_prompt_cache
|
||||||
from .utils import load, maybe_quantize_kv_cache
|
from .utils import generate_step, load
|
||||||
|
|
||||||
DEFAULT_QUANTIZED_KV_START = 5000
|
DEFAULT_QUANTIZED_KV_START = 5000
|
||||||
|
|
||||||
@ -50,12 +50,6 @@ def setup_arg_parser():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use the default chat template",
|
help="Use the default chat template",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--cache-limit-gb",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Set the MLX cache limit in GB",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-kv-size",
|
"--max-kv-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -99,9 +93,6 @@ def main():
|
|||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.cache_limit_gb is not None:
|
|
||||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
|
||||||
|
|
||||||
# Building tokenizer_config
|
# Building tokenizer_config
|
||||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||||
if args.eos_token is not None:
|
if args.eos_token is not None:
|
||||||
@ -144,26 +135,28 @@ def main():
|
|||||||
y = mx.array(tokenizer.encode(prompt))
|
y = mx.array(tokenizer.encode(prompt))
|
||||||
|
|
||||||
# Process the prompt
|
# Process the prompt
|
||||||
processed = 0
|
|
||||||
step_size = 512
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
max_msg_len = 0
|
max_msg_len = 0
|
||||||
while y.size > 0:
|
|
||||||
|
|
||||||
model(y[:step_size][None], cache=cache)
|
def callback(processed, total_tokens):
|
||||||
mx.eval([c.state for c in cache])
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
processed += min(y.size, step_size)
|
|
||||||
y = y[step_size:]
|
|
||||||
current = time.time()
|
current = time.time()
|
||||||
speed = processed / (current - start)
|
speed = processed / (current - start)
|
||||||
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
|
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
|
||||||
|
nonlocal max_msg_len
|
||||||
max_msg_len = max(max_msg_len, len(msg))
|
max_msg_len = max(max_msg_len, len(msg))
|
||||||
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
|
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
|
||||||
|
|
||||||
maybe_quantize_kv_cache(
|
for _ in generate_step(
|
||||||
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
|
y,
|
||||||
)
|
model,
|
||||||
|
max_tokens=0,
|
||||||
|
prompt_cache=cache,
|
||||||
|
kv_bits=args.kv_bits,
|
||||||
|
kv_group_size=args.kv_group_size,
|
||||||
|
quantized_kv_start=args.quantized_kv_start,
|
||||||
|
prompt_progress_callback=callback,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
|
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
|
||||||
|
@ -77,7 +77,7 @@ def setup_arg_parser():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min-tokens-to-keep",
|
"--min-tokens-to-keep",
|
||||||
type=float,
|
type=int,
|
||||||
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
||||||
help="Minimum tokens to keep for min-p sampling.",
|
help="Minimum tokens to keep for min-p sampling.",
|
||||||
)
|
)
|
||||||
|
@ -183,6 +183,7 @@ def generate_step(
|
|||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
*,
|
*,
|
||||||
|
max_tokens: int = 256,
|
||||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
@ -191,6 +192,7 @@ def generate_step(
|
|||||||
kv_bits: Optional[int] = None,
|
kv_bits: Optional[int] = None,
|
||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
quantized_kv_start: int = 0,
|
quantized_kv_start: int = 0,
|
||||||
|
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
||||||
temp: Optional[float] = None,
|
temp: Optional[float] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
repetition_context_size: Optional[int] = None,
|
repetition_context_size: Optional[int] = None,
|
||||||
@ -204,21 +206,25 @@ def generate_step(
|
|||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
prefill_step_size (int): Step size for processing the prompt.
|
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
generator. Default: ``256``.
|
||||||
entries (except the first 4 tokens) will be overwritten.
|
|
||||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
|
||||||
provided, the cache will be updated in place.
|
|
||||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||||
token from a vector of log probabilities. Default: ``None``.
|
token from a vector of log probabilities. Default: ``None``.
|
||||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
A list of functions that take tokens and logits and return the processed
|
A list of functions that take tokens and logits and return the processed
|
||||||
logits. Default: ``None``.
|
logits. Default: ``None``.
|
||||||
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||||
|
provided, the cache will be updated in place.
|
||||||
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||||
None implies no cache quantization. Default: ``None``.
|
None implies no cache quantization. Default: ``None``.
|
||||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||||
when ``kv_bits`` is non-None. Default: ``0``.
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
|
||||||
|
prompt tokens processed so far and the total number of prompt tokens.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||||
@ -253,6 +259,7 @@ def generate_step(
|
|||||||
logits_processors = logits_processors or make_logits_processors(
|
logits_processors = logits_processors or make_logits_processors(
|
||||||
None, repetition_penalty, repetition_context_size or 20
|
None, repetition_penalty, repetition_context_size or 20
|
||||||
)
|
)
|
||||||
|
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
with mx.stream(generation_stream):
|
with mx.stream(generation_stream):
|
||||||
@ -275,9 +282,13 @@ def generate_step(
|
|||||||
return y, logprobs.squeeze(0)
|
return y, logprobs.squeeze(0)
|
||||||
|
|
||||||
with mx.stream(generation_stream):
|
with mx.stream(generation_stream):
|
||||||
|
total_prompt_tokens = y.size
|
||||||
|
prompt_processed_tokens = 0
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||||
mx.eval([c.state for c in prompt_cache])
|
mx.eval([c.state for c in prompt_cache])
|
||||||
|
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
||||||
|
prompt_processed_tokens += prefill_step_size
|
||||||
y = y[prefill_step_size:]
|
y = y[prefill_step_size:]
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
@ -286,20 +297,25 @@ def generate_step(
|
|||||||
mx.async_eval(y, logprobs)
|
mx.async_eval(y, logprobs)
|
||||||
n = 0
|
n = 0
|
||||||
while True:
|
while True:
|
||||||
|
if n != max_tokens:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.async_eval(next_y, next_logprobs)
|
mx.async_eval(next_y, next_logprobs)
|
||||||
|
if n == 0:
|
||||||
|
mx.eval(y)
|
||||||
|
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||||
|
if n == max_tokens:
|
||||||
|
break
|
||||||
yield y.item(), logprobs
|
yield y.item(), logprobs
|
||||||
if n % 256 == 0:
|
if n % 256 == 0:
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
n += 1
|
|
||||||
y, logprobs = next_y, next_logprobs
|
y, logprobs = next_y, next_logprobs
|
||||||
|
n += 1
|
||||||
|
|
||||||
|
|
||||||
def stream_generate(
|
def stream_generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: Union[str, mx.array, List[int]],
|
prompt: Union[str, mx.array, List[int]],
|
||||||
max_tokens: int = 100,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Generator[GenerationResponse, None, None]:
|
) -> Generator[GenerationResponse, None, None]:
|
||||||
"""
|
"""
|
||||||
@ -309,7 +325,6 @@ def stream_generate(
|
|||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
|
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
|
||||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
|
||||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
See :func:`generate_step` for more details.
|
See :func:`generate_step` for more details.
|
||||||
|
|
||||||
@ -330,10 +345,7 @@ def stream_generate(
|
|||||||
with wired_limit(model, [generation_stream]):
|
with wired_limit(model, [generation_stream]):
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for n, (token, logprobs) in zip(
|
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
|
||||||
range(max_tokens),
|
|
||||||
generate_step(prompt, model, **kwargs),
|
|
||||||
):
|
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
prompt_tps = prompt.size / prompt_time
|
prompt_tps = prompt.size / prompt_time
|
||||||
@ -343,9 +355,6 @@ def stream_generate(
|
|||||||
|
|
||||||
detokenizer.add_token(token)
|
detokenizer.add_token(token)
|
||||||
|
|
||||||
if n == (max_tokens - 1):
|
|
||||||
break
|
|
||||||
|
|
||||||
yield GenerationResponse(
|
yield GenerationResponse(
|
||||||
text=detokenizer.last_segment,
|
text=detokenizer.last_segment,
|
||||||
token=token,
|
token=token,
|
||||||
@ -385,7 +394,6 @@ def generate(
|
|||||||
model (nn.Module): The language model.
|
model (nn.Module): The language model.
|
||||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
prompt (str): The string prompt.
|
prompt (str): The string prompt.
|
||||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
|
||||||
verbose (bool): If ``True``, print tokens and timing information.
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
kwargs: The remaining options get passed to :func:`stream_generate`.
|
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||||
|
@ -121,21 +121,20 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
def test_cache_with_generate(self):
|
def test_cache_with_generate(self):
|
||||||
model, tokenizer = load(HF_MODEL_PATH)
|
model, tokenizer = load(HF_MODEL_PATH)
|
||||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||||
results = zip(range(4), generate_step(prompt, model))
|
results = list(generate_step(prompt, model, max_tokens=4))
|
||||||
toks, all_logits = zip(*(r[1] for r in results))
|
toks, all_logits = zip(*results)
|
||||||
|
|
||||||
prompt_cache = make_prompt_cache(model)
|
prompt_cache = make_prompt_cache(model)
|
||||||
i = 0
|
i = 0
|
||||||
for _, (tok, logits) in zip(
|
for tok, logits in generate_step(
|
||||||
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
|
prompt, model, prompt_cache=prompt_cache, max_tokens=2
|
||||||
):
|
):
|
||||||
self.assertEqual(tok, toks[i])
|
self.assertEqual(tok, toks[i])
|
||||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
for _, (tok, logits) in zip(
|
for tok, logits in generate_step(
|
||||||
range(1),
|
mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1
|
||||||
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
|
|
||||||
):
|
):
|
||||||
i += 1
|
i += 1
|
||||||
self.assertEqual(tok, toks[i])
|
self.assertEqual(tok, toks[i])
|
||||||
|
Loading…
Reference in New Issue
Block a user