Allow prompt callback to generate_step (#1133)

* allow prompt callback and use in cache_prompt

* nit

* comments

* bump version
This commit is contained in:
Awni Hannun 2024-12-03 16:17:14 -08:00 committed by GitHub
parent 0ca162cfb2
commit 1963df8565
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 48 additions and 48 deletions

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.20.1" __version__ = "0.20.2"

View File

@ -8,7 +8,7 @@ import time
import mlx.core as mx import mlx.core as mx
from .models.cache import make_prompt_cache, save_prompt_cache from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load, maybe_quantize_kv_cache from .utils import generate_step, load
DEFAULT_QUANTIZED_KV_START = 5000 DEFAULT_QUANTIZED_KV_START = 5000
@ -50,12 +50,6 @@ def setup_arg_parser():
action="store_true", action="store_true",
help="Use the default chat template", help="Use the default chat template",
) )
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
)
parser.add_argument( parser.add_argument(
"--max-kv-size", "--max-kv-size",
type=int, type=int,
@ -99,9 +93,6 @@ def main():
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config # Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None: if args.eos_token is not None:
@ -144,26 +135,28 @@ def main():
y = mx.array(tokenizer.encode(prompt)) y = mx.array(tokenizer.encode(prompt))
# Process the prompt # Process the prompt
processed = 0
step_size = 512
start = time.time() start = time.time()
max_msg_len = 0 max_msg_len = 0
while y.size > 0:
model(y[:step_size][None], cache=cache) def callback(processed, total_tokens):
mx.eval([c.state for c in cache])
mx.metal.clear_cache()
processed += min(y.size, step_size)
y = y[step_size:]
current = time.time() current = time.time()
speed = processed / (current - start) speed = processed / (current - start)
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
nonlocal max_msg_len
max_msg_len = max(max_msg_len, len(msg)) max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
maybe_quantize_kv_cache( for _ in generate_step(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits y,
) model,
max_tokens=0,
prompt_cache=cache,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
prompt_progress_callback=callback,
):
pass
print() print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")

View File

@ -77,7 +77,7 @@ def setup_arg_parser():
) )
parser.add_argument( parser.add_argument(
"--min-tokens-to-keep", "--min-tokens-to-keep",
type=float, type=int,
default=DEFAULT_MIN_TOKENS_TO_KEEP, default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.", help="Minimum tokens to keep for min-p sampling.",
) )

View File

@ -183,6 +183,7 @@ def generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
*, *,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None, sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None, max_kv_size: Optional[int] = None,
@ -191,6 +192,7 @@ def generate_step(
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None,
temp: Optional[float] = None, temp: Optional[float] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None, repetition_context_size: Optional[int] = None,
@ -204,21 +206,25 @@ def generate_step(
Args: Args:
prompt (mx.array): The input prompt. prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
prefill_step_size (int): Step size for processing the prompt. max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
max_kv_size (int, optional): Maximum size of the key-value cache. Old generator. Default: ``256``.
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``. token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed A list of functions that take tokens and logits and return the processed
logits. Default: ``None``. logits. Default: ``None``.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization. kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``. None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``. kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache. quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.
Yields: Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities. Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
@ -253,6 +259,7 @@ def generate_step(
logits_processors = logits_processors or make_logits_processors( logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20 None, repetition_penalty, repetition_context_size or 20
) )
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
def _step(y): def _step(y):
with mx.stream(generation_stream): with mx.stream(generation_stream):
@ -275,9 +282,13 @@ def generate_step(
return y, logprobs.squeeze(0) return y, logprobs.squeeze(0)
with mx.stream(generation_stream): with mx.stream(generation_stream):
total_prompt_tokens = y.size
prompt_processed_tokens = 0
while y.size > prefill_step_size: while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache) model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache]) mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size
y = y[prefill_step_size:] y = y[prefill_step_size:]
mx.metal.clear_cache() mx.metal.clear_cache()
@ -286,20 +297,25 @@ def generate_step(
mx.async_eval(y, logprobs) mx.async_eval(y, logprobs)
n = 0 n = 0
while True: while True:
if n != max_tokens:
next_y, next_logprobs = _step(y) next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs) mx.async_eval(next_y, next_logprobs)
if n == 0:
mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
if n == max_tokens:
break
yield y.item(), logprobs yield y.item(), logprobs
if n % 256 == 0: if n % 256 == 0:
mx.metal.clear_cache() mx.metal.clear_cache()
n += 1
y, logprobs = next_y, next_logprobs y, logprobs = next_y, next_logprobs
n += 1
def stream_generate( def stream_generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]], prompt: Union[str, mx.array, List[int]],
max_tokens: int = 100,
**kwargs, **kwargs,
) -> Generator[GenerationResponse, None, None]: ) -> Generator[GenerationResponse, None, None]:
""" """
@ -309,7 +325,6 @@ def stream_generate(
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`. kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details. See :func:`generate_step` for more details.
@ -330,10 +345,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
tic = time.perf_counter() tic = time.perf_counter()
for n, (token, logprobs) in zip( for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
range(max_tokens),
generate_step(prompt, model, **kwargs),
):
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time prompt_tps = prompt.size / prompt_time
@ -343,9 +355,6 @@ def stream_generate(
detokenizer.add_token(token) detokenizer.add_token(token)
if n == (max_tokens - 1):
break
yield GenerationResponse( yield GenerationResponse(
text=detokenizer.last_segment, text=detokenizer.last_segment,
token=token, token=token,
@ -385,7 +394,6 @@ def generate(
model (nn.Module): The language model. model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt. prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information. verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``. Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`. kwargs: The remaining options get passed to :func:`stream_generate`.

View File

@ -121,21 +121,20 @@ class TestPromptCache(unittest.TestCase):
def test_cache_with_generate(self): def test_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH) model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = zip(range(4), generate_step(prompt, model)) results = list(generate_step(prompt, model, max_tokens=4))
toks, all_logits = zip(*(r[1] for r in results)) toks, all_logits = zip(*results)
prompt_cache = make_prompt_cache(model) prompt_cache = make_prompt_cache(model)
i = 0 i = 0
for _, (tok, logits) in zip( for tok, logits in generate_step(
range(2), generate_step(prompt, model, prompt_cache=prompt_cache) prompt, model, prompt_cache=prompt_cache, max_tokens=2
): ):
self.assertEqual(tok, toks[i]) self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i])) self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1 i += 1
for _, (tok, logits) in zip( for tok, logits in generate_step(
range(1), mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
): ):
i += 1 i += 1
self.assertEqual(tok, toks[i]) self.assertEqual(tok, toks[i])