mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Update utils.py
Change to enable saving the kv-cache as a safetensors file after a text completion; after generate step has finished creating all the tokens, the key values cache is made into a dict and saved using mx.save_safetensors to a user-specified file location; similar to cache_prompt.
This commit is contained in:
@@ -15,7 +15,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer, AddedToken
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models.base import KVCache, RotatingKVCache
|
from .models.base import KVCache, RotatingKVCache
|
||||||
@@ -309,6 +309,12 @@ def stream_generate(
|
|||||||
yield detokenizer.last_segment
|
yield detokenizer.last_segment
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerConfigEncoder(json.JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, AddedToken):
|
||||||
|
return str(obj)
|
||||||
|
return super().default(obj)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
@@ -316,6 +322,7 @@ def generate(
|
|||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Optional[Callable] = None,
|
formatter: Optional[Callable] = None,
|
||||||
|
save_cache_path: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, Generator[str, None, None]]:
|
) -> Union[str, Generator[str, None, None]]:
|
||||||
"""
|
"""
|
||||||
@@ -330,6 +337,7 @@ def generate(
|
|||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
formatter (Optional[Callable]): A function which takes a token and a
|
formatter (Optional[Callable]): A function which takes a token and a
|
||||||
probability and displays it.
|
probability and displays it.
|
||||||
|
save_cache_path (Optional[str]): If provided, save the final KV cache to this path.
|
||||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
See :func:`generate_step` for more details.
|
See :func:`generate_step` for more details.
|
||||||
"""
|
"""
|
||||||
@@ -346,10 +354,11 @@ def generate(
|
|||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
|
|
||||||
for (token, logprobs), n in zip(
|
# Get the generate_step generator
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
gen_step = generate_step(prompt_tokens, model, **kwargs)
|
||||||
range(max_tokens),
|
|
||||||
):
|
# Actual generation loop is here
|
||||||
|
for (token, logprobs), n in zip(gen_step, range(max_tokens)):
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
@@ -359,7 +368,6 @@ def generate(
|
|||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
if formatter:
|
if formatter:
|
||||||
# We have to finalize so that the prob corresponds to the last segment
|
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
||||||
else:
|
else:
|
||||||
@@ -382,6 +390,32 @@ def generate(
|
|||||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
|
|
||||||
|
# Save the final KV cache if requested
|
||||||
|
if save_cache_path:
|
||||||
|
# Extract the final cache state
|
||||||
|
cache = gen_step.gi_frame.f_locals['cache']
|
||||||
|
|
||||||
|
# Prepare the cache data
|
||||||
|
cache_dict = {}
|
||||||
|
for i, c in enumerate(cache):
|
||||||
|
cache_dict[f"{i}_keys"] = c.state[0][..., :c.offset, :]
|
||||||
|
cache_dict[f"{i}_values"] = c.state[1][..., :c.offset, :]
|
||||||
|
|
||||||
|
# Prepare metadata
|
||||||
|
metadata = {
|
||||||
|
"model": model.__class__.__name__,
|
||||||
|
"chat_template": str(getattr(tokenizer, "chat_template", None)),
|
||||||
|
"tokenizer_config": json.dumps({
|
||||||
|
k: (str(v) if isinstance(v, AddedToken) else v)
|
||||||
|
for k, v in tokenizer.init_kwargs.items()
|
||||||
|
}),
|
||||||
|
"max_kv_size": str(kwargs.get("max_kv_size", "None")),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save to safetensors file
|
||||||
|
mx.save_safetensors(save_cache_path, cache_dict, metadata=metadata)
|
||||||
|
print(f"Final KV cache saved to {save_cache_path}")
|
||||||
|
|
||||||
return detokenizer.text
|
return detokenizer.text
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user