Example of response generation with optional arguments (#853)

* Generate response with optional arguments

* Reference response generation example

* Include transformers and sentencepiece

* Update example to run Mistral-7B-Instruct-v0.3

* Link to generation example

* Style changes from pre-commit
This commit is contained in:
Alex Wozniakowski 2024-07-09 06:49:59 -07:00 committed by GitHub
parent 68e88d42fb
commit 63800c8feb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 1 deletions

View File

@ -38,6 +38,8 @@ To see a description of all the arguments you can do:
>>> help(generate)
```
Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail.
The `mlx-lm` package also comes with functionality to quantize and optionally
upload models to the Hugging Face Hub.

View File

@ -0,0 +1,40 @@
from mlx_lm import generate, load
# Specify the checkpoint
checkpoint = "mistralai/Mistral-7B-Instruct-v0.3"
# Load the corresponding model and tokenizer
model, tokenizer = load(path_or_hf_repo=checkpoint)
# Specify the prompt and conversation history
prompt = "Why is the sky blue?"
conversation = [{"role": "user", "content": prompt}]
# Transform the prompt into the chat template
prompt = tokenizer.apply_chat_template(
conversation=conversation, tokenize=False, add_generation_prompt=True
)
# Specify the maximum number of tokens
max_tokens = 1_000
# Specify if tokens and timing information will be printed
verbose = True
# Some optional arguments for causal language model generation
generation_args = {
"temp": 0.7,
"repetition_penalty": 1.2,
"repetition_context_size": 20,
"top_p": 0.95,
}
# Generate a response with the specified settings
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
verbose=verbose,
**generation_args,
)

View File

@ -1,6 +1,6 @@
mlx>=0.14.1
numpy
transformers>=4.39.3
transformers[sentencepiece]>=4.39.3
protobuf
pyyaml
jinja2