* fix rotating kv cache for chat use case
* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat
* nit in chat
* fix tests
* fix tests
* fix tests
* docs
* chat command
* comments + docs
* Define meta_state on all Cache implementations
* fixes + trim_prompt_cache api
* fix default model
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* Unify attention mask creation in LLMs.
Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc
code to create a mask for the attention mechanism. This usually takes the form:
```
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
```
This correctly creates a mask only if the input consists of more than one token.
But this code assumes the multi-token input is at the beginning of inference.
If, for example, we are evaluating multiple tokens because of speculative
decoding or prompt cache reuse, this mask will not have the correct shape and
and will cause the raising of an exception in the attention computation.
Some of the models correctly implement the mask creation with code like this:
```
mask = None
if h.shape[1] > 1:
mask = create_additive_causal_mask(
h.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(h.dtype)
```
This commit unifies the attention mask creation for all models with a new
function `create_attention_mask`, reducing code duplication and helping all
models support inference performance enhancements like those mentioned above.
* Allow batches in LLM key-value cache
The current implementation of the LLM key-value cache assumes that
the input batch is of size 1. Input batching (evaluating multiple
alterative inputs at the same time) can be a valuable tool for
speculative sampling and other techniques.
This change removes the hard-coded batch size from the code that
resizes the key-value cache.
* Simplify causal mask creation
Use the same codepath regardless of whether there's an offset or
not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717).
* Use old-style type annotation to avoid linter error
* use nn.RMSNorm, use sdpa, cleanup
* bump mlx versions
* minor update
* use fast layer norm
* version bump
* update requirement for whisper
* update requirement for gguf
Mixtral models throw the following exception
```
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 119, in <module>
main(args)
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 96, in main
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 278, in load
model = load_model(model_path)
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 221, in load_model
model_class, model_args_class = _get_classes(config=config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 46, in _get_classes
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 940, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/models/mixtral.py", line 11, in <module>
@dataclass
^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1230, in dataclass
return wrap(cls)
^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1220, in wrap
return _process_class(cls, init, repr, eq, order, unsafe_hash,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1027, in _process_class
_init_fn(all_init_fields,
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 545, in _init_fn
raise TypeError(f'non-default argument {f.name!r} '
TypeError: non-default argument 'model_type' follows default argument
```
* lazy model import in mlx_lm
* change lora loading
* fix olmo lora
* remove a bunch of unused stuff from plamo
* move phixtral to mlx-lm and out of llms/