mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Update utils.py
Change to enable saving the kv-cache as a safetensors file after a text completion; after generate step has finished creating all the tokens, the key values cache is made into a dict and saved using mx.save_safetensors to a user-specified file location; similar to cache_prompt.
This commit is contained in:
@@ -15,7 +15,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, AddedToken
|
||||
|
||||
# Local imports
|
||||
from .models.base import KVCache, RotatingKVCache
|
||||
@@ -309,6 +309,12 @@ def stream_generate(
|
||||
yield detokenizer.last_segment
|
||||
|
||||
|
||||
class TokenizerConfigEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, AddedToken):
|
||||
return str(obj)
|
||||
return super().default(obj)
|
||||
|
||||
def generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
@@ -316,6 +322,7 @@ def generate(
|
||||
max_tokens: int = 100,
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
save_cache_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[str, Generator[str, None, None]]:
|
||||
"""
|
||||
@@ -330,6 +337,7 @@ def generate(
|
||||
Default: ``False``.
|
||||
formatter (Optional[Callable]): A function which takes a token and a
|
||||
probability and displays it.
|
||||
save_cache_path (Optional[str]): If provided, save the final KV cache to this path.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
"""
|
||||
@@ -346,10 +354,11 @@ def generate(
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
|
||||
for (token, logprobs), n in zip(
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
range(max_tokens),
|
||||
):
|
||||
# Get the generate_step generator
|
||||
gen_step = generate_step(prompt_tokens, model, **kwargs)
|
||||
|
||||
# Actual generation loop is here
|
||||
for (token, logprobs), n in zip(gen_step, range(max_tokens)):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
@@ -359,7 +368,6 @@ def generate(
|
||||
|
||||
if verbose:
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
||||
else:
|
||||
@@ -382,6 +390,32 @@ def generate(
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
# Save the final KV cache if requested
|
||||
if save_cache_path:
|
||||
# Extract the final cache state
|
||||
cache = gen_step.gi_frame.f_locals['cache']
|
||||
|
||||
# Prepare the cache data
|
||||
cache_dict = {}
|
||||
for i, c in enumerate(cache):
|
||||
cache_dict[f"{i}_keys"] = c.state[0][..., :c.offset, :]
|
||||
cache_dict[f"{i}_values"] = c.state[1][..., :c.offset, :]
|
||||
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
"model": model.__class__.__name__,
|
||||
"chat_template": str(getattr(tokenizer, "chat_template", None)),
|
||||
"tokenizer_config": json.dumps({
|
||||
k: (str(v) if isinstance(v, AddedToken) else v)
|
||||
for k, v in tokenizer.init_kwargs.items()
|
||||
}),
|
||||
"max_kv_size": str(kwargs.get("max_kv_size", "None")),
|
||||
}
|
||||
|
||||
# Save to safetensors file
|
||||
mx.save_safetensors(save_cache_path, cache_dict, metadata=metadata)
|
||||
print(f"Final KV cache saved to {save_cache_path}")
|
||||
|
||||
return detokenizer.text
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user