Docs on prompt scaling (#963)

* docs on prompt scaling

* remove unused var

* nits
This commit is contained in:
Awni Hannun 2024-08-29 15:05:17 -07:00 committed by GitHub
parent 1003a8b2dd
commit b1186e2a81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 49 additions and 12 deletions

View File

@ -38,7 +38,9 @@ To see a description of all the arguments you can do:
>>> help(generate) >>> help(generate)
``` ```
Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail. Check out the [generation
example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py)
to see how to use the API in more detail.
The `mlx-lm` package also comes with functionality to quantize and optionally The `mlx-lm` package also comes with functionality to quantize and optionally
upload models to the Hugging Face Hub. upload models to the Hugging Face Hub.
@ -122,10 +124,44 @@ mlx_lm.convert \
--upload-repo mlx-community/my-4bit-mistral --upload-repo mlx-community/my-4bit-mistral
``` ```
### Long Prompts and Generations
MLX LM has some tools to scale efficiently to long prompts and generations:
- A rotating fixed-size key-value cache.
- Prompt caching
To use the rotating key-value cache pass the argument `--max-kv-size n` where
`n` can be any integer. Smaller values like `512` will use very little RAM but
result in worse quality. Larger values like `4096` or higher will use more RAM
but have better quality.
Caching prompts can substantially speedup reusing the same long context with
different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
```bash
cat prompt.txt | mlx_lm.cache_prompt \
--model mistralai/Mistral-7B-Instruct-v0.3 \
--prompt - \
--kv-cache-file mistral_prompt.safetensors
```
Then use the cached prompt with `mlx_lm.generate`:
```
mlx_lm.generate \
--kv-cache-file mistral_prompt.safetensors \
--prompt "\nSummarize the above text."
```
The cached prompt is treated as a prefix to the supplied prompt. Also notice
when using a cached prompt, the model to use is read from the cache and need
not be supplied explicitly.
### Supported Models ### Supported Models
The example supports Hugging Face format Mistral, Llama, and Phi-2 style MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
models. If the model you want to run is not supported, file an run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, [issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request. submit a pull request.

View File

@ -56,7 +56,7 @@ def setup_arg_parser():
parser.add_argument( parser.add_argument(
"--max-kv-size", "--max-kv-size",
type=int, type=int,
default=1024, default=None,
help="Set the maximum key-value cache size", help="Set the maximum key-value cache size",
) )
parser.add_argument( parser.add_argument(
@ -147,3 +147,7 @@ def main():
metadata["tokenizer_config"] = json.dumps(tokenizer_config) metadata["tokenizer_config"] = json.dumps(tokenizer_config)
metadata["max_kv_size"] = str(args.max_kv_size) metadata["max_kv_size"] = str(args.max_kv_size)
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata) mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
if __name__ == "__main__":
main()

View File

@ -12,7 +12,6 @@ DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6 DEFAULT_TEMP = 0.6
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0 DEFAULT_SEED = 0
DEFAULT_MAX_KV_SIZE = 1024
def setup_arg_parser(): def setup_arg_parser():
@ -81,6 +80,7 @@ def setup_arg_parser():
"--max-kv-size", "--max-kv-size",
type=int, type=int,
help="Set the maximum key-value cache size", help="Set the maximum key-value cache size",
default=None,
) )
parser.add_argument( parser.add_argument(
"--kv-cache-file", "--kv-cache-file",
@ -199,12 +199,9 @@ def main():
# Determine the max kv size from the kv cache or passed arguments # Determine the max kv size from the kv cache or passed arguments
max_kv_size = args.max_kv_size max_kv_size = args.max_kv_size
if max_kv_size is None: if cache_history is not None:
max_kv_size = ( max_kv_size = metadata["max_kv_size"]
int(metadata["max_kv_size"]) max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
if cache_history is not None
else DEFAULT_MAX_KV_SIZE
)
generate( generate(
model, model,

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.17.1" __version__ = "0.18.0"