diff --git a/llms/README.md b/llms/README.md index 5c59c796..08970ee9 100644 --- a/llms/README.md +++ b/llms/README.md @@ -170,6 +170,12 @@ The cached prompt is treated as a prefix to the supplied prompt. Also notice when using a cached prompt, the model to use is read from the cache and need not be supplied explicitly. +Prompt caching can also be used in the Python API in order to to avoid +recomputing the prompt. This is useful in multi-turn dialogues or across +requests that use the same context. See the +[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py) +for more usage details. + ### Supported Models MLX LM supports thousands of Hugging Face format LLMs. If the model you want to diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py index 1dc78957..3bf01688 100644 --- a/llms/mlx_lm/examples/chat.py +++ b/llms/mlx_lm/examples/chat.py @@ -5,7 +5,7 @@ An example of a multi-turn chat with prompt caching. """ from mlx_lm import generate, load -from mlx_lm.models.cache import make_prompt_cache +from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") @@ -45,3 +45,9 @@ response = generate( temp=0.0, prompt_cache=prompt_cache, ) + +# Save the prompt cache to disk to reuse it at a later time +save_prompt_cache("mistral_prompt.safetensors", prompt_cache) + +# Load the prompt cache from disk +prompt_cache = load_prompt_cache("mistral_prompt.safetensors")