diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5621609d..2d93fe84 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -15,7 +15,7 @@ import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download from mlx.utils import tree_flatten -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, AddedToken # Local imports from .models.base import KVCache, RotatingKVCache @@ -309,6 +309,12 @@ def stream_generate( yield detokenizer.last_segment +class TokenizerConfigEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, AddedToken): + return str(obj) + return super().default(obj) + def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], @@ -316,6 +322,7 @@ def generate( max_tokens: int = 100, verbose: bool = False, formatter: Optional[Callable] = None, + save_cache_path: Optional[str] = None, **kwargs, ) -> Union[str, Generator[str, None, None]]: """ @@ -330,6 +337,7 @@ def generate( Default: ``False``. formatter (Optional[Callable]): A function which takes a token and a probability and displays it. + save_cache_path (Optional[str]): If provided, save the final KV cache to this path. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. """ @@ -346,10 +354,11 @@ def generate( tic = time.perf_counter() detokenizer.reset() - for (token, logprobs), n in zip( - generate_step(prompt_tokens, model, **kwargs), - range(max_tokens), - ): + # Get the generate_step generator + gen_step = generate_step(prompt_tokens, model, **kwargs) + + # Actual generation loop is here + for (token, logprobs), n in zip(gen_step, range(max_tokens)): if n == 0: prompt_time = time.perf_counter() - tic tic = time.perf_counter() @@ -359,7 +368,6 @@ def generate( if verbose: if formatter: - # We have to finalize so that the prob corresponds to the last segment detokenizer.finalize() formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) else: @@ -382,6 +390,32 @@ def generate( peak_mem = mx.metal.get_peak_memory() / 2**30 print(f"Peak memory: {peak_mem:.3f} GB") + # Save the final KV cache if requested + if save_cache_path: + # Extract the final cache state + cache = gen_step.gi_frame.f_locals['cache'] + + # Prepare the cache data + cache_dict = {} + for i, c in enumerate(cache): + cache_dict[f"{i}_keys"] = c.state[0][..., :c.offset, :] + cache_dict[f"{i}_values"] = c.state[1][..., :c.offset, :] + + # Prepare metadata + metadata = { + "model": model.__class__.__name__, + "chat_template": str(getattr(tokenizer, "chat_template", None)), + "tokenizer_config": json.dumps({ + k: (str(v) if isinstance(v, AddedToken) else v) + for k, v in tokenizer.init_kwargs.items() + }), + "max_kv_size": str(kwargs.get("max_kv_size", "None")), + } + + # Save to safetensors file + mx.save_safetensors(save_cache_path, cache_dict, metadata=metadata) + print(f"Final KV cache saved to {save_cache_path}") + return detokenizer.text