* fix rotating kv cache for chat use case
* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat
* nit in chat
* fix tests
* fix tests
* fix tests
* docs
* chat command
* comments + docs
* Define meta_state on all Cache implementations
* fixes + trim_prompt_cache api
* fix default model
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* initial commit
* initial commit
* Adding first lines
* adding x, and dt projection layers
* adding the clamping mechanism
* First succesful inference
* last commit for today - added custom geenrate function and it works as expected, will try training and then with loading a model from the hub
* clean up
* save up
* almost
* update
* update
* fixed cache handeling
* fixed loading
* added seperate generat_step method in the model and also in the utils to automaticaly use the generate step mthod in the model class
* quick update
* still not working
* save
* still not working
* initial commit
* utils.py logits = logits[:, -1, :] TypeError: tuple indices must be integers or slices, not tuple
* update
* update
* Fixing the Batching Depfwise Comnvolution and multi token input
* fixing generate and logits outputs
* Done!
* Fixing the cache handling, generating works now trying training
* update ACKNOWLEDGEMENTS
* removing the model_type if stuff in the _step loop in generate_step and adding MambaCache in base.py for training easier generations and removing mamba in tuner/utils.
* quick clean up
* update trainer/utils for right initialisation of the layers for LoRA, but not working.
* clean up
* Forther update to trainer/utils for correct layer selection. Successfull training
* removing extra mamba-infer.py file
* clean up, reformating will come later
* reformat and big clean up, final commit
* some speedups and cleanups
* fix test
* nits
* nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* feat: Nemotron
https://huggingface.co/nvidia/Minitron-4B-Base
This is basically Llama with partial RoPE and LayerNorm instead of
BatchNorm. Also they add 1 to the LayerNorm weight for some reason.
* fixup! feat: Nemotron
* nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* use fast rope
* fix llama
* use fast rope for llama3.1
* requires unreleased mlx
* fix su
* fix deepseek v2
* only one of base or freqs
* nit
* fix
* hard code freqs
* feat: deepseek v1
DeepSeek is still releasing models on the DeepSeek V1 architecture.
```sh
mlx_lm.convert --hf-path deepseek-ai/DeepSeek-Prover-V1.5-RL --mlx-path DeepSeek-Prover-V1.5-RL-8bit --q-bits 8 -q
mlx_lm.generate --model DeepSeek-Prover-V1.5-RL-8bit --ignore-chat-template --max-tokens 512 --prompt 'import Mathlib
import Aesop
set_option maxHeartbeats 0
open BigOperators Real Nat Topology Rat
/-- The second and fourth terms of a geometric sequence are $2$ and $6$. Which of the following is a possible first term?
Show that it is $\frac{2\sqrt{3}}{3}$.-/
theorem amc12b_2003_p6 (a r : ℝ) (u : ℕ → ℝ) (h₀ : ∀ k, u k = a * r ^ k) (h₁ : u 1 = 2)
(h₂ : u 3 = 6) : u 0 = 2 / Real.sqrt 3 ∨ u 0 = -(2 / Real.sqrt 3) := by'
```
* nits
* nits
* nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* Unify attention mask creation in LLMs.
Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc
code to create a mask for the attention mechanism. This usually takes the form:
```
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
```
This correctly creates a mask only if the input consists of more than one token.
But this code assumes the multi-token input is at the beginning of inference.
If, for example, we are evaluating multiple tokens because of speculative
decoding or prompt cache reuse, this mask will not have the correct shape and
and will cause the raising of an exception in the attention computation.
Some of the models correctly implement the mask creation with code like this:
```
mask = None
if h.shape[1] > 1:
mask = create_additive_causal_mask(
h.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(h.dtype)
```
This commit unifies the attention mask creation for all models with a new
function `create_attention_mask`, reducing code duplication and helping all
models support inference performance enhancements like those mentioned above.
* Allow batches in LLM key-value cache
The current implementation of the LLM key-value cache assumes that
the input batch is of size 1. Input batching (evaluating multiple
alterative inputs at the same time) can be a valuable tool for
speculative sampling and other techniques.
This change removes the hard-coded batch size from the code that
resizes the key-value cache.
* Simplify causal mask creation
Use the same codepath regardless of whether there's an offset or
not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717).
* Use old-style type annotation to avoid linter error
* add dynamicNTK scaling rope
* remove unused var
* fix rope base
* llama3.1 fixes
* TODO for rope eval
* vectorise llama3 base freq calculation
* removed the arbitrary 2.0 rope_scale default case
* fix slow llama3.1 generation by evaluating stateless part of DynamicNTKScalingRoPE in init
* nits + format
* use mx.pi
* fix tests and add test for 3.1
---------
Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
* Add logit soft capping to gemma, and fix precision issues
Gemma was babbling nonsense - so I figured out it was due to not having logit softcapping and precision issues causing NaNs (so I implemented the softcapping and added more float32 inference). gemma-27b-it-4bit now works flawlessly (or near-flawlessly, no sliding-window attention).
* get rid of comments
* get rid of last comments (sry lol)
* nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* Su-RoPE
* nits
* Update su_rope.py
* Update su_rope.py
Per GPT4: "The error TypeError: 'type' object is not subscriptable is caused by using the type hint list[float] in a version of Python that does not support it. This syntax is only available in Python 3.9 and later."
* Ran isort
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* GPT-2 model support
* Add test for gpt2 model
* Fix weight sanitizing for quantization
* use approx gelu
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* add support for granite 3-8B config
* add gpt_bigcode
* add positional embedding condition.
* add support for granite 3-8B config
* add gpt_bigcode
* add positional embedding condition.
* remove unused function
* rebase fix
* move position emebedding to mask creation
* add to tuner and format
* add support for granite 3-8B config
* add gpt_bigcode
* add positional embedding condition.
* add support for granite 3-8B config
* add gpt_bigcode
* add positional embedding condition.
* rebase fix
* move position emebedding to mask creation
* add to tuner and format
* refactor mask
* remove dropout layers
* Pad mask with zeros for non-square attention matrices
The current implementation of the mask assumes the attention matrix is square, which is true if there is no cache. However, if one wishes to produce multiple tokens at a time, such as in speculative decoding implementations, a rectangular mask is necessary.
This change pads the bottom of the mask with zeros so multi-token decoding with a cache works correctly.
* Directly create mask instead of padding
* Update llama.py
* Added support for the MiniCPM architecture
* Added support for the MiniCPM architecture
* Updated utils.py and LORA.md
* Updated utils.py and LORA.md
* Update implementation details for MiniCPM architecture
* Cleaning up
* fixed the missing lm.head layer problem
* Refactor Model class to dynamically handle tied and untied word embeddings
* Quick update
* added a dynamic rope scaling base calucaltion
* Added support for the MiniCPM architecture
* Added support for the MiniCPM architecture
* Updated utils.py and LORA.md
* Updated utils.py and LORA.md
* Update implementation details for MiniCPM architecture
* Cleaning up
* fixed the missing lm.head layer problem
* Refactor Model class to dynamically handle tied and untied word embeddings
* added a dynamic rope scaling base calucaltion
* quick fix and clean up
* clean up again
* removed the MiniCPMNorm class as its not used
* forgot something, sorry
* format
* version bump
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* Initial config handler and test
* Added means to run from CLI
* Update lora config loading and tests
* Constrain scheduler config (warmup and minimum LR) for each kind
* Update reference to moved schedule_config module
* Minor fix
* Fix typos
* Moved build_schedule and tests
* nits in schedule config
* flake
* fix path
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* use nn.RMSNorm, use sdpa, cleanup
* bump mlx versions
* minor update
* use fast layer norm
* version bump
* update requirement for whisper
* update requirement for gguf
* Add Starcoder2 model and update utils.py
* Refactor model arguments and modules in starcoder2.py
* Refactor FeedForward class to MLP in starcoder2.py
* Fix typo
* pre-commit
* Refactor starcoder2.py: Update model arguments and modules
* Fix LM head and MLP layers
* Rename input layer norm
* Update bias in linear layers
* Refactor token embeddings in Starcoder2Model
* Rename to standard HF attention layer name
* Add LayerNorm
* Add transposed token embeddings (like in Gemma)
* Refactor MLP and TransformerBlock classes
* Add tie_word_embeddings option to ModelArgs and update Model implementation
* Add conditional check for tying word embeddings in Starcoder2Model
* Fix bias in lm_head linear layer
* Remove unused LayerNorm in stablelm
* Update transformers dependency to use GitHub repository
* fix lm head bug, revert transformer req
* Update RoPE initialization in Attention class
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* StableLM now part of Transformers as stablelm rather than stablelm_epoch; changed config to match new changes
* removing old file
* reference new stablelm
Mixtral models throw the following exception
```
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 119, in <module>
main(args)
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 96, in main
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 278, in load
model = load_model(model_path)
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 221, in load_model
model_class, model_args_class = _get_classes(config=config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 46, in _get_classes
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 940, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/models/mixtral.py", line 11, in <module>
@dataclass
^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1230, in dataclass
return wrap(cls)
^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1220, in wrap
return _process_class(cls, init, repr, eq, order, unsafe_hash,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1027, in _process_class
_init_fn(all_init_fields,
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 545, in _init_fn
raise TypeError(f'non-default argument {f.name!r} '
TypeError: non-default argument 'model_type' follows default argument
```
* lazy model import in mlx_lm
* change lora loading
* fix olmo lora
* remove a bunch of unused stuff from plamo
* move phixtral to mlx-lm and out of llms/
* initial commit
* style fixes
* update of ACKNOWLEDGMENTS
* fixed comment
* minor refactoring; removed unused imports
* added cifar and cvae to top-level README.md
* removed mention of cuda/mps in argparse
* fixed training status output
* load_weights() with strict=True
* pretrained model update
* fixed imports and style
* requires mlx>=0.0.9
* updated with results using mlx 0.0.9
* removed mention of private repo
* simplify and combine to one file, more consistency with other exmaples
* few more nits
* nits
* spell
* format
---------
Co-authored-by: Awni Hannun <awni@apple.com>