mlx-examples/llms/mlx_lm
otriscon 46da74fea2
Unify attention mask in LLMs (#911)
* Unify attention mask creation in LLMs.

Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc
code to create a mask for the attention mechanism. This usually takes the form:

```
    mask = None
    if h.shape[1] > 1:
        mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
        mask = mask.astype(h.dtype)
```

This correctly creates a mask only if the input consists of more than one token.
But this code assumes the multi-token input is at the beginning of inference.
If, for example, we are evaluating multiple tokens because of speculative
decoding or prompt cache reuse, this mask will not have the correct shape and
and will cause the raising of an exception in the attention computation.

Some of the models correctly implement the mask creation with code like this:

```
    mask = None
    if h.shape[1] > 1:
        mask = create_additive_causal_mask(
            h.shape[1], cache[0].offset if cache is not None else 0
        )
        mask = mask.astype(h.dtype)
```

This commit unifies the attention mask creation for all models with a new
function `create_attention_mask`, reducing code duplication and helping all
models support inference performance enhancements like those mentioned above.

* Allow batches in LLM key-value cache

The current implementation of the LLM key-value cache assumes that
the input batch is of size 1. Input batching (evaluating multiple
alterative inputs at the same time) can be a valuable tool for
speculative sampling and other techniques.

This change removes the hard-coded batch size from the code that
resizes the key-value cache.

* Simplify causal mask creation

Use the same codepath regardless of whether there's an offset or
not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717).

* Use old-style type annotation to avoid linter error
2024-07-25 16:45:22 -07:00
..
examples Example of response generation with optional arguments (#853) 2024-07-09 06:49:59 -07:00
models Unify attention mask in LLMs (#911) 2024-07-25 16:45:22 -07:00
tuner Add GPT-neox model (#863) 2024-07-11 06:13:17 -07:00
__init__.py mlx_lm: Add Streaming Capability to Generate Function (#807) 2024-06-03 09:04:39 -07:00
convert.py Create executables for generate, lora, server, merge, convert (#682) 2024-04-16 16:08:49 -07:00
fuse.py Block sparse MM MoEs (#782) 2024-05-21 15:58:08 -07:00
generate.py mlx_lm: Add Streaming Capability to Generate Function (#807) 2024-06-03 09:04:39 -07:00
gguf.py fix(mlx-lm): type hints in gguf.py (#621) 2024-03-26 07:56:01 -07:00
LORA.md Configuration-based use of HF hub-hosted datasets for training (#701) 2024-06-26 10:20:50 -07:00
lora.py Pass use_dora parameter to linear_to_lora_layers (#885) 2024-07-11 14:34:34 -07:00
MANAGE.md Add model management functionality for local caches (#736) 2024-05-03 12:20:13 -07:00
manage.py Add model management functionality for local caches (#736) 2024-05-03 12:20:13 -07:00
MERGE.md Create executables for generate, lora, server, merge, convert (#682) 2024-04-16 16:08:49 -07:00
merge.py Create executables for generate, lora, server, merge, convert (#682) 2024-04-16 16:08:49 -07:00
py.typed Add py.typed to support PEP-561 (type-hinting) (#389) 2024-01-30 21:17:38 -08:00
README.md feat: move lora into mlx-lm (#337) 2024-01-23 08:44:37 -08:00
requirements.txt Example of response generation with optional arguments (#853) 2024-07-09 06:49:59 -07:00
sample_utils.py Use async eval (#670) 2024-04-11 13:18:23 -07:00
SERVER.md Logprobs info to completion API (#806) 2024-06-23 10:35:13 -07:00
server.py keep the server in a valid state (#889) 2024-07-15 18:35:36 -07:00
tokenizer_utils.py fix yi (#852) 2024-06-27 06:38:19 -07:00
UPLOAD.md Mlx llm package (#301) 2024-01-12 10:25:56 -08:00
utils.py support load model by custom get_model_classes (#899) 2024-07-25 11:01:17 -07:00
version.py Configuration-based use of HF hub-hosted datasets for training (#701) 2024-06-26 10:20:50 -07:00

Generate Text with MLX and 🤗 Hugging Face

This an example of large language model text generation that can pull models from the Hugging Face Hub.

For more information on this example, see the README in the parent directory.

This package also supports fine tuning with LoRA or QLoRA. For more information see the LoRA documentation.