Commit Graph

15 Commits

Author SHA1 Message Date
Alex Barron
d4ef909d4a
Length masking for batch inputs (#1173)
* length masking

* add mask to mlx_lm model interface

* remove lengths

* fix test:

* comment + fix
2024-12-18 19:43:52 -08:00
Alex Barron
85ffd2c96a
Quantized KV Cache (#1075)
* add QuantizedKVCache

* simplify

* add tests

* single sdpa function

* fix sed

* in place

* fix tests

* support different k and v head dims
2024-10-31 16:59:52 -07:00
Awni Hannun
fca087be49
More cache improvements (#1015)
* fix rotating kv cache for chat use case

* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

* nit in chat

* fix tests

* fix tests

* fix tests

* docs

* chat command

* comments + docs

* Define meta_state on all Cache implementations

* fixes + trim_prompt_cache api

* fix default model

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-10-07 20:45:51 -07:00
Awni Hannun
7be292c0c9
Handle longer prompt/generation (#931)
* rebase

* nits

* nit

* fix rotating cache with step prefill

* update version
2024-08-16 15:28:39 -07:00
otriscon
46da74fea2
Unify attention mask in LLMs (#911)
* Unify attention mask creation in LLMs.

Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc
code to create a mask for the attention mechanism. This usually takes the form:

```
    mask = None
    if h.shape[1] > 1:
        mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
        mask = mask.astype(h.dtype)
```

This correctly creates a mask only if the input consists of more than one token.
But this code assumes the multi-token input is at the beginning of inference.
If, for example, we are evaluating multiple tokens because of speculative
decoding or prompt cache reuse, this mask will not have the correct shape and
and will cause the raising of an exception in the attention computation.

Some of the models correctly implement the mask creation with code like this:

```
    mask = None
    if h.shape[1] > 1:
        mask = create_additive_causal_mask(
            h.shape[1], cache[0].offset if cache is not None else 0
        )
        mask = mask.astype(h.dtype)
```

This commit unifies the attention mask creation for all models with a new
function `create_attention_mask`, reducing code duplication and helping all
models support inference performance enhancements like those mentioned above.

* Allow batches in LLM key-value cache

The current implementation of the LLM key-value cache assumes that
the input batch is of size 1. Input batching (evaluating multiple
alterative inputs at the same time) can be a valuable tool for
speculative sampling and other techniques.

This change removes the hard-coded batch size from the code that
resizes the key-value cache.

* Simplify causal mask creation

Use the same codepath regardless of whether there's an offset or
not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717).

* Use old-style type annotation to avoid linter error
2024-07-25 16:45:22 -07:00
Awni Hannun
09aaeac72c
fix moe conversion (#802) 2024-05-31 12:36:05 -07:00
Angelos Katharopoulos
9f671228cd
Block sparse MM MoEs (#782)
- Adds SwitchLinear
- Adds QuantizedSwitchLinear
2024-05-21 15:58:08 -07:00
Awni Hannun
ee60e2a9d5
Kv cache (#643)
* in place kv_cache

* fix

* fix kv cache size

* partially fix kv cache dtype

* step kv cache

* multiple of step size

* more teests + kv cache

* more kv cache

* udpate all models to use kv cache
2024-05-08 08:18:13 -07:00
Awni Hannun
d3f8e4aee9
Fix argpartition call in Mixtral and other MOES (#676)
* Update mixtral.py

* fix all moes

---------

Co-authored-by: yuhai-china <yuhai.china@gmail.com>
2024-04-12 11:00:56 -07:00
Awni Hannun
b80adbcc3e
DBRX (#628)
* dbrx

* format

* format

* comments

* change scores slightly

* remove inadvertant import
2024-03-28 21:03:53 -07:00
Awni Hannun
b8a348c1b8
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
2024-03-23 07:13:51 -07:00
Awni Hannun
e4b19bb9e1
Make attention faster for a some models (#574)
* make attention faster for a couple models

* remove unused generation flags

* add comment on lora

* include text files as well
2024-03-14 21:35:54 -07:00
Awni Hannun
f24edfa9dc
[mlx-lm] Add precompiled normalizations (#451)
* add precompiled normalizations

* nits
2024-02-22 12:40:55 -08:00
Awni Hannun
8fd953ee2b
Support for slerp merging models (#455)
* support for slerp merging models

* docs

* update docs

* format'
2024-02-19 20:37:15 -08:00
Awni Hannun
d4666615bb
Lazy import + refactor Lora layer addition (#426)
* lazy model import in mlx_lm

* change lora loading

* fix olmo lora

* remove a bunch of unused stuff from plamo

* move phixtral to mlx-lm and out of llms/
2024-02-12 10:51:02 -08:00