Better support for rotating KV-cache and add stop word list as an argument for generate and stream_generate

This commit is contained in:
nath1295
2024-09-27 13:43:45 +01:00
parent 3ef1011aff
commit 7e98499ee3

View File

@@ -145,23 +145,120 @@ def make_kv_caches(
else:
return [KVCache(model.head_dim, n) for n in kv_heads]
def max_common_prefix(ls1: List[int], ls2: List[int]) -> str:
"""Find the maximum number of shared tokens from the start of the lists.
Args:
ls1 (List[int]): Token ID list 1.
ls2 (List[int]): Token ID list 2.
Returns:
str: The number of shared tokens in the prefix.
"""
import itertools
return sum(1 for _ in itertools.takewhile(lambda x: x[0] == x[1], zip(ls1, ls2)))
class StepOutput(NamedTuple):
token: int
logprobs: mx.array
token_ids: List[int]
cache: Optional[List[Union[KVCache, RotatingKVCache]]] = None
class CacheHistory(NamedTuple):
cache_history: List[Tuple[mx.array, mx.array]]
history_tokens: List[int]
cache: List[Tuple[mx.array, mx.array]]
token_ids: List[int]
def convert_cache_to_history(cache: StepOutput) -> CacheHistory:
"""Helper function to convert the output of "generate_step" into reusable cache history.
Args:
cache (StepOutput): Outout of "generate_step".
Returns:
CacheHistory: Reusable cache history.
"""
cache_list = [(c.state[0][..., : c.offset, :], c.state[1][..., : c.offset, :]) for c in cache.cache]
return CacheHistory(cache=cache_list, token_ids=cache.token_ids)
def save_cache(cache: Union[StepOutput, CacheHistory], filename: str, metadata: Optional[Dict[str, str]] = None) -> None:
"""Saving a prompt cache into a disk.
Args:
cache (Union[StepOutput, CacheHistory]): Output of "generate_step" or formatted cache history.
filename (str): File directory of the prompt cache safetensors file.
metadata (Optional[Dict[str, str]], optional): String keys and values metadata for the cache file. Defaults to None.
"""
import orjson
cache = cache if isinstance(cache, CacheHistory) else convert_cache_to_history(cache=cache)
metadata = dict() if not isinstance(metadata, dict) else metadata
metadata['token_ids'] = orjson.dumps(cache.token_ids).decode()
cache_dict = {}
for i, c in enumerate(cache.cache):
cache_dict[f'{i}_key'] = c[0]
cache_dict[f'{i}_value'] = c[1]
mx.save_safetensors(file=filename, arrays=cache_dict, metadata=metadata)
mx.metal.clear_cache()
def load_cache(filename: str) -> Tuple[CacheHistory, Dict[str, str]]:
"""Loading prompt cache from a safetnesors file into ram for generation.
Args:
filename (str): File directory of the prompt cache safetensors file.
Returns:
Tuple[CacheHistory, Dict[str, str]]: Reusable cache history and the metadata.
"""
import orjson
cache_dict, metadata = mx.load(filename, return_metadata=True)
# Loading cache
num_layers = int(len(cache_dict) / 2)
cache = []
for i in range(num_layers):
cache.append((cache_dict[f'{i}_key'], cache_dict[f'{i}_value']))
token_ids = orjson.loads(metadata.pop('token_ids'))
ch = CacheHistory(cache=cache, token_ids=token_ids)
mx.metal.clear_cache()
return ch, metadata
def find_max_prefix_num(new: List[int], baseline: List[int]) -> int:
"""Helper function to find the maximum number of tokens shared in the prefix of two prompts.
Args:
new (List[int]): First prompt token ids.
baseline (List[int]): Second prompt token ids.
Returns:
int: The maximum number of tokens shared in the prefix of two prompts.
"""
from itertools import takewhile
return len(list(takewhile(lambda x: x[0] == x[1], zip(new, baseline))))
def get_kv_caches(
model: nn.Module,
promp_tokens: List[int],
max_kv_size: Optional[int] = None,
cache_history: Optional[Union[CacheHistory, StepOutput]] = None
) -> Tuple[List[Union[KVCache, RotatingKVCache]], int]:
"""Helper function to setup the kv cache in "generate_step".
Args:
model (nn.Module): The LLM model.
promp_tokens (List[int]): Prompt tokens ids.
max_kv_size (Optional[int], optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten.. Defaults to None.
cache_history (Optional[Union[CacheHistory, StepOutput]], optional): Reusable prompt cache history or previous generation step output. Defaults to None.
Returns:
Tuple[List[Union[KVCache, RotatingKVCache]], int]: List of KV cache for model generation and the number of tokens reused from the cache history.
"""
cache = make_kv_caches(model=model, max_kv_size=max_kv_size)
max_prefix = 0
if cache_history is not None:
if isinstance(cache_history, StepOutput):
cache_history = convert_cache_to_history(cache=cache_history)
if len(cache_history.cache) != len(cache):
raise ValueError("Wrong number of layers in the cache history")
cache_size = cache_history.cache[0][0].shape[2]
if (max_kv_size is None) or (cache_size <= max_kv_size):
max_prefix = find_max_prefix_num(promp_tokens, cache_history.token_ids)
# Leave at least one token to evaluate during generation.
max_prefix = max_prefix - 1 if len(promp_tokens) == max_prefix else max_prefix
# Set the history in the cache objects and evaluate them to prepare for
# generation.
for c, h in zip(cache, cache_history.cache):
c.update_and_fetch(h[0][:, :, :max_prefix, :], h[1][:, :, :max_prefix, :])
mx.eval([c.state for c in cache])
return cache, max_prefix
def generate_step(
prompt: mx.array,
model: nn.Module,
@@ -173,10 +270,10 @@ def generate_step(
min_tokens_to_keep: int = 1,
logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
return_cache: bool = False,
verbose: bool = False,
max_kv_size: Optional[int] = None,
cache_history: Optional[CacheHistory] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
cache_history: Optional[Union[CacheHistory, StepOutput]] = None
) -> Generator[StepOutput, None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -197,13 +294,12 @@ def generate_step(
be filtered by min_p sampling.
logit_bias (dictionary, optional): Additive logit bias.
prefill_step_size (int): Step size for processing the prompt.
return_cache (bool, optional): Whether to yield the cache as well for prompt caching. Defaults to False.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
cache_history (Optional[CacheHistory]): KV cache history to reuse.
cache_history (Optional[Union[CacheHistory, StepOutput]], optional): Reusable prompt cache history or previous generation step output. Defaults to None.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
Generator[StepOutput, None, None]: A generator producing
one token and a vector of log probabilities.
"""
@@ -223,7 +319,6 @@ def generate_step(
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
else:
token = categorical_sampling(logits, temp)
return token, logprobs
if repetition_penalty and (
@@ -232,27 +327,15 @@ def generate_step(
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)
tokens: List[int] = prompt.tolist()
token_count = len(tokens)
y = prompt
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
# Create the KV cache for generation and get the number of tokens being reused.
cache, max_prefix = get_kv_caches(model=model, promp_tokens=tokens, max_kv_size=max_kv_size, cache_history=cache_history)
y = y[max_prefix:]
if cache_history is not None:
if len(cache_history.cache_history) != len(cache):
raise ValueError("Wrong number of layers in the cache history")
max_prefix = max_common_prefix(prompt, cache_history.history_tokens)
max_prefix = max_prefix - 1 if max_prefix == len(y) else max_prefix
y = y[max_prefix:]
ch = list(map(lambda x: list(map(lambda y: y[:, :, :max_prefix, :], x)), cache_history.cache_history))
# Set the history in the cache objects and evaluate them to prepare for
# generation.
for c, h in zip(cache, ch):
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])
repetition_context = prompt.tolist()
repetition_context = tokens
if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
@@ -276,142 +359,166 @@ def generate_step(
repetition_context = repetition_context[-repetition_context_size:]
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
# Getting preprocessing batches
num_batches = y.shape[0] // prefill_step_size
if num_batches != (y.size / prefill_step_size):
num_batches += 1
batches = [(i * prefill_step_size, min((i + 1) * prefill_step_size, y.size)) for i in range(num_batches)]
num_tokens = y.size
y, logprobs = _step(y)
# Prompt preprocessing
if verbose:
from tqdm import tqdm
batches = tqdm(batches)
pp_start = time.perf_counter()
for b in batches:
if verbose:
batches.set_description(f'Processing prompt ({b[1]}/{y.size})')
if (b[1] - b[0]) >= prefill_step_size:
model(y[b[0]:b[1]][None], cache=cache)
mx.eval([c.state for c in cache])
mx.metal.clear_cache() # Clearing mlx cache, otherwise it grows very quick with longer prompts.
else:
y = y[b[0]:b[1]]
y, logprobs = _step(y)
pp_end = time.perf_counter() - pp_start
if verbose:
print(f'Prompt preprocessing time for {num_tokens} tokens: {pp_end:.4}s ({num_tokens/pp_end:.4f} tok/sec)')
mx.async_eval(y)
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
if return_cache:
yield y.item(), logprobs, cache
else:
yield y.item(), logprobs
token = y.item()
tokens.append(token)
token_count += 1
if max_kv_size is not None:
# Trim off token ids if max_kv_size is set.
if token_count >= max_kv_size:
token_count -= 1
tokens = tokens[:4] + tokens[5:]
yield StepOutput(token=token, logprobs=logprobs, token_ids=tokens, cache=cache)
y, logprobs = next_y, next_logprobs
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
stop: Optional[List[str]] = None,
return_cache: bool = False,
verbose: bool = False,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> Generator[Union[str, Tuple[str, StepOutput]], None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The maximum number of tokens.
max_tokens (int): The maximum number of tokens to generate.
stop (Optional[List[str]], optional): List of words to stop generation. The stop words will be returned. Defaults to None.
return_cache (bool, optional): Whether to return the last step output.
verbose (bool, optional): Whether to print prompt processing time and stats. Defaults to False.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
Generator[Union[str, Tuple[str, StepOutput]], None, None]: A generator producing text.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = tokenizer.encode(prompt)
prompt_tokens = mx.array(prompt_tokens)
tokens = tokenizer.encode(prompt)
prompt_tokens = mx.array(tokens)
detokenizer = tokenizer.detokenizer
# Place holder for cache
cache = None
detokenizer.reset()
stop = [] if stop is None else list(filter(lambda x: x != '', stop))
output_text = ''
contain_stop = False
for step_output, n in zip(
generate_step(prompt_tokens, model, return_cache=return_cache, **kwargs),
generate_step(prompt_tokens, model, verbose=verbose, **kwargs),
range(max_tokens),
):
if return_cache:
token, _, cache = step_output
else:
token, _ = step_output
if token == tokenizer.eos_token_id:
):
if (step_output.token == tokenizer.eos_token_id) or contain_stop:
break
detokenizer.add_token(token)
detokenizer.add_token(step_output.token)
tokens.append(step_output.token)
last_segment = detokenizer.last_segment
output_text += last_segment
if any([x in output_text for x in stop]):
contain_stop = True
# Yield the last segment if streaming
if return_cache:
yield detokenizer.last_segment, cache
yield last_segment, step_output
else:
yield detokenizer.last_segment
yield last_segment
detokenizer.finalize()
if return_cache:
yield detokenizer.last_segment, cache
yield detokenizer.last_segment, step_output
else:
yield detokenizer.last_segment
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
stop: Optional[List[str]] = None,
return_cache: bool = False,
verbose: bool = False,
formatter: Optional[Callable] = None,
return_cache: bool = False,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
"""
Generate a complete response from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
stop (Optional[List[str]], optional): List of words to stop generation. The stop words will be returned. Defaults to None.
return_cache (bool, optional): Whether to return the last step output.
verbose (bool, optional): Whether to print prompt processing time and stats. Defaults to False.
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
tic = time.perf_counter()
detokenizer.reset()
cache = None
stop = [] if stop is None else list(filter(lambda x: x != '', stop))
output_text = ''
contain_stop = False
for step_output, n in zip(
generate_step(prompt_tokens, model, **kwargs),
generate_step(prompt_tokens, model, verbose=verbose, **kwargs),
range(max_tokens),
):
if return_cache:
token, logprobs, cache = step_output
else:
token, logprobs = step_output
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
if (step_output.token == tokenizer.eos_token_id) or contain_stop:
break
detokenizer.add_token(token)
detokenizer.add_token(step_output.token)
output_text += detokenizer.last_segment
if any([x in output_text for x in stop]):
contain_stop = True
if verbose:
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
else:
print(detokenizer.last_segment, end="", flush=True)
formatter(detokenizer.last_segment, mx.exp(step_output.logprobs[step_output.token]).item())
token_count = n + 1
detokenizer.finalize()
@@ -423,40 +530,12 @@ def generate(
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
if return_cache:
return detokenizer.text, cache
else:
return detokenizer.text
def format_cache(tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], text: str, cache: KVCache) -> CacheHistory:
# Code copied from reformatting cache to safetensors file.
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
# Converting the cach_dict to something the original generate_step function accept.
cache_per_layer = {}
for k, x in cache_dict.items():
layer, kv_type = k.split("_")
if layer not in cache_per_layer:
cache_per_layer[layer] = {}
cache_per_layer[layer][kv_type] = x
cache_history = [None] * len(cache_per_layer)
for layer, c in cache_per_layer.items():
cache_history[int(layer)] = (c["keys"], c["values"])
# Converting the original prompt + newly generated text to tokens.
tokens = tokenizer.encode(text=text)
return CacheHistory(cache_history=cache_history, history_tokens=tokens)
return detokenizer.text, step_output if return_cache else detokenizer.text
def load_config(model_path: Path) -> dict: