mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Docs on prompt scaling (#963)
* docs on prompt scaling * remove unused var * nits
This commit is contained in:
parent
1003a8b2dd
commit
b1186e2a81
@ -38,7 +38,9 @@ To see a description of all the arguments you can do:
|
|||||||
>>> help(generate)
|
>>> help(generate)
|
||||||
```
|
```
|
||||||
|
|
||||||
Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail.
|
Check out the [generation
|
||||||
|
example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py)
|
||||||
|
to see how to use the API in more detail.
|
||||||
|
|
||||||
The `mlx-lm` package also comes with functionality to quantize and optionally
|
The `mlx-lm` package also comes with functionality to quantize and optionally
|
||||||
upload models to the Hugging Face Hub.
|
upload models to the Hugging Face Hub.
|
||||||
@ -122,10 +124,44 @@ mlx_lm.convert \
|
|||||||
--upload-repo mlx-community/my-4bit-mistral
|
--upload-repo mlx-community/my-4bit-mistral
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Long Prompts and Generations
|
||||||
|
|
||||||
|
MLX LM has some tools to scale efficiently to long prompts and generations:
|
||||||
|
|
||||||
|
- A rotating fixed-size key-value cache.
|
||||||
|
- Prompt caching
|
||||||
|
|
||||||
|
To use the rotating key-value cache pass the argument `--max-kv-size n` where
|
||||||
|
`n` can be any integer. Smaller values like `512` will use very little RAM but
|
||||||
|
result in worse quality. Larger values like `4096` or higher will use more RAM
|
||||||
|
but have better quality.
|
||||||
|
|
||||||
|
Caching prompts can substantially speedup reusing the same long context with
|
||||||
|
different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cat prompt.txt | mlx_lm.cache_prompt \
|
||||||
|
--model mistralai/Mistral-7B-Instruct-v0.3 \
|
||||||
|
--prompt - \
|
||||||
|
--kv-cache-file mistral_prompt.safetensors
|
||||||
|
```
|
||||||
|
|
||||||
|
Then use the cached prompt with `mlx_lm.generate`:
|
||||||
|
|
||||||
|
```
|
||||||
|
mlx_lm.generate \
|
||||||
|
--kv-cache-file mistral_prompt.safetensors \
|
||||||
|
--prompt "\nSummarize the above text."
|
||||||
|
```
|
||||||
|
|
||||||
|
The cached prompt is treated as a prefix to the supplied prompt. Also notice
|
||||||
|
when using a cached prompt, the model to use is read from the cache and need
|
||||||
|
not be supplied explicitly.
|
||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
|
|
||||||
The example supports Hugging Face format Mistral, Llama, and Phi-2 style
|
MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
|
||||||
models. If the model you want to run is not supported, file an
|
run is not supported, file an
|
||||||
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||||
submit a pull request.
|
submit a pull request.
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ def setup_arg_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-kv-size",
|
"--max-kv-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1024,
|
default=None,
|
||||||
help="Set the maximum key-value cache size",
|
help="Set the maximum key-value cache size",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -147,3 +147,7 @@ def main():
|
|||||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||||
metadata["max_kv_size"] = str(args.max_kv_size)
|
metadata["max_kv_size"] = str(args.max_kv_size)
|
||||||
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
|
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -12,7 +12,6 @@ DEFAULT_MAX_TOKENS = 100
|
|||||||
DEFAULT_TEMP = 0.6
|
DEFAULT_TEMP = 0.6
|
||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
DEFAULT_MAX_KV_SIZE = 1024
|
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
@ -81,6 +80,7 @@ def setup_arg_parser():
|
|||||||
"--max-kv-size",
|
"--max-kv-size",
|
||||||
type=int,
|
type=int,
|
||||||
help="Set the maximum key-value cache size",
|
help="Set the maximum key-value cache size",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-file",
|
"--kv-cache-file",
|
||||||
@ -199,12 +199,9 @@ def main():
|
|||||||
|
|
||||||
# Determine the max kv size from the kv cache or passed arguments
|
# Determine the max kv size from the kv cache or passed arguments
|
||||||
max_kv_size = args.max_kv_size
|
max_kv_size = args.max_kv_size
|
||||||
if max_kv_size is None:
|
if cache_history is not None:
|
||||||
max_kv_size = (
|
max_kv_size = metadata["max_kv_size"]
|
||||||
int(metadata["max_kv_size"])
|
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
|
||||||
if cache_history is not None
|
|
||||||
else DEFAULT_MAX_KV_SIZE
|
|
||||||
)
|
|
||||||
|
|
||||||
generate(
|
generate(
|
||||||
model,
|
model,
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.17.1"
|
__version__ = "0.18.0"
|
||||||
|
Loading…
Reference in New Issue
Block a user