Files
mlx-examples/llms/lookup_decoding/README.md
2023-12-29 12:18:09 +01:00

2.0 KiB

Prompt Lookup Decoding

This example implements prompt lookup decoding for LLM text generation. This particular example uses Mistral as the model but that can be changed with minimal adjustments to decoder.py. Prompt lookup decoding is a variation of speculative decoding where the draft model has been replaced with a simple prompt lookup function to generate the draft. For input-grounded tasks such as summarization, document QA or code editing, this technique can provide significant speedups with no effect on output quality.

Setup

Install the dependencies:

pip install -r requirements.txt

Next, download a Mistral model and tokenizer:

curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
tar -xf mistral-7B-v0.1.tar

Then, convert the weights with:

python convert.py --torch-path <path_to_torch>

To generate a 4-bit quantized model, use -q. For a full list of options:

python convert.py --help

By default, the conversion script will make the directory mlx_model and save the converted weights.npz, tokenizer.model, and config.json there.

Tip

Alternatively, you can also download a few converted checkpoints from the MLX Community organization on Hugging Face and skip the conversion step.

Run

python decoder.py --prompt "[INST] Repeat the following phrase 10 times: 'The quick brown fox jumps over the lazy dog.'. Don't say antyhing else. [/INST]" --max-tokens 250

Alternatively you can call the API

from decoder import PromptLookupDecoder

prompt = "[INST] Repeat the following phrase 10 times: 'The quick brown fox jumps over the lazy dog.'. Don't say antyhing else. [/INST] "

engine = PromptLookupDecoder("mlx_model")

engine.prompt_lookup(prompt=prompt, max_tokens=250, n_draft=10, 
    ngram_max=3, ngram_min=1, temp=0.0, seed=0, color=True)