From b1186e2a81c678087bcaab43202d040b126523c3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 29 Aug 2024 15:05:17 -0700 Subject: [PATCH] Docs on prompt scaling (#963) * docs on prompt scaling * remove unused var * nits --- llms/README.md | 42 ++++++++++++++++++++++++++++++++++--- llms/mlx_lm/cache_prompt.py | 6 +++++- llms/mlx_lm/generate.py | 11 ++++------ llms/mlx_lm/version.py | 2 +- 4 files changed, 49 insertions(+), 12 deletions(-) diff --git a/llms/README.md b/llms/README.md index 497c0277..79f26d41 100644 --- a/llms/README.md +++ b/llms/README.md @@ -38,7 +38,9 @@ To see a description of all the arguments you can do: >>> help(generate) ``` -Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail. +Check out the [generation +example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) +to see how to use the API in more detail. The `mlx-lm` package also comes with functionality to quantize and optionally upload models to the Hugging Face Hub. @@ -122,10 +124,44 @@ mlx_lm.convert \ --upload-repo mlx-community/my-4bit-mistral ``` +### Long Prompts and Generations + +MLX LM has some tools to scale efficiently to long prompts and generations: + +- A rotating fixed-size key-value cache. +- Prompt caching + +To use the rotating key-value cache pass the argument `--max-kv-size n` where +`n` can be any integer. Smaller values like `512` will use very little RAM but +result in worse quality. Larger values like `4096` or higher will use more RAM +but have better quality. + +Caching prompts can substantially speedup reusing the same long context with +different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example: + +```bash +cat prompt.txt | mlx_lm.cache_prompt \ + --model mistralai/Mistral-7B-Instruct-v0.3 \ + --prompt - \ + --kv-cache-file mistral_prompt.safetensors +``` + +Then use the cached prompt with `mlx_lm.generate`: + +``` +mlx_lm.generate \ + --kv-cache-file mistral_prompt.safetensors \ + --prompt "\nSummarize the above text." +``` + +The cached prompt is treated as a prefix to the supplied prompt. Also notice +when using a cached prompt, the model to use is read from the cache and need +not be supplied explicitly. + ### Supported Models -The example supports Hugging Face format Mistral, Llama, and Phi-2 style -models. If the model you want to run is not supported, file an +MLX LM supports thousands of Hugging Face format LLMs. If the model you want to +run is not supported, file an [issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, submit a pull request. diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index ad045f1a..fe088118 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -56,7 +56,7 @@ def setup_arg_parser(): parser.add_argument( "--max-kv-size", type=int, - default=1024, + default=None, help="Set the maximum key-value cache size", ) parser.add_argument( @@ -147,3 +147,7 @@ def main(): metadata["tokenizer_config"] = json.dumps(tokenizer_config) metadata["max_kv_size"] = str(args.max_kv_size) mx.save_safetensors(args.kv_cache_file, cache_dict, metadata) + + +if __name__ == "__main__": + main() diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 4aa4001a..54f6f4d2 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -12,7 +12,6 @@ DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.6 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 -DEFAULT_MAX_KV_SIZE = 1024 def setup_arg_parser(): @@ -81,6 +80,7 @@ def setup_arg_parser(): "--max-kv-size", type=int, help="Set the maximum key-value cache size", + default=None, ) parser.add_argument( "--kv-cache-file", @@ -199,12 +199,9 @@ def main(): # Determine the max kv size from the kv cache or passed arguments max_kv_size = args.max_kv_size - if max_kv_size is None: - max_kv_size = ( - int(metadata["max_kv_size"]) - if cache_history is not None - else DEFAULT_MAX_KV_SIZE - ) + if cache_history is not None: + max_kv_size = metadata["max_kv_size"] + max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None generate( model, diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 41237905..87e86846 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.17.1" +__version__ = "0.18.0"