mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Better support for rotating KV-cache and add stop word list as an argument for generate and stream_generate
This commit is contained in:
@@ -145,23 +145,120 @@ def make_kv_caches(
|
||||
else:
|
||||
return [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
def max_common_prefix(ls1: List[int], ls2: List[int]) -> str:
|
||||
"""Find the maximum number of shared tokens from the start of the lists.
|
||||
|
||||
Args:
|
||||
ls1 (List[int]): Token ID list 1.
|
||||
ls2 (List[int]): Token ID list 2.
|
||||
|
||||
Returns:
|
||||
str: The number of shared tokens in the prefix.
|
||||
"""
|
||||
import itertools
|
||||
return sum(1 for _ in itertools.takewhile(lambda x: x[0] == x[1], zip(ls1, ls2)))
|
||||
class StepOutput(NamedTuple):
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
token_ids: List[int]
|
||||
cache: Optional[List[Union[KVCache, RotatingKVCache]]] = None
|
||||
|
||||
class CacheHistory(NamedTuple):
|
||||
cache_history: List[Tuple[mx.array, mx.array]]
|
||||
history_tokens: List[int]
|
||||
cache: List[Tuple[mx.array, mx.array]]
|
||||
token_ids: List[int]
|
||||
|
||||
def convert_cache_to_history(cache: StepOutput) -> CacheHistory:
|
||||
"""Helper function to convert the output of "generate_step" into reusable cache history.
|
||||
|
||||
Args:
|
||||
cache (StepOutput): Outout of "generate_step".
|
||||
|
||||
Returns:
|
||||
CacheHistory: Reusable cache history.
|
||||
"""
|
||||
cache_list = [(c.state[0][..., : c.offset, :], c.state[1][..., : c.offset, :]) for c in cache.cache]
|
||||
return CacheHistory(cache=cache_list, token_ids=cache.token_ids)
|
||||
|
||||
def save_cache(cache: Union[StepOutput, CacheHistory], filename: str, metadata: Optional[Dict[str, str]] = None) -> None:
|
||||
"""Saving a prompt cache into a disk.
|
||||
|
||||
Args:
|
||||
cache (Union[StepOutput, CacheHistory]): Output of "generate_step" or formatted cache history.
|
||||
filename (str): File directory of the prompt cache safetensors file.
|
||||
metadata (Optional[Dict[str, str]], optional): String keys and values metadata for the cache file. Defaults to None.
|
||||
"""
|
||||
import orjson
|
||||
cache = cache if isinstance(cache, CacheHistory) else convert_cache_to_history(cache=cache)
|
||||
metadata = dict() if not isinstance(metadata, dict) else metadata
|
||||
metadata['token_ids'] = orjson.dumps(cache.token_ids).decode()
|
||||
cache_dict = {}
|
||||
for i, c in enumerate(cache.cache):
|
||||
cache_dict[f'{i}_key'] = c[0]
|
||||
cache_dict[f'{i}_value'] = c[1]
|
||||
mx.save_safetensors(file=filename, arrays=cache_dict, metadata=metadata)
|
||||
mx.metal.clear_cache()
|
||||
|
||||
def load_cache(filename: str) -> Tuple[CacheHistory, Dict[str, str]]:
|
||||
"""Loading prompt cache from a safetnesors file into ram for generation.
|
||||
|
||||
Args:
|
||||
filename (str): File directory of the prompt cache safetensors file.
|
||||
|
||||
Returns:
|
||||
Tuple[CacheHistory, Dict[str, str]]: Reusable cache history and the metadata.
|
||||
"""
|
||||
import orjson
|
||||
cache_dict, metadata = mx.load(filename, return_metadata=True)
|
||||
|
||||
# Loading cache
|
||||
num_layers = int(len(cache_dict) / 2)
|
||||
cache = []
|
||||
for i in range(num_layers):
|
||||
cache.append((cache_dict[f'{i}_key'], cache_dict[f'{i}_value']))
|
||||
token_ids = orjson.loads(metadata.pop('token_ids'))
|
||||
ch = CacheHistory(cache=cache, token_ids=token_ids)
|
||||
mx.metal.clear_cache()
|
||||
return ch, metadata
|
||||
|
||||
def find_max_prefix_num(new: List[int], baseline: List[int]) -> int:
|
||||
"""Helper function to find the maximum number of tokens shared in the prefix of two prompts.
|
||||
|
||||
Args:
|
||||
new (List[int]): First prompt token ids.
|
||||
baseline (List[int]): Second prompt token ids.
|
||||
|
||||
Returns:
|
||||
int: The maximum number of tokens shared in the prefix of two prompts.
|
||||
"""
|
||||
from itertools import takewhile
|
||||
return len(list(takewhile(lambda x: x[0] == x[1], zip(new, baseline))))
|
||||
|
||||
def get_kv_caches(
|
||||
model: nn.Module,
|
||||
promp_tokens: List[int],
|
||||
max_kv_size: Optional[int] = None,
|
||||
cache_history: Optional[Union[CacheHistory, StepOutput]] = None
|
||||
) -> Tuple[List[Union[KVCache, RotatingKVCache]], int]:
|
||||
"""Helper function to setup the kv cache in "generate_step".
|
||||
|
||||
Args:
|
||||
model (nn.Module): The LLM model.
|
||||
promp_tokens (List[int]): Prompt tokens ids.
|
||||
max_kv_size (Optional[int], optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten.. Defaults to None.
|
||||
cache_history (Optional[Union[CacheHistory, StepOutput]], optional): Reusable prompt cache history or previous generation step output. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Union[KVCache, RotatingKVCache]], int]: List of KV cache for model generation and the number of tokens reused from the cache history.
|
||||
"""
|
||||
cache = make_kv_caches(model=model, max_kv_size=max_kv_size)
|
||||
max_prefix = 0
|
||||
|
||||
if cache_history is not None:
|
||||
if isinstance(cache_history, StepOutput):
|
||||
cache_history = convert_cache_to_history(cache=cache_history)
|
||||
if len(cache_history.cache) != len(cache):
|
||||
raise ValueError("Wrong number of layers in the cache history")
|
||||
cache_size = cache_history.cache[0][0].shape[2]
|
||||
if (max_kv_size is None) or (cache_size <= max_kv_size):
|
||||
max_prefix = find_max_prefix_num(promp_tokens, cache_history.token_ids)
|
||||
# Leave at least one token to evaluate during generation.
|
||||
max_prefix = max_prefix - 1 if len(promp_tokens) == max_prefix else max_prefix
|
||||
|
||||
# Set the history in the cache objects and evaluate them to prepare for
|
||||
# generation.
|
||||
for c, h in zip(cache, cache_history.cache):
|
||||
c.update_and_fetch(h[0][:, :, :max_prefix, :], h[1][:, :, :max_prefix, :])
|
||||
mx.eval([c.state for c in cache])
|
||||
return cache, max_prefix
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
@@ -173,10 +270,10 @@ def generate_step(
|
||||
min_tokens_to_keep: int = 1,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
prefill_step_size: int = 512,
|
||||
return_cache: bool = False,
|
||||
verbose: bool = False,
|
||||
max_kv_size: Optional[int] = None,
|
||||
cache_history: Optional[CacheHistory] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
cache_history: Optional[Union[CacheHistory, StepOutput]] = None
|
||||
) -> Generator[StepOutput, None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
|
||||
@@ -197,13 +294,12 @@ def generate_step(
|
||||
be filtered by min_p sampling.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
return_cache (bool, optional): Whether to yield the cache as well for prompt caching. Defaults to False.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
cache_history (Optional[CacheHistory]): KV cache history to reuse.
|
||||
cache_history (Optional[Union[CacheHistory, StepOutput]], optional): Reusable prompt cache history or previous generation step output. Defaults to None.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||
Generator[StepOutput, None, None]: A generator producing
|
||||
one token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
@@ -223,7 +319,6 @@ def generate_step(
|
||||
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
||||
else:
|
||||
token = categorical_sampling(logits, temp)
|
||||
|
||||
return token, logprobs
|
||||
|
||||
if repetition_penalty and (
|
||||
@@ -232,27 +327,15 @@ def generate_step(
|
||||
raise ValueError(
|
||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||
)
|
||||
|
||||
tokens: List[int] = prompt.tolist()
|
||||
token_count = len(tokens)
|
||||
y = prompt
|
||||
|
||||
# Create the KV cache for generation
|
||||
cache = make_kv_caches(model, max_kv_size)
|
||||
# Create the KV cache for generation and get the number of tokens being reused.
|
||||
cache, max_prefix = get_kv_caches(model=model, promp_tokens=tokens, max_kv_size=max_kv_size, cache_history=cache_history)
|
||||
y = y[max_prefix:]
|
||||
|
||||
if cache_history is not None:
|
||||
if len(cache_history.cache_history) != len(cache):
|
||||
raise ValueError("Wrong number of layers in the cache history")
|
||||
|
||||
max_prefix = max_common_prefix(prompt, cache_history.history_tokens)
|
||||
max_prefix = max_prefix - 1 if max_prefix == len(y) else max_prefix
|
||||
y = y[max_prefix:]
|
||||
ch = list(map(lambda x: list(map(lambda y: y[:, :, :max_prefix, :], x)), cache_history.cache_history))
|
||||
# Set the history in the cache objects and evaluate them to prepare for
|
||||
# generation.
|
||||
for c, h in zip(cache, ch):
|
||||
c.update_and_fetch(h[0], h[1])
|
||||
mx.eval([c.state for c in cache])
|
||||
|
||||
repetition_context = prompt.tolist()
|
||||
repetition_context = tokens
|
||||
|
||||
if repetition_context_size:
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
@@ -276,142 +359,166 @@ def generate_step(
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
y = y[prefill_step_size:]
|
||||
# Getting preprocessing batches
|
||||
num_batches = y.shape[0] // prefill_step_size
|
||||
if num_batches != (y.size / prefill_step_size):
|
||||
num_batches += 1
|
||||
batches = [(i * prefill_step_size, min((i + 1) * prefill_step_size, y.size)) for i in range(num_batches)]
|
||||
num_tokens = y.size
|
||||
|
||||
y, logprobs = _step(y)
|
||||
# Prompt preprocessing
|
||||
if verbose:
|
||||
from tqdm import tqdm
|
||||
batches = tqdm(batches)
|
||||
pp_start = time.perf_counter()
|
||||
for b in batches:
|
||||
if verbose:
|
||||
batches.set_description(f'Processing prompt ({b[1]}/{y.size})')
|
||||
if (b[1] - b[0]) >= prefill_step_size:
|
||||
model(y[b[0]:b[1]][None], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
mx.metal.clear_cache() # Clearing mlx cache, otherwise it grows very quick with longer prompts.
|
||||
else:
|
||||
y = y[b[0]:b[1]]
|
||||
y, logprobs = _step(y)
|
||||
pp_end = time.perf_counter() - pp_start
|
||||
if verbose:
|
||||
print(f'Prompt preprocessing time for {num_tokens} tokens: {pp_end:.4}s ({num_tokens/pp_end:.4f} tok/sec)')
|
||||
|
||||
mx.async_eval(y)
|
||||
while True:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y)
|
||||
if return_cache:
|
||||
yield y.item(), logprobs, cache
|
||||
else:
|
||||
yield y.item(), logprobs
|
||||
token = y.item()
|
||||
tokens.append(token)
|
||||
token_count += 1
|
||||
if max_kv_size is not None:
|
||||
# Trim off token ids if max_kv_size is set.
|
||||
if token_count >= max_kv_size:
|
||||
token_count -= 1
|
||||
tokens = tokens[:4] + tokens[5:]
|
||||
yield StepOutput(token=token, logprobs=logprobs, token_ids=tokens, cache=cache)
|
||||
y, logprobs = next_y, next_logprobs
|
||||
|
||||
|
||||
def stream_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
max_tokens: int = 100,
|
||||
stop: Optional[List[str]] = None,
|
||||
return_cache: bool = False,
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[str, Generator[str, None, None]]:
|
||||
) -> Generator[Union[str, Tuple[str, StepOutput]], None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
max_tokens (int): The maximum number of tokens.
|
||||
max_tokens (int): The maximum number of tokens to generate.
|
||||
stop (Optional[List[str]], optional): List of words to stop generation. The stop words will be returned. Defaults to None.
|
||||
return_cache (bool, optional): Whether to return the last step output.
|
||||
verbose (bool, optional): Whether to print prompt processing time and stats. Defaults to False.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
|
||||
Generator[Union[str, Tuple[str, StepOutput]], None, None]: A generator producing text.
|
||||
"""
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
prompt_tokens = tokenizer.encode(prompt)
|
||||
prompt_tokens = mx.array(prompt_tokens)
|
||||
tokens = tokenizer.encode(prompt)
|
||||
prompt_tokens = mx.array(tokens)
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
# Place holder for cache
|
||||
cache = None
|
||||
|
||||
detokenizer.reset()
|
||||
stop = [] if stop is None else list(filter(lambda x: x != '', stop))
|
||||
output_text = ''
|
||||
contain_stop = False
|
||||
for step_output, n in zip(
|
||||
generate_step(prompt_tokens, model, return_cache=return_cache, **kwargs),
|
||||
generate_step(prompt_tokens, model, verbose=verbose, **kwargs),
|
||||
range(max_tokens),
|
||||
):
|
||||
if return_cache:
|
||||
token, _, cache = step_output
|
||||
else:
|
||||
token, _ = step_output
|
||||
if token == tokenizer.eos_token_id:
|
||||
):
|
||||
if (step_output.token == tokenizer.eos_token_id) or contain_stop:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
detokenizer.add_token(step_output.token)
|
||||
tokens.append(step_output.token)
|
||||
last_segment = detokenizer.last_segment
|
||||
output_text += last_segment
|
||||
|
||||
if any([x in output_text for x in stop]):
|
||||
contain_stop = True
|
||||
|
||||
# Yield the last segment if streaming
|
||||
if return_cache:
|
||||
yield detokenizer.last_segment, cache
|
||||
yield last_segment, step_output
|
||||
else:
|
||||
yield detokenizer.last_segment
|
||||
yield last_segment
|
||||
|
||||
detokenizer.finalize()
|
||||
if return_cache:
|
||||
yield detokenizer.last_segment, cache
|
||||
yield detokenizer.last_segment, step_output
|
||||
else:
|
||||
yield detokenizer.last_segment
|
||||
|
||||
|
||||
def generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
max_tokens: int = 100,
|
||||
stop: Optional[List[str]] = None,
|
||||
return_cache: bool = False,
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
return_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[str, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate a complete response from the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (str): The string prompt.
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
formatter (Optional[Callable]): A function which takes a token and a
|
||||
probability and displays it.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (str): The string prompt.
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
stop (Optional[List[str]], optional): List of words to stop generation. The stop words will be returned. Defaults to None.
|
||||
return_cache (bool, optional): Whether to return the last step output.
|
||||
verbose (bool, optional): Whether to print prompt processing time and stats. Defaults to False.
|
||||
formatter (Optional[Callable]): A function which takes a token and a
|
||||
probability and displays it.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
"""
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
cache = None
|
||||
stop = [] if stop is None else list(filter(lambda x: x != '', stop))
|
||||
output_text = ''
|
||||
contain_stop = False
|
||||
|
||||
for step_output, n in zip(
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
generate_step(prompt_tokens, model, verbose=verbose, **kwargs),
|
||||
range(max_tokens),
|
||||
):
|
||||
if return_cache:
|
||||
token, logprobs, cache = step_output
|
||||
else:
|
||||
token, logprobs = step_output
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
if (step_output.token == tokenizer.eos_token_id) or contain_stop:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
detokenizer.add_token(step_output.token)
|
||||
output_text += detokenizer.last_segment
|
||||
if any([x in output_text for x in stop]):
|
||||
contain_stop = True
|
||||
|
||||
if verbose:
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
formatter(detokenizer.last_segment, mx.exp(step_output.logprobs[step_output.token]).item())
|
||||
|
||||
token_count = n + 1
|
||||
detokenizer.finalize()
|
||||
@@ -423,40 +530,12 @@ def generate(
|
||||
if token_count == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt_tokens.size / prompt_time
|
||||
gen_tps = (token_count - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
if return_cache:
|
||||
return detokenizer.text, cache
|
||||
else:
|
||||
return detokenizer.text
|
||||
|
||||
def format_cache(tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], text: str, cache: KVCache) -> CacheHistory:
|
||||
# Code copied from reformatting cache to safetensors file.
|
||||
cache_dict = {}
|
||||
for i, c in enumerate(cache):
|
||||
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
|
||||
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
|
||||
|
||||
# Converting the cach_dict to something the original generate_step function accept.
|
||||
cache_per_layer = {}
|
||||
for k, x in cache_dict.items():
|
||||
layer, kv_type = k.split("_")
|
||||
if layer not in cache_per_layer:
|
||||
cache_per_layer[layer] = {}
|
||||
cache_per_layer[layer][kv_type] = x
|
||||
|
||||
cache_history = [None] * len(cache_per_layer)
|
||||
for layer, c in cache_per_layer.items():
|
||||
cache_history[int(layer)] = (c["keys"], c["values"])
|
||||
|
||||
# Converting the original prompt + newly generated text to tokens.
|
||||
tokens = tokenizer.encode(text=text)
|
||||
return CacheHistory(cache_history=cache_history, history_tokens=tokens)
|
||||
return detokenizer.text, step_output if return_cache else detokenizer.text
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
|
Reference in New Issue
Block a user