Update utils.py

Change to enable saving the kv-cache as a safetensors file after a text completion; after generate step has finished creating all the tokens, the key values cache is made into a dict and saved using mx.save_safetensors to a user-specified file location; similar to cache_prompt.
This commit is contained in:
mark
2024-09-26 16:58:02 +01:00
committed by GitHub
parent dafda90980
commit e5c98f4715

View File

@@ -15,7 +15,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer, AddedToken
# Local imports # Local imports
from .models.base import KVCache, RotatingKVCache from .models.base import KVCache, RotatingKVCache
@@ -309,6 +309,12 @@ def stream_generate(
yield detokenizer.last_segment yield detokenizer.last_segment
class TokenizerConfigEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, AddedToken):
return str(obj)
return super().default(obj)
def generate( def generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
@@ -316,6 +322,7 @@ def generate(
max_tokens: int = 100, max_tokens: int = 100,
verbose: bool = False, verbose: bool = False,
formatter: Optional[Callable] = None, formatter: Optional[Callable] = None,
save_cache_path: Optional[str] = None,
**kwargs, **kwargs,
) -> Union[str, Generator[str, None, None]]: ) -> Union[str, Generator[str, None, None]]:
""" """
@@ -330,6 +337,7 @@ def generate(
Default: ``False``. Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a formatter (Optional[Callable]): A function which takes a token and a
probability and displays it. probability and displays it.
save_cache_path (Optional[str]): If provided, save the final KV cache to this path.
kwargs: The remaining options get passed to :func:`generate_step`. kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details. See :func:`generate_step` for more details.
""" """
@@ -346,10 +354,11 @@ def generate(
tic = time.perf_counter() tic = time.perf_counter()
detokenizer.reset() detokenizer.reset()
for (token, logprobs), n in zip( # Get the generate_step generator
generate_step(prompt_tokens, model, **kwargs), gen_step = generate_step(prompt_tokens, model, **kwargs)
range(max_tokens),
): # Actual generation loop is here
for (token, logprobs), n in zip(gen_step, range(max_tokens)):
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
tic = time.perf_counter() tic = time.perf_counter()
@@ -359,7 +368,6 @@ def generate(
if verbose: if verbose:
if formatter: if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize() detokenizer.finalize()
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
else: else:
@@ -382,6 +390,32 @@ def generate(
peak_mem = mx.metal.get_peak_memory() / 2**30 peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB") print(f"Peak memory: {peak_mem:.3f} GB")
# Save the final KV cache if requested
if save_cache_path:
# Extract the final cache state
cache = gen_step.gi_frame.f_locals['cache']
# Prepare the cache data
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0][..., :c.offset, :]
cache_dict[f"{i}_values"] = c.state[1][..., :c.offset, :]
# Prepare metadata
metadata = {
"model": model.__class__.__name__,
"chat_template": str(getattr(tokenizer, "chat_template", None)),
"tokenizer_config": json.dumps({
k: (str(v) if isinstance(v, AddedToken) else v)
for k, v in tokenizer.init_kwargs.items()
}),
"max_kv_size": str(kwargs.get("max_kv_size", "None")),
}
# Save to safetensors file
mx.save_safetensors(save_cache_path, cache_dict, metadata=metadata)
print(f"Final KV cache saved to {save_cache_path}")
return detokenizer.text return detokenizer.text