Examples in the MLX framework
mlx
Go to file
otriscon 46da74fea2
Unify attention mask in LLMs (#911)
* Unify attention mask creation in LLMs.

Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc
code to create a mask for the attention mechanism. This usually takes the form:

```
    mask = None
    if h.shape[1] > 1:
        mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
        mask = mask.astype(h.dtype)
```

This correctly creates a mask only if the input consists of more than one token.
But this code assumes the multi-token input is at the beginning of inference.
If, for example, we are evaluating multiple tokens because of speculative
decoding or prompt cache reuse, this mask will not have the correct shape and
and will cause the raising of an exception in the attention computation.

Some of the models correctly implement the mask creation with code like this:

```
    mask = None
    if h.shape[1] > 1:
        mask = create_additive_causal_mask(
            h.shape[1], cache[0].offset if cache is not None else 0
        )
        mask = mask.astype(h.dtype)
```

This commit unifies the attention mask creation for all models with a new
function `create_attention_mask`, reducing code duplication and helping all
models support inference performance enhancements like those mentioned above.

* Allow batches in LLM key-value cache

The current implementation of the LLM key-value cache assumes that
the input batch is of size 1. Input batching (evaluating multiple
alterative inputs at the same time) can be a valuable tool for
speculative sampling and other techniques.

This change removes the hard-coded batch size from the code that
resizes the key-value cache.

* Simplify causal mask creation

Use the same codepath regardless of whether there's an offset or
not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717).

* Use old-style type annotation to avoid linter error
2024-07-25 16:45:22 -07:00
.circleci Configuration-based use of HF hub-hosted datasets for training (#701) 2024-06-26 10:20:50 -07:00
bert - Removed unused Python imports (#683) 2024-04-16 07:50:32 -07:00
cifar - Removed unused Python imports (#683) 2024-04-16 07:50:32 -07:00
clip refactor: add force_download parameter to get_model_path function (#800) 2024-07-23 13:10:20 -07:00
cvae Update a few examples to use compile (#420) 2024-02-08 13:00:41 -08:00
gcn - Removed unused Python imports (#683) 2024-04-16 07:50:32 -07:00
llava Add optional EOS token for llava example (#753) 2024-05-08 06:04:36 -07:00
llms Unify attention mask in LLMs (#911) 2024-07-25 16:45:22 -07:00
lora Validation with full data set, results in NaN validation score (#879) 2024-07-10 08:36:11 -07:00
mnist Use stable url for MNIST (#749) 2024-05-03 17:13:05 -07:00
normalizing_flow Update a few examples to use compile (#420) 2024-02-08 13:00:41 -08:00
segment_anything Segment Anything Model (#552) 2024-06-02 16:45:51 -07:00
speechcommands - Removed unused Python imports (#683) 2024-04-16 07:50:32 -07:00
stable_diffusion Quantize embedding / Update quantize API (#680) 2024-04-18 18:16:10 -07:00
t5 Switch to fast RMS/LN Norm (#603) 2024-03-23 07:13:51 -07:00
transformer_lm transformer_lm: add --dataset enwik8 (#838) 2024-06-26 11:59:01 -07:00
whisper gpu featurization (#824) 2024-06-07 08:59:44 -07:00
.gitignore Align CLI args and some smaller fixes (#167) 2023-12-22 14:34:32 -08:00
.pre-commit-config.yaml feat: Update black-pre-commit-mirror to version 24.3.0 (#675) 2024-04-11 07:28:26 -07:00
ACKNOWLEDGMENTS.md Segment Anything Model (#552) 2024-06-02 16:45:51 -07:00
CODE_OF_CONDUCT.md contribution + code of conduct 2023-11-29 12:31:18 -08:00
CONTRIBUTING.md feat: add update_config functionality (#531) 2024-03-14 06:36:05 -07:00
LICENSE consistent copyright 2023-11-30 11:11:04 -08:00
README.md Port of phi3small (#794) 2024-05-31 12:54:14 -07:00

MLX Examples

This repo contains a variety of standalone examples using the MLX framework.

The MNIST example is a good starting point to learn how to use MLX.

Some more useful examples are listed below.

Text Models

  • MLX LM a package for LLM text generation, fine-tuning, and more.
  • Transformer language model training.
  • Minimal examples of large scale text generation with LLaMA, Mistral, and more in the LLMs directory.
  • A mixture-of-experts (MoE) language model with Mixtral 8x7B.
  • Parameter efficient fine-tuning with LoRA or QLoRA.
  • Text-to-text multi-task Transformers with T5.
  • Bidirectional language understanding with BERT.

Image Models

Audio Models

Multimodal models

  • Joint text and image embeddings with CLIP.
  • Text generation from image and text inputs with LLaVA.

Other Models

  • Semi-supervised learning on graph-structured data with GCN.
  • Real NVP normalizing flow for density estimation and sampling.

Hugging Face

Note: You can now directly download a few converted checkpoints from the MLX Community organization on Hugging Face. We encourage you to join the community and contribute new models.

Contributing

We are grateful for all of our contributors. If you contribute to MLX Examples and wish to be acknowledged, please add your name to the list in your pull request.

Citing MLX Examples

The MLX software suite was initially developed with equal contribution by Awni Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find MLX Examples useful in your research and wish to cite it, please use the following BibTex entry:

@software{mlx2023,
  author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
  title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
  url = {https://github.com/ml-explore},
  version = {0.0},
  year = {2023},
}