mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-01 03:28:08 +08:00 
			
		
		
		
	Docs on prompt scaling (#963)
* docs on prompt scaling * remove unused var * nits
This commit is contained in:
		| @@ -38,7 +38,9 @@ To see a description of all the arguments you can do: | ||||
| >>> help(generate) | ||||
| ``` | ||||
|  | ||||
| Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail. | ||||
| Check out the [generation | ||||
| example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) | ||||
| to see how to use the API in more detail. | ||||
|  | ||||
| The `mlx-lm` package also comes with functionality to quantize and optionally | ||||
| upload models to the Hugging Face Hub. | ||||
| @@ -122,10 +124,44 @@ mlx_lm.convert \ | ||||
|     --upload-repo mlx-community/my-4bit-mistral | ||||
| ``` | ||||
|  | ||||
| ### Long Prompts and Generations  | ||||
|  | ||||
| MLX LM has some tools to scale efficiently to long prompts and generations: | ||||
|  | ||||
| - A rotating fixed-size key-value cache. | ||||
| - Prompt caching | ||||
|  | ||||
| To use the rotating key-value cache pass the argument `--max-kv-size n` where | ||||
| `n` can be any integer. Smaller values like `512` will use very little RAM but | ||||
| result in worse quality. Larger values like `4096` or higher will use more RAM | ||||
| but have better quality. | ||||
|  | ||||
| Caching prompts can substantially speedup reusing the same long context with | ||||
| different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example: | ||||
|  | ||||
| ```bash | ||||
| cat prompt.txt | mlx_lm.cache_prompt \ | ||||
|   --model mistralai/Mistral-7B-Instruct-v0.3 \ | ||||
|   --prompt - \ | ||||
|   --kv-cache-file mistral_prompt.safetensors | ||||
| ```  | ||||
|  | ||||
| Then use the cached prompt with `mlx_lm.generate`: | ||||
|  | ||||
| ``` | ||||
| mlx_lm.generate \ | ||||
|     --kv-cache-file mistral_prompt.safetensors \ | ||||
|     --prompt "\nSummarize the above text." | ||||
| ``` | ||||
|  | ||||
| The cached prompt is treated as a prefix to the supplied prompt. Also notice | ||||
| when using a cached prompt, the model to use is read from the cache and need | ||||
| not be supplied explicitly. | ||||
|  | ||||
| ### Supported Models | ||||
|  | ||||
| The example supports Hugging Face format Mistral, Llama, and Phi-2 style | ||||
| models.  If the model you want to run is not supported, file an | ||||
| MLX LM supports thousands of Hugging Face format LLMs. If the model you want to | ||||
| run is not supported, file an | ||||
| [issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, | ||||
| submit a pull request. | ||||
|  | ||||
|   | ||||
| @@ -56,7 +56,7 @@ def setup_arg_parser(): | ||||
|     parser.add_argument( | ||||
|         "--max-kv-size", | ||||
|         type=int, | ||||
|         default=1024, | ||||
|         default=None, | ||||
|         help="Set the maximum key-value cache size", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
| @@ -147,3 +147,7 @@ def main(): | ||||
|     metadata["tokenizer_config"] = json.dumps(tokenizer_config) | ||||
|     metadata["max_kv_size"] = str(args.max_kv_size) | ||||
|     mx.save_safetensors(args.kv_cache_file, cache_dict, metadata) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|   | ||||
| @@ -12,7 +12,6 @@ DEFAULT_MAX_TOKENS = 100 | ||||
| DEFAULT_TEMP = 0.6 | ||||
| DEFAULT_TOP_P = 1.0 | ||||
| DEFAULT_SEED = 0 | ||||
| DEFAULT_MAX_KV_SIZE = 1024 | ||||
|  | ||||
|  | ||||
| def setup_arg_parser(): | ||||
| @@ -81,6 +80,7 @@ def setup_arg_parser(): | ||||
|         "--max-kv-size", | ||||
|         type=int, | ||||
|         help="Set the maximum key-value cache size", | ||||
|         default=None, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--kv-cache-file", | ||||
| @@ -199,12 +199,9 @@ def main(): | ||||
|  | ||||
|     # Determine the max kv size from the kv cache or passed arguments | ||||
|     max_kv_size = args.max_kv_size | ||||
|     if max_kv_size is None: | ||||
|         max_kv_size = ( | ||||
|             int(metadata["max_kv_size"]) | ||||
|             if cache_history is not None | ||||
|             else DEFAULT_MAX_KV_SIZE | ||||
|         ) | ||||
|     if cache_history is not None: | ||||
|         max_kv_size = metadata["max_kv_size"] | ||||
|         max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None | ||||
|  | ||||
|     generate( | ||||
|         model, | ||||
|   | ||||
| @@ -1,3 +1,3 @@ | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| __version__ = "0.17.1" | ||||
| __version__ = "0.18.0" | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun