Commit Graph

65 Commits

Author SHA1 Message Date
Awni Hannun
538339b599 gemma2 (#855) 2024-06-27 10:06:28 -07:00
Yi Wang
a7598e9456 Fix mypy errors with models/{qwen2,qwen2_moe,startcoder2}.py (#835)
* Fix starcoder.py

* Fix qwen2

* Remvoe unnecessary assert not None
2024-06-14 09:44:50 -07:00
Yi Wang
6da07fb1b0 make models/phi3.py and models/phi3small.py compatible with mypy (#833) 2024-06-12 06:53:55 -07:00
JosefAlbers
fda41545a6 Su-RoPE(Rotary Position Embedding) for Phi-3 (#813)
* Su-RoPE

* nits

* Update su_rope.py

* Update su_rope.py

Per GPT4: "The error TypeError: 'type' object is not subscriptable is caused by using the type hint list[float] in a version of Python that does not support it. This syntax is only available in Python 3.9 and later."

* Ran isort

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-11 06:20:04 -07:00
Yi Wang
a54dfd698e Correct the type annotation of cache in llama.py (#828)
* Update

* Fix isort
2024-06-10 15:18:34 -07:00
Yi Wang
bb8227f181 Correct type annotation of llama.ModelArgs.num_key_value_heads (#827) 2024-06-10 14:47:31 -07:00
Derek Lewis
89b0b75250 GPT2 Support (#798)
* GPT-2 model support

* Add test for gpt2 model

* Fix weight sanitizing for quantization

* use approx gelu

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-02 16:33:20 -07:00
Awni Hannun
81318ad4a8 Port of phi3small (#794)
* start port of phi3small

* fix phi3

* use block sparsity

* compile activation

* nits in readme / mlx lm version
2024-05-31 12:54:14 -07:00
Awni Hannun
09aaeac72c fix moe conversion (#802) 2024-05-31 12:36:05 -07:00
Chen Xin
aac98ca6f4 support internlm2 (#797)
* support internlm2

* only attention projections

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-05-27 06:22:21 -07:00
Awni Hannun
ca7ce60c91 Rename block sparse to gather (#793)
* rename block sparse to gather

* pin mlx version
2024-05-23 19:47:35 -07:00
Prince Canuma
69700d8431 Add support for Phi-3 Medium (#790)
* update to support phi-3 medium

* fuse qkv split
2024-05-22 16:47:06 -07:00
Prince Canuma
b044ce2acf Add support for ibm granite (#758)
* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* remove unused function

* rebase fix

* move position emebedding to mask creation

* add to tuner and format

* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* rebase fix

* move position emebedding to mask creation

* add to tuner and format

* refactor mask

* remove dropout layers
2024-05-21 20:16:31 -07:00
Angelos Katharopoulos
9f671228cd Block sparse MM MoEs (#782)
- Adds SwitchLinear
- Adds QuantizedSwitchLinear
2024-05-21 15:58:08 -07:00
Awni Hannun
69181e0058 Support non incremental kv cache growth (#766) 2024-05-15 12:56:24 -07:00
Awni Hannun
fad9598372 Fix llama cache check (#763)
* fix llama cache check

* add test
2024-05-08 08:35:54 -07:00
Awni Hannun
ee60e2a9d5 Kv cache (#643)
* in place kv_cache

* fix

* fix kv cache size

* partially fix kv cache dtype

* step kv cache

* multiple of step size

* more teests + kv cache

* more kv cache

* udpate all models to use kv cache
2024-05-08 08:18:13 -07:00
Kevin Wang
c0019c4908 Pad mask with zeros for non-square attention matrices (#715)
* Pad mask with zeros for non-square attention matrices

The current implementation of the mask assumes the attention matrix is square, which is true if there is no cache. However, if one wishes to produce multiple tokens at a time, such as in speculative decoding implementations, a rectangular mask is necessary.

This change pads the bottom of the mask with zeros so multi-token decoding with a cache works correctly.

* Directly create mask instead of padding

* Update llama.py
2024-05-04 16:32:25 -07:00
Awni Hannun
92430df0a0 Fix lora for qwen moe (#743)
* fix lora for qwen moe

* use max seq length in test as well
2024-05-02 21:55:09 -07:00
Thomas Lazarus
5513c4e57d Fixes Typo in Starcoder2 (#740) 2024-04-29 13:14:45 -07:00
Prince Canuma
c012eb173f Add support for OpenELM (#719)
* add openELM

* update splitting logic

* update qkv logic and, transformer and MLP block

* code formatting and fix args

* fix array slicing and remove unused var :)

* add to tuner

* use mx.split for slicing qkv

* merge with phi3

* remove rope scaling logic

* code formatting
2024-04-25 16:49:28 -07:00
Gökdeniz Gülmez
2c1c9e9024 MiniCPM implementation (#685)
* Added support for the MiniCPM architecture

* Added support for the MiniCPM architecture

* Updated utils.py and LORA.md

* Updated utils.py and LORA.md

* Update implementation details for MiniCPM architecture

* Cleaning up

* fixed the missing lm.head layer problem

* Refactor Model class to dynamically handle tied and untied word embeddings

* Quick update

* added a dynamic rope scaling base calucaltion

* Added support for the MiniCPM architecture

* Added support for the MiniCPM architecture

* Updated utils.py and LORA.md

* Updated utils.py and LORA.md

* Update implementation details for MiniCPM architecture

* Cleaning up

* fixed the missing lm.head layer problem

* Refactor Model class to dynamically handle tied and untied word embeddings

* added a dynamic rope scaling base calucaltion

* quick fix and clean up

* clean up again

* removed the MiniCPMNorm class as its not used

* forgot something, sorry

* format

* version bump

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-25 15:29:28 -07:00
Prince Canuma
abcd891851 Add support for phi-3 (#712)
* Add phi-3 modelling

* fix rope scaling warning

* add tests and update tuner utils

* update name and remove sanitize

* fix lora
2024-04-23 09:20:00 -07:00
Awni Hannun
2146bcd7ee Quantize embedding / Update quantize API (#680)
* more async eval

* quantize embedding / update quantize api

* more updates for quantize

* update for quantize embeddings

* update sd quant API

* update sdxl quants

* error for datasets < batch_size

* async

* fix config loading

* fix quant

* fix tests

* fix req

* remove lm head if tie weights is true

* fix test
2024-04-18 18:16:10 -07:00
dmdaksh
7d7e236061 - Removed unused Python imports (#683)
- bert/model.py:10: tree_unflatten
  - bert/model.py:2: dataclass
  - bert/model.py:8: numpy
  - cifar/resnet.py:6: Any
  - clip/model.py:15: tree_flatten
  - clip/model.py:9: Union
  - gcn/main.py:8: download_cora
  - gcn/main.py:9: cross_entropy
  - llms/gguf_llm/models.py:12: tree_flatten, tree_unflatten
  - llms/gguf_llm/models.py:9: numpy
  - llms/mixtral/mixtral.py:12: tree_map
  - llms/mlx_lm/models/dbrx.py:2: Dict, Union
  - llms/mlx_lm/tuner/trainer.py:5: partial
  - llms/speculative_decoding/decoder.py:1: dataclass, field
  - llms/speculative_decoding/decoder.py:2: Optional
  - llms/speculative_decoding/decoder.py:5: mlx.nn
  - llms/speculative_decoding/decoder.py:6: numpy
  - llms/speculative_decoding/main.py:2: glob
  - llms/speculative_decoding/main.py:3: json
  - llms/speculative_decoding/main.py:5: Path
  - llms/speculative_decoding/main.py:8: mlx.nn
  - llms/speculative_decoding/model.py:6: tree_unflatten
  - llms/speculative_decoding/model.py:7: AutoTokenizer
  - llms/tests/test_lora.py:13: yaml_loader
  - lora/lora.py:14: tree_unflatten
  - lora/models.py:11: numpy
  - lora/models.py:3: glob
  - speechcommands/kwt.py:1: Any
  - speechcommands/main.py:7: mlx.data
  - stable_diffusion/stable_diffusion/model_io.py:4: partial
  - whisper/benchmark.py:5: sys
  - whisper/test.py:5: subprocess
  - whisper/whisper/audio.py:6: Optional
  - whisper/whisper/decoding.py:8: mlx.nn
2024-04-16 07:50:32 -07:00
Awni Hannun
d3f8e4aee9 Fix argpartition call in Mixtral and other MOES (#676)
* Update mixtral.py

* fix all moes

---------

Co-authored-by: yuhai-china <yuhai.china@gmail.com>
2024-04-12 11:00:56 -07:00
Awni Hannun
c68aa3c7c3 Stable lm 2 (#666)
* stable lm 2

* test and lora

* version bump

* merge stable models
2024-04-08 14:18:55 -07:00
Awni Hannun
c386dd5f5a Fix for cohere plus (#650)
* fix for cohere plus

* version bump
2024-04-05 14:11:24 -07:00
Prince Canuma
d661440dbb Add support for qwen2moe (#640)
* add sparsemoe block and update decoder logic

* update file name to match HF

* update name

* Code formatting

* update gates calculation

* add support for Qwen2MoE.

* fix pytest

* code formatting and fix missing comma in utils

* Remove decoder sparse step.

Co-authored-by: bozheng-hit <dsoul0621@gmail.com>

* remove gate layer anti-quantisation

* remove unused argument

---------

Co-authored-by: bozheng-hit <dsoul0621@gmail.com>
2024-04-02 11:33:29 -07:00
Chime Ogbuji
f6283ef7ce Configurable LR schedulers (#604)
* Initial config handler and test

* Added means to run from CLI

* Update lora config loading and tests

* Constrain scheduler config (warmup and minimum LR) for each kind

* Update reference to moved schedule_config module

* Minor fix

* Fix typos

* Moved build_schedule and tests

* nits in schedule config

* flake

* fix path

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-29 13:41:10 -07:00
Awni Hannun
b80adbcc3e DBRX (#628)
* dbrx

* format

* format

* comments

* change scores slightly

* remove inadvertant import
2024-03-28 21:03:53 -07:00
Awni Hannun
b8a348c1b8 Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
2024-03-23 07:13:51 -07:00
Awni Hannun
e4b19bb9e1 Make attention faster for a some models (#574)
* make attention faster for a couple models

* remove unused generation flags

* add comment on lora

* include text files as well
2024-03-14 21:35:54 -07:00
Prince Canuma
76c3244cc5 Add support for Cohere's Command-R (#565)
* initial commit for command-R

* update mlp, layernorm, lm_head and model args

* add custom layernorm

* add default to tie_word_embeddings

* add layernorm weight type and refactor

* update layernorm (bias conditional) in model/layers

* fix layer norm use traditional rope

* add test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-13 07:03:36 -07:00
Anchen
3535408c99 chore(mlx-lm): fix tie_word_embeddings for qwen2 (#566)
* chore: fix tie_word_embeddings for qwen2

* chore: default tie_word_embeddings to True
2024-03-12 21:34:32 -07:00
Awni Hannun
8b05bb6d18 [mlx-lm] Use sdpa in llama / mistral model (#515)
* use sdpa

* update a few more models

* version

* fix stablelm type
2024-03-07 17:41:23 -08:00
Awni Hannun
7cdd1b69ac Enable unit testing in Circle and start some MLX LM tests (#545)
* add a few tests for mlx lm

* add a few tests for mlx lm

* add a few tests for mlx lm

* more tests / cleanup
2024-03-07 09:31:57 -08:00
Anchen
8a178f8716 chore: enable tie_word_embeddings config for qwen2 (#544) 2024-03-07 06:11:35 -08:00
Muhtasham Oblokulov
5de7c2ac33 Add tips on porting LLMs from HuggingFace (#523)
* Add tips on porting LLMs from HuggingFace

* Add CONTRIBUTING.md  to mlx-examples-llms

* Refactor imports and update comment in starcoder2.py

* Update llms/mlx_lm/models/starcoder2.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-05 17:43:15 -08:00
Prince Canuma
3fdf85e79d Starcoder2: Update config and change GQA to use repeat (#520)
* update config

* change gqa to use repeat instead of concante

* contribution
2024-03-03 06:12:03 -08:00
Anchen
1e3daea3bb chore(mlx-lm): add missing model_type for starcoder2 (#522) 2024-03-03 06:07:45 -08:00
Muhtasham Oblokulov
81e2a80026 Add Starcoder 2 (#502)
* Add Starcoder2 model and update utils.py

* Refactor model arguments and modules in starcoder2.py

* Refactor FeedForward class to MLP in starcoder2.py

* Fix typo

* pre-commit

* Refactor starcoder2.py: Update model arguments and modules

* Fix LM head and MLP layers

* Rename  input layer norm

* Update bias in linear layers

* Refactor token embeddings in Starcoder2Model

* Rename to standard HF attention layer name

* Add LayerNorm

* Add transposed token embeddings (like in Gemma)

* Refactor MLP and TransformerBlock classes

* Add tie_word_embeddings option to ModelArgs and update Model implementation

* Add conditional check for tying word embeddings in Starcoder2Model

* Fix bias in lm_head linear layer

* Remove unused LayerNorm in stablelm

* Update transformers dependency to use GitHub repository

* fix lm head bug, revert transformer req

* Update RoPE initialization in Attention class

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-02 19:39:23 -08:00
Ashish
261f1280f6 Update to StableLM code (#514)
* StableLM now part of Transformers as stablelm rather than stablelm_epoch; changed config to match new changes

* removing old file

* reference new stablelm
2024-03-01 09:53:38 -08:00
Awni Hannun
f24edfa9dc [mlx-lm] Add precompiled normalizations (#451)
* add precompiled normalizations

* nits
2024-02-22 12:40:55 -08:00
Awni Hannun
ab9172baac Gemma support (#474)
* gemma support

* format

* lora support for gemma
2024-02-21 08:47:13 -08:00
Awni Hannun
8fd953ee2b Support for slerp merging models (#455)
* support for slerp merging models

* docs

* update docs

* format'
2024-02-19 20:37:15 -08:00
devonthomas35
cc671cd1c7 Mixtral: Fix non-default arg follows default exception (#450)
Mixtral models throw the following exception
```
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 119, in <module>
    main(args)
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 96, in main
    model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 278, in load
    model = load_model(model_path)
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 221, in load_model
    model_class, model_args_class = _get_classes(config=config)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 46, in _get_classes
    arch = importlib.import_module(f"mlx_lm.models.{model_type}")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/models/mixtral.py", line 11, in <module>
    @dataclass
     ^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1230, in dataclass
    return wrap(cls)
           ^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1220, in wrap
    return _process_class(cls, init, repr, eq, order, unsafe_hash,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1027, in _process_class
    _init_fn(all_init_fields,
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 545, in _init_fn
    raise TypeError(f'non-default argument {f.name!r} '
TypeError: non-default argument 'model_type' follows default argument
```
2024-02-18 13:30:26 -08:00
Angelos Katharopoulos
f71e965d57 Change gqa to use repeat instead of concatenate (#443) 2024-02-14 17:40:11 -08:00
Awni Hannun
06ddb8414d Fix Qwen2 and SD (#441)
* fix qwen2

* version bump

* fix list shape
2024-02-14 13:43:12 -08:00
Awni Hannun
d4666615bb Lazy import + refactor Lora layer addition (#426)
* lazy model import in mlx_lm

* change lora loading

* fix olmo lora

* remove a bunch of unused stuff from plamo

* move phixtral to mlx-lm and out of llms/
2024-02-12 10:51:02 -08:00