mlx-examples/llms
otriscon 46da74fea2
Unify attention mask in LLMs (#911)
* Unify attention mask creation in LLMs.

Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc
code to create a mask for the attention mechanism. This usually takes the form:

```
    mask = None
    if h.shape[1] > 1:
        mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
        mask = mask.astype(h.dtype)
```

This correctly creates a mask only if the input consists of more than one token.
But this code assumes the multi-token input is at the beginning of inference.
If, for example, we are evaluating multiple tokens because of speculative
decoding or prompt cache reuse, this mask will not have the correct shape and
and will cause the raising of an exception in the attention computation.

Some of the models correctly implement the mask creation with code like this:

```
    mask = None
    if h.shape[1] > 1:
        mask = create_additive_causal_mask(
            h.shape[1], cache[0].offset if cache is not None else 0
        )
        mask = mask.astype(h.dtype)
```

This commit unifies the attention mask creation for all models with a new
function `create_attention_mask`, reducing code duplication and helping all
models support inference performance enhancements like those mentioned above.

* Allow batches in LLM key-value cache

The current implementation of the LLM key-value cache assumes that
the input batch is of size 1. Input batching (evaluating multiple
alterative inputs at the same time) can be a valuable tool for
speculative sampling and other techniques.

This change removes the hard-coded batch size from the code that
resizes the key-value cache.

* Simplify causal mask creation

Use the same codepath regardless of whether there's an offset or
not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717).

* Use old-style type annotation to avoid linter error
2024-07-25 16:45:22 -07:00
..
gguf_llm fixed the requirements (#803) 2024-05-29 06:14:19 -07:00
llama Quantize embedding / Update quantize API (#680) 2024-04-18 18:16:10 -07:00
mistral Quantize embedding / Update quantize API (#680) 2024-04-18 18:16:10 -07:00
mixtral Quantize embedding / Update quantize API (#680) 2024-04-18 18:16:10 -07:00
mlx_lm Unify attention mask in LLMs (#911) 2024-07-25 16:45:22 -07:00
speculative_decoding Fix incorrect type annotation (#720) 2024-04-24 15:52:43 -07:00
tests support load model by custom get_model_classes (#899) 2024-07-25 11:01:17 -07:00
CONTRIBUTING.md Enable unit testing in Circle and start some MLX LM tests (#545) 2024-03-07 09:31:57 -08:00
MANIFEST.in Mlx llm package (#301) 2024-01-12 10:25:56 -08:00
README.md Example of response generation with optional arguments (#853) 2024-07-09 06:49:59 -07:00
setup.py Configuration-based use of HF hub-hosted datasets for training (#701) 2024-06-26 10:20:50 -07:00

Generate Text with LLMs and MLX

The easiest way to get started is to install the mlx-lm package:

With pip:

pip install mlx-lm

With conda:

conda install -c conda-forge mlx-lm

The mlx-lm package also has:

Python API

You can use mlx-lm as a module:

from mlx_lm import load, generate

model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")

response = generate(model, tokenizer, prompt="hello", verbose=True)

To see a description of all the arguments you can do:

>>> help(generate)

Check out the generation example to see how to use the API in more detail.

The mlx-lm package also comes with functionality to quantize and optionally upload models to the Hugging Face Hub.

You can convert models in the Python API with:

from mlx_lm import convert

repo = "mistralai/Mistral-7B-Instruct-v0.3"
upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit"

convert(repo, quantize=True, upload_repo=upload_repo)

This will generate a 4-bit quantized Mistral 7B and upload it to the repo mlx-community/My-Mistral-7B-Instruct-v0.3-4bit. It will also save the converted model in the path mlx_model by default.

To see a description of all the arguments you can do:

>>> help(convert)

Streaming

For streaming generation, use the stream_generate function. This returns a generator object which streams the output text. For example,

from mlx_lm import load, stream_generate

repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
model, tokenizer = load(repo)

prompt = "Write a story about Einstein"

for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
    print(t, end="", flush=True)
print()

Command Line

You can also use mlx-lm from the command line with:

mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello"

This will download a Mistral 7B model from the Hugging Face Hub and generate text using the given prompt.

For a full list of options run:

mlx_lm.generate --help

To quantize a model from the command line run:

mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q

For more options run:

mlx_lm.convert --help

You can upload new models to Hugging Face by specifying --upload-repo to convert. For example, to upload a quantized Mistral-7B model to the MLX Hugging Face community you can do:

mlx_lm.convert \
    --hf-path mistralai/Mistral-7B-Instruct-v0.3 \
    -q \
    --upload-repo mlx-community/my-4bit-mistral

Supported Models

The example supports Hugging Face format Mistral, Llama, and Phi-2 style models. If the model you want to run is not supported, file an issue or better yet, submit a pull request.

Here are a few examples of Hugging Face models that work with this example:

Most Mistral, Llama, Phi-2, and Mixtral style models should work out of the box.

For some models (such as Qwen and plamo) the tokenizer requires you to enable the trust_remote_code option. You can do this by passing --trust-remote-code in the command line. If you don't specify the flag explicitly, you will be prompted to trust remote code in the terminal when running the model.

For Qwen models you must also specify the eos_token. You can do this by passing --eos-token "<|endoftext|>" in the command line.

These options can also be set in the Python API. For example:

model, tokenizer = load(
    "qwen/Qwen-7B",
    tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
)