135 Commits

Author SHA1 Message Date
Angelos Katharopoulos
bc08025f41 Add optional quantization types 2024-12-17 22:24:41 -08:00
Billel Mokeddem
845efddc8c Fix decoding manually added tokens (#1164)
* Fix decoding manually added tokens

* fix + test

* nit

* nit

* no lag bpe

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-12-17 09:54:29 -08:00
Prince Canuma
dfa4dd6c93 Add support for cohere2 (#1157)
* add support for cohere2

* revert to act_fn to silu

* fix tests and sliding window attention

* add tests

* add to tuner

* fix sliding window

* add coauthor :)

Co-authored-by: n8programs <43304488+N8python@users.noreply.github.com>

* Add rotating kvcache to save space

* some nits

* style

* nits

---------

Co-authored-by: n8programs <43304488+N8python@users.noreply.github.com>
Co-authored-by: N8 <n8@n8programs.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-12-16 08:01:03 -08:00
Ikko Eltociear Ashimine
fc0674d2d8 chore: update evaluate.py (#1159)
occurence -> occurrence
2024-12-15 06:06:29 -08:00
Awni Hannun
9f2ea5892e Bpe stream without space (#1154)
* bpe streaming detokenization without space

* version bump
2024-12-12 13:13:50 -08:00
Awni Hannun
2ba0e36683 [mlx-lm] Use top p in server (#1144)
* use top p in server

* couple other fixes
2024-12-12 11:12:21 -08:00
Angelos Katharopoulos
19abf3dcaa Replace unicode errors instead of raising exception (#1146) 2024-12-12 11:10:41 -08:00
madroid
06af3c9b0e Add finish_reason in GenerationResponse (#1153) 2024-12-12 10:37:40 -08:00
Awni Hannun
77b42b7c8b fix llava (#1149) 2024-12-12 10:37:26 -08:00
Alex Barron
135c5818c1 Fix max_tokens (#1148) 2024-12-10 11:26:04 -08:00
madroid
12083c4b7e Support for multiple EOS tokens (#1141)
* Support for multiple EOS tokens

* Change _eos_token_ids type from list to set

* Remove model_config & add eos_token_id

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-12-09 08:53:58 -08:00
n8programs
5687d5b99b Adds EXAONE architecture. (#1145)
* Adds EXAONE architecture.

* nits + format

* format

* clean up and fix rope

* clean up and fix rope

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-12-09 07:58:25 -08:00
hehua2008
893b3f085e Change Flux default max_shift to 1.15 to match the official one (#1137) 2024-12-08 23:29:48 -08:00
Peter Sibley
ed91bbc4dc Fix final message at end of flux training (#1143) 2024-12-08 23:01:53 -08:00
hehua2008
1fd6aae871 Fix flux training with batch size (#1135)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-12-08 22:09:04 -08:00
Alex Barron
2211b27388 Mixed Quantizations (#1132)
* saving/loading mixed quantizations

* comment

* add bits per weight

* more concise bpw

* count bias too
2024-12-08 14:21:50 -08:00
Alex Barron
cd8cf28c39 mlx_lm.evaluate (#1140)
* Add evaluation script

* only write top level results

* add lm eval version

* typo

* create output dir

* relative import

* comment

---------

Co-authored-by: David Grangier <dgrangier@users.noreply.github.com>
2024-12-08 12:20:10 -08:00
vb
1727959a27 Add mentions of MLX-my-repo. (#1129)
* Add mentions of MLX-my-repo.

* simplify

* move

* move

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-12-03 19:21:39 -08:00
Awni Hannun
1963df8565 Allow prompt callback to generate_step (#1133)
* allow prompt callback and use in cache_prompt

* nit

* comments

* bump version
2024-12-03 16:17:14 -08:00
sakares saengkaew
0ca162cfb2 Fix data_iter in prepare_dataset from speechcommands example (#1113) 2024-12-02 23:56:07 -08:00
Angelos Katharopoulos
eb9277f574 Allow loading from diffusers ckpt (#1117) 2024-12-02 13:15:50 -08:00
hehua2008
2a9294a5f0 Fix bug in FluxSampler.timesteps method (#1131) 2024-12-02 13:15:19 -08:00
Awni Hannun
8801beb66f Add olmo2 (#1128)
* add olmo2

* add olmo2
2024-12-02 11:42:58 -08:00
Neil Mehta
cefe793ae0 Accept mx.array type for prompt argument for stream_generate (#1125)
* Accept mx.array type for prompt argument for stream_generate

* Fix formatting
2024-11-26 16:51:55 -08:00
Awni Hannun
cfc29c29f4 Put prompt processing in same stream (#1122)
* put prompt processing in same stream

* patch
2024-11-25 09:47:00 -08:00
madroid
a5e173802e docs: update stream_generate return type annotation (#1121)
Improve documentation clarity by:
1. Fix return type annotation to correctly reflect GenerationResponse
2. Simplify docstring by referencing GenerationResponse class
3. Remove redundant field descriptions
2024-11-25 08:10:14 -08:00
Remixer Dec
adaab81029 Allow converting models from local directories (#1118) 2024-11-24 16:41:06 -08:00
Kevin Conner
0ffdb6dd20 Fix object property value in mlx_lm.server chat completions response to match OpenAI spec (#1119)
These were "chat.completions" and "chat.completions.chunk"
but should be "chat.completion" and "chat.completion.chunk"
for compatibility with clients expecting an OpenAI API.

In particular, this solves a problem in which aider 0.64.1 reports
hitting a token limit on any completion request, no matter how small,
despite apparently correct counts in the usage property.

Refer to:

https://platform.openai.com/docs/api-reference/chat/object

> object string
> The object type, which is always chat.completion.

https://platform.openai.com/docs/api-reference/chat/streaming

> object string
> The object type, which is always chat.completion.chunk.
2024-11-24 16:37:37 -08:00
Awni Hannun
0f135396ae Generation refactor: part 2 (#1099)
* unify with stream_generate

* fixes

* nit

* some cleanup, warnings, tests

* fix test + faster min p + test

* version
2024-11-23 11:47:06 -08:00
Awni Hannun
004eb4cc9d Tencent HunYuan MOE model (#1100)
* hunyuan

* fix

* format str

* default trust remote code for tokenizer, allow system prompt to be configurable
2024-11-23 11:06:26 -08:00
Angelos Katharopoulos
042280ce50 Fix format (#1115) 2024-11-20 16:15:53 -08:00
Valentin Roussellet
60c7b80350 Pass seed to sd img2img (#1114) 2024-11-20 15:21:52 -08:00
Alban Lecocq
bd6d910ca3 [MLX LM] Fix f-string formatting in memory warning message (#1105)
* Fix missing f-prefix for string interpolation in model size warning
* Ensures proper display of memory values in MB for model and max size
2024-11-13 06:14:03 -08:00
madroid
1e07660184 FLUX: save train config (#1049) 2024-11-08 17:15:19 -08:00
Awni Hannun
657b4cc0aa [MLX LM] Sampler refactor + a few improvements (#1094)
* starting

* refactor sampler/processor and a few improvements

* fix stream

* fix stream generate

* fix eos handling in stream generate
2024-11-07 16:15:24 -08:00
Angelos Katharopoulos
ed9e81dd58 Fix rotating kv cache size (#1093) 2024-11-05 10:24:24 -08:00
Awni Hannun
6fd1f70f73 fix spm decoder multi-byte (#1092) 2024-11-05 06:06:26 -08:00
Anthony Wu
4394633ce0 mlx_whisper: add support for audio input from stdin (#1012)
* add support for audio and input name from stdin

* refactored to stdin - arg, and output-name template

* fix bugs, add test coverage

* fix doc to match arg rename

* some nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-11-04 14:02:13 -08:00
ilyasch2
3b526f0aa1 Add support for falcon-mamba (#1074)
* Add support for falcon-mamba

* nits

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-11-04 12:23:30 -08:00
Anchen
82e3338987 chore(mlx-lm): add max token arg for mlx_lm.chat (#1089)
* chore(mlx-lm): add max token arg for mlx_lm.chat

* chore: update the default max token value
2024-11-04 06:06:34 -08:00
Angelos Katharopoulos
331148d8ec Enable distributed LoRA training (#821) 2024-11-02 18:02:31 -07:00
Awni Hannun
29c954f4cb fix (#1082) 2024-11-02 13:51:38 -07:00
Awni Hannun
0f799947d0 fix (#1079) 2024-11-01 16:30:32 -07:00
Awni Hannun
e510987870 Clear cache every now and then (#1081)
* clear cache every now and then

* don't need user arg anymore
2024-11-01 14:15:32 -07:00
Awni Hannun
8160e0c4e5 Whisper improvements (#1080)
* use safetensors in whisper

* speed up decoder

* version
2024-11-01 10:52:28 -07:00
Alex Barron
85ffd2c96a Quantized KV Cache (#1075)
* add QuantizedKVCache

* simplify

* add tests

* single sdpa function

* fix sed

* in place

* fix tests

* support different k and v head dims
2024-10-31 16:59:52 -07:00
Awni Hannun
9f34fdbda4 Wire models in MLX LM (#1069)
* wired in MLX LM

* fix synch

* comment + nit

* version

* mlx lm version

* bump to 0.19.2
2024-10-31 08:17:14 -07:00
Awni Hannun
8fe9539af7 Fix detokenizer space match for quote (#1072)
* fix + test

* remove transformer flax/torch warning

* format
2024-10-27 15:06:07 -07:00
hschaeufler
ab4bf05c6e Update lora_config.yaml with new param: num_layers (#1068) 2024-10-26 09:34:46 -07:00
Saurav Maheshkar
4971462bf0 feat(clip): add linear probe evaluation script (#960) 2024-10-24 21:56:17 -07:00
Awni Hannun
9000e280ae fix mamba models conversion (#1065) 2024-10-22 15:44:08 -07:00
madroid
d1d480867b LoRA: update tools datasets docs (#1063)
* LoRA: update tools datasets docs

* nits

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-22 12:19:11 -07:00
Awni Hannun
66e7bcb886 override dtype with quant (#1062) 2024-10-22 09:56:45 -07:00
aronson
743763bc2e Handle empty string case in maybe_trim_space (#1055)
* Handle empty string case in maybe_trim_space

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-20 20:46:43 -07:00
madroid
f491d473a3 FLUX: Optimize dataset loading logic (#1038) 2024-10-15 10:37:45 -07:00
Zak B. Elep
3d62b058a4 fix: typo on flux model preloading (#1050) 2024-10-15 09:13:01 -07:00
madroid
bbd2003047 FLUX: update README.md (#1036) 2024-10-14 11:21:41 -07:00
Awni Hannun
605c4854f1 Prompt caching in mlx_lm.server (#1026)
* caching in server

* nits

* fix tests

* don't throw if no metal

* comments
2024-10-14 10:57:22 -07:00
Awni Hannun
8dca1a2f60 Tokenizer updates + tests (#1024)
* tokenizer updates + tests

* nit

* add can_trim_prompt_cache

* nits
2024-10-14 10:48:46 -07:00
Awni Hannun
6c368f2124 bump mac tests to use py39 (#1047) 2024-10-14 10:40:36 -07:00
Awni Hannun
c799133998 Make llm async eval less brittle (#1040)
* Make llm async eval less brittle

* nit
2024-10-14 10:25:24 -07:00
Seitaro Sugawara
1e0cda68c6 Update README.md (#1045)
* Update README.md

A small typo was fixed in the musicgen README.md.

* Update musicgen/README.md

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-10-14 06:21:25 -07:00
Shunta Saito
7612c646f3 Fix PLaMo model to support Grouped Query Attention (#1037) 2024-10-12 15:26:50 -07:00
Ivan Fioravanti
d8611dd69f Small typo fixed in flux README.md (#1035) 2024-10-12 06:14:01 -07:00
Angelos Katharopoulos
a5f2bab070 Add FLUX finetuning (#1028) 2024-10-11 21:17:41 -07:00
Alex Barron
d72fdeb4ee MusicGen (#1020)
* Add MusicGen model

* add benchmarks

* change to from_pretrained

* symlinks

* add readme and requirements

* fix readme

* readme
2024-10-11 10:16:20 -07:00
Awni Hannun
4360e7ccec clear cache during prompt processing (#1027) 2024-10-09 16:48:32 -07:00
Awni Hannun
b7373cb44f fix long prompt generations (#1023) 2024-10-09 11:09:36 -07:00
Awni Hannun
fca087be49 More cache improvements (#1015)
* fix rotating kv cache for chat use case

* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

* nit in chat

* fix tests

* fix tests

* fix tests

* docs

* chat command

* comments + docs

* Define meta_state on all Cache implementations

* fixes + trim_prompt_cache api

* fix default model

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-10-07 20:45:51 -07:00
Awni Hannun
9bc53fc210 convert (#1006) 2024-10-02 13:13:33 -07:00
madroid
36c1d8e8dc Server: support function calling (#1003) 2024-10-02 12:36:07 -07:00
nathan
0866e23a67 repetiton_penalty and logits_bias just using logits_processors (#1004)
* refactor of repetition_penalty and logits_bias to use logits_processor

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-30 08:49:03 -07:00
Zai Thottakath
418d9a5511 Feature: QDoRA (#891)
* feat: QDoRA with tests and a small bug fix for recalculation of self.m

* some simplifications and fixes

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-30 08:01:11 -07:00
madroid
aa1c8abdc6 LoRA: Support HuggingFace dataset via data parameter (#996)
* LoRA: support huggingface dataset via `data` argument

* LoRA: Extract the load_custom_hf_dataset function

* LoRA: split small functions

* fix spelling errors

* handle load hf dataset error

* fix pre-commit lint

* update data argument help

* nits and doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-30 07:36:21 -07:00
Gökdeniz Gülmez
50e5ca81a8 Adding full finetuning (#903)
* Adding full model weights finetuning

* Updating the LORA.md and ACKNOWLEDGMENTS.md files.

* removing --use-dora and --fulll-training and adding --fine-tune-type

* some clean up

* reformating and fixing dora training

* updated CONFIG_DEFAULTS

* update config example

* update in the config example fie

* Update LORA.md

* merge and commit

* adding argument for dora linear layer

* clean up

* clean up in the example yaml file

* fix

* final fix before sending

* small addition to re md file

* fix for loading the fully trained model by saving all the files and configs correctly

* clean up

* removing the unnesesairy files

* changing lora layers back to 16

* removed max file size

* nits

* resolve merge

* some consistency changes

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-29 17:12:47 -07:00
madroid
7ec2021bb9 LoRA: support tools(function calling) format datasets (#995)
* LoRA: support fine-tuning tools datasets

* LoRA: Split small function

* LoRA: add tools format to lora docs

* LoRA: pre-commit fix

* Revert "LoRA: pre-commit fix"

This reverts commit b94b7e0fe7.

* Revert "LoRA: Split small function"

This reverts commit 3f6a5f19fd.

* LoRA: remove ToolsDataset

In a JSONL file, not all data is required to include the tools value.

* nit in readme

* nit in readme

* nit in readme

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 10:41:36 -07:00
nathan
ace2bb5890 Add logits_processor option to generate_step function (#983)
* Add logits_processor option for the generation as in huggingface transformers library

* concatenation correction

* Rename the tokens variable for clarity

* remove the logit_bias argument from generate_step method

* fix the variable name

* nits + test

* test

* add back logit bias + test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 10:08:49 -07:00
jamesm131
d812516d3d Add /v1/models endpoint to mlx_lm.server (#984)
* Add 'models' endpoint to server

* Add test for new 'models' server endpoint

* Check hf_cache for mlx models

* update tests to check hf_cache for models

* simplify test

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 07:21:11 -07:00
Gökdeniz Gülmez
76710f61af Adding support for mamba (#940)
* initial commit

* initial commit

* Adding first lines

* adding x, and dt projection layers

* adding the clamping mechanism

* First succesful inference

* last commit for today - added custom geenrate function and it works as expected, will try training and then with loading a model from the hub

* clean up

* save up

* almost

* update

* update

* fixed cache handeling

* fixed loading

* added seperate generat_step method in the model and also in the utils to automaticaly use the generate step mthod in the model class

* quick update

* still not working

* save

* still not working

* initial commit

* utils.py logits = logits[:, -1, :] TypeError: tuple indices must be integers or slices, not tuple

* update

* update

* Fixing the Batching Depfwise Comnvolution and multi token input

* fixing generate and logits outputs

* Done!

* Fixing the cache handling, generating works now trying training

* update ACKNOWLEDGEMENTS

* removing the model_type if stuff in the _step loop in generate_step and adding MambaCache in base.py for training easier generations and removing mamba in tuner/utils.

* quick clean up

* update trainer/utils for right initialisation of the layers for LoRA, but not working.

* clean up

* Forther update to trainer/utils for correct layer selection. Successfull training

* removing extra mamba-infer.py file

* clean up, reformating will come later

* reformat and big clean up, final commit

* some speedups and cleanups

* fix test

* nits

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 07:02:53 -07:00
Cheng
e776c970f7 Fix llava model when using text-only prompt (#998) 2024-09-25 07:19:41 -07:00
Awni Hannun
9bb2dd62f3 Encodec (#991)
* initial encodec

* works

* nits

* use fast group norm

* fix for rnn layer

* fix mlx version

* use custom LSTM kernel

* audio encodec

* fix example, support batched inference

* nits
2024-09-23 11:39:25 -07:00
Angelos Katharopoulos
796d5e40e4 Fix export to gguf (#993) 2024-09-20 13:33:45 -07:00
Awni Hannun
f530f56df2 don't use internal exception (#990) 2024-09-17 16:22:48 -07:00
Awni Hannun
6c2369e4b9 Fix bug in upload + docs nit (#981)
* fix bug in upload + docs nit

* nit
2024-09-07 14:46:57 -07:00
Awni Hannun
c3e3411756 Update LLM generation docs to use chat template (#973)
* fix docs

* add template to model cards as well

* revert

* version
2024-09-07 06:06:15 -07:00
Angelos Katharopoulos
324184d670 Fix the cache_prompt (#979) 2024-09-06 20:19:27 -07:00
madroid
bd29aec299 Support HuggingFace model tree (#957)
* Hub: Update quantization configuration fields

* Hub: add base_model metadata

* Hub: add quantization_config for model tree Quantized type

* Hub: update quantization_config value

* Hub: remove config print
2024-09-04 06:19:32 -07:00
Chime Ogbuji
83a209e200 Add prompt piping (#962)
* Initial commit of --prompt-only and prompt from STDIN feature

* Switch to using --verbose instead of --prompt-only

* Fix capitalization typo

* Fix reference to changed option name

* Update exception text
2024-09-03 13:29:10 -07:00
James Zhao
bf921afcbe Make sure to import the correct "version" module when installing mlx_whisper and mlx_lm from local source code. (#969)
* Make sure to import the correct "version" module when installing the
mlx_whisper package from local source code.

* Make sure to import the correct "version" module when installing the mlx_lm package from local source code

* fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-03 13:16:21 -07:00
Awni Hannun
3c6e8b11af fix (#965) 2024-08-30 05:56:27 -07:00
L
fc93c55723 feat(mlx_lm): Nemotron (#949)
* feat: Nemotron

https://huggingface.co/nvidia/Minitron-4B-Base

This is basically Llama with partial RoPE and LayerNorm instead of
BatchNorm. Also they add 1 to the LayerNorm weight for some reason.

* fixup! feat: Nemotron

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-29 21:08:57 -07:00
Awni Hannun
b1186e2a81 Docs on prompt scaling (#963)
* docs on prompt scaling

* remove unused var

* nits
2024-08-29 15:05:17 -07:00
Angelos Katharopoulos
1003a8b2dd Add the ability to load the KV cache from a file (#956) 2024-08-28 22:11:45 -07:00
Angelos Katharopoulos
7f8c961287 Fix setattr for the TokenizerWrapper (#961) 2024-08-28 14:47:33 -07:00
Nripesh Niketan
bf21789b17 chore: update black pre-commit hooks to latest versions (#955) 2024-08-26 07:54:23 -07:00
Prince Canuma
b5e18ef1e3 Add Phi-3.5-MoE (#946)
* add phimoe

* add phimoe to tunner

* add switch_mlp

* fix SuScaled args

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-24 06:52:33 -07:00
Awni Hannun
6731254e76 Use fast rope (#945)
* use fast rope

* fix llama

* use fast rope for llama3.1

* requires unreleased mlx

* fix su

* fix deepseek v2

* only one of base or freqs

* nit

* fix

* hard code freqs
2024-08-23 13:18:51 -07:00
Awni Hannun
58591a1b41 fine tune deepseek (#932) 2024-08-22 10:41:21 -07:00
L
0164d2058b feat: DeepSeek MoE v1 (#942)
* feat: deepseek v1

DeepSeek is still releasing models on the DeepSeek V1 architecture.

```sh
mlx_lm.convert --hf-path deepseek-ai/DeepSeek-Prover-V1.5-RL --mlx-path DeepSeek-Prover-V1.5-RL-8bit --q-bits 8 -q
mlx_lm.generate --model DeepSeek-Prover-V1.5-RL-8bit --ignore-chat-template --max-tokens 512 --prompt 'import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat

/-- The second and fourth terms of a geometric sequence are $2$ and $6$. Which of the following is a possible first term?
Show that it is $\frac{2\sqrt{3}}{3}$.-/
theorem amc12b_2003_p6 (a r : ℝ) (u : ℕ → ℝ) (h₀ : ∀ k, u k = a * r ^ k) (h₁ : u 1 = 2)
  (h₂ : u 3 = 6) : u 0 = 2 / Real.sqrt 3 ∨ u 0 = -(2 / Real.sqrt 3) := by'
```

* nits

* nits

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-17 07:18:09 -07:00
Awni Hannun
7be292c0c9 Handle longer prompt/generation (#931)
* rebase

* nits

* nit

* fix rotating cache with step prefill

* update version
2024-08-16 15:28:39 -07:00
madroid
e196fa3208 Whisper: Support command line (#746)
* Whisper: Add CLI command

* Whisper: Prevent precision loss when converting to words dictionary

* Whisper: disable json ensure_ascii

* Whisper: add cli setup config

* Whisper: pre-commit

* Whisper: Adjust the _ in the command line arguments to -

* nits

* version + readme

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-16 10:35:44 -07:00
Zai Thottakath
4e01700816 Allow the entire model to be targed for LoRA and DoRA fine tuning: LoRA and DoRA embeddings with small DoRALinear bug fix (#914)
* feature: LoRA adapter for Embeddings

* feature: wire in LoRAEmbedding into the tuner. Allow the embedding and non model.layers Linear layers to be targeted for fine tuning

* feature: DoRA adapter for Embeddings

* feature: wire in DoRAEmbedding

* bugfix: ensure self.m is recalculated when the linear layer is changed in DoRALinear.from_linear

* refactor: prefer from_base over from_linear or from_embedding. prefer fuse over to_linear or to_embedding

* cleanup: remove unused imports in test_dora.py

* refactor: remove unnecessary non_layer_modules

* cleanup: remove wrong comments for lora embedding dropout. remove uncessary parens in dora embedding dropout

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-16 07:38:36 -07:00
Chime Ogbuji
c50971e860 Min P implementation (#926)
* Min P implementation

* Change default to 0 (no min_p)

* nits

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-15 15:45:02 -07:00
Awni Hannun
9b83004631 Faster sampling with mx.compile (#937)
* faster sampling with compile

* fix test
2024-08-15 11:29:09 -07:00
Awni Hannun
95840f32e2 Fix whipser conversion for safetensors models (#935)
* fix whipser conversion for safetensor only. error in mlx lm for existing paths

* fix tests
2024-08-14 10:22:04 -07:00
Awni Hannun
33905447f9 Whisper updates to allow HF models (#923)
* simplify conversion and update convert for HF models

* use npz for compat

* fixes

* fixes

* fix gguf

* allow user supplied path
2024-08-09 11:11:58 -07:00
tidely
df744c98e6 Predict stop sequence matches during streaming (#541)
* Predict stop sequence matches during streaming

Check for overlap of stop sequences and the tokens array for potential sequence matches after more tokens get generated. Generate tokens until we can confirm that the stop sequence is not met.

* fix typo

* Change sequence_overlap logic

* range isn't inclusive, add 1 to max_overlap

* Add test_server.py

Added a test for the sequence_overlap method

* nits

* eos sequence

* finalize

---------

Co-authored-by: Y4hL <43219534+Y4hL@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-06 15:24:15 -07:00
Khush Gupta
8fa12b0058 Adapters loading (#902)
* Added functionality to load in adapters through post-requests so you do not need to restart the server

* ran pre-commit

* nits

* fix test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-01 16:18:18 -07:00
madroid
85dc76f6e0 Server: support stream_options (#913)
* Server: support stream_options

see https://x.com/OpenAIDevs/status/1787573348496773423

* Server: support stream_options

* Server: check None type
2024-07-26 08:58:52 -07:00
otriscon
46da74fea2 Unify attention mask in LLMs (#911)
* Unify attention mask creation in LLMs.

Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc
code to create a mask for the attention mechanism. This usually takes the form:

```
    mask = None
    if h.shape[1] > 1:
        mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
        mask = mask.astype(h.dtype)
```

This correctly creates a mask only if the input consists of more than one token.
But this code assumes the multi-token input is at the beginning of inference.
If, for example, we are evaluating multiple tokens because of speculative
decoding or prompt cache reuse, this mask will not have the correct shape and
and will cause the raising of an exception in the attention computation.

Some of the models correctly implement the mask creation with code like this:

```
    mask = None
    if h.shape[1] > 1:
        mask = create_additive_causal_mask(
            h.shape[1], cache[0].offset if cache is not None else 0
        )
        mask = mask.astype(h.dtype)
```

This commit unifies the attention mask creation for all models with a new
function `create_attention_mask`, reducing code duplication and helping all
models support inference performance enhancements like those mentioned above.

* Allow batches in LLM key-value cache

The current implementation of the LLM key-value cache assumes that
the input batch is of size 1. Input batching (evaluating multiple
alterative inputs at the same time) can be a valuable tool for
speculative sampling and other techniques.

This change removes the hard-coded batch size from the code that
resizes the key-value cache.

* Simplify causal mask creation

Use the same codepath regardless of whether there's an offset or
not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717).

* Use old-style type annotation to avoid linter error
2024-07-25 16:45:22 -07:00
Anchen
7a3ab1620a support load model by custom get_model_classes (#899)
* feature(mlx_lm): support load model by custom get classes

* rename the param
2024-07-25 11:01:17 -07:00
Alex Cheema
cd8efc7fbc Add support for Llama-3.1 (#907)
* add dynamicNTK scaling rope

* remove unused var

* fix rope base

* llama3.1 fixes

* TODO for rope eval

* vectorise llama3 base freq calculation

* removed the arbitrary 2.0 rope_scale default case

* fix slow llama3.1 generation by evaluating stateless part of DynamicNTKScalingRoPE in init

* nits + format

* use mx.pi

* fix tests and add test for 3.1

---------

Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-23 13:21:32 -07:00
M. Ali Bayram
47060a8130 refactor: add force_download parameter to get_model_path function (#800) 2024-07-23 13:10:20 -07:00
Prince Canuma
3f337e0f0a Add Mistral NeMo (fix) (#895)
* fix head_dim

* Update llms/mlx_lm/models/llama.py

* fix kv error

* formatting

* Delete test.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-07-22 06:09:24 -07:00
Prince Canuma
3d365b612a Add support for InternLM-2.5 (#871)
* fix internlm-2

* formatting

* add dynamic ntk rope

* formatting

* move dynamic scaling rope to intermlm2.py

* add default max_position_embeddings
2024-07-17 16:38:22 -07:00
Anchen
561dcf5643 Add support for deepseek coder v2 lite (#882)
* feat: add support for deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct

* fix softmax + some cleanup

* more nits

* fix rope

* fix original_max_position_embeddings in rope

* fix original_max_position_embeddings in rope config

* add group greedy

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-17 07:23:28 -07:00
Awni Hannun
f0c6c6e226 keep the server in a valid state (#889) 2024-07-15 18:35:36 -07:00
JosefAlbers
bfc1f2763b longrope (#886) 2024-07-12 07:19:11 -07:00
Chime Ogbuji
8bf397e450 Pass use_dora parameter to linear_to_lora_layers (#885) 2024-07-11 14:34:34 -07:00
nicolov
fbe3247772 Add GPT-neox model (#863) 2024-07-11 06:13:17 -07:00
James A Capozzoli
9717307ff0 Validation with full data set, results in NaN validation score (#879)
* CLI arguments may set num_batches to -1

The CLI arguments allow you to validate with the entire dataset by passing a negative one value, but this quickly results in a division by zero `NaN` to appear as the validation score!

* Must properly assemble the mini batches when validating with entire dataset.

Tested locally, a validation of a novel took about an hour, with a loss of 0.928. Thanks @awni for the correction!

* Set up the pre-commit hooks and run them so that black may format lora.py.
2024-07-10 08:36:11 -07:00
Alex Wozniakowski
63800c8feb Example of response generation with optional arguments (#853)
* Generate response with optional arguments

* Reference response generation example

* Include transformers and sentencepiece

* Update example to run Mistral-7B-Instruct-v0.3

* Link to generation example

* Style changes from pre-commit
2024-07-09 06:49:59 -07:00
Awni Hannun
68e88d42fb Fix server for openai package (#877)
* fix

* fixes for 9b
2024-07-08 12:34:31 -07:00
Awni Hannun
20e221f7f7 Add recurrent gemma (#856)
* add recurrent gemma

* fix window cache
2024-07-07 12:10:04 -07:00
n8programs
1e05aef344 Add logit soft capping to gemma, and fix precision issues (#857)
* Add logit soft capping to gemma, and fix precision issues

Gemma was babbling nonsense - so I figured out it was due to not having logit softcapping and precision issues causing NaNs (so I implemented the softcapping and added more float32 inference). gemma-27b-it-4bit now works flawlessly (or near-flawlessly, no sliding-window attention).

* get rid of comments

* get rid of last comments (sry lol)

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-02 07:52:39 -07:00
Angelos Katharopoulos
f212b770d8 Server loads the model on demand from the request (#851) 2024-06-27 11:37:57 -07:00
Awni Hannun
538339b599 gemma2 (#855) 2024-06-27 10:06:28 -07:00
Awni Hannun
9f10728145 fix yi (#852) 2024-06-27 06:38:19 -07:00
Volodymyr Kyrylov
7979b84a9e transformer_lm: add --dataset enwik8 (#838)
* transformer_lm: add --dataset enwik8

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-26 11:59:01 -07:00
Chime Ogbuji
df6bc09d74 Configuration-based use of HF hub-hosted datasets for training (#701)
* Add hf_dataset configuration for using HF hub-hosted datasets for (Q)LoRA training

* Pre-commit formatting

* Fix YAML config example

* Print DS info

* Include name

* Add hf_dataset parameter default

* Remove TextHFDataset and CompletionsHFDataset and use Dataset and CompletionsDataset instead, adding a text_key constructor argument to the former (and changing it to work with a provided data structure instead of just from a JSON file), and prompt_key and completion_key arguments to the latter with defaults for backwards compatibility.

* nits

* update docs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-26 10:20:50 -07:00
Chime Ogbuji
1d701a1831 Logprobs info to completion API (#806)
* Initial implementation

* Fix handling of return_step_logits in return

* Fixed OpenAI parameter expectations and logprob structure and datatypes

* pre-commit black formatting

* Remove unused parameter

* fix log probs

* fix colorize

* nits in server

* nits in server

* Fix top_logprobs structure (a dict) and include tokens in logprobs response

* nits

* fix types

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-23 10:35:13 -07:00
Yi Wang
a7598e9456 Fix mypy errors with models/{qwen2,qwen2_moe,startcoder2}.py (#835)
* Fix starcoder.py

* Fix qwen2

* Remvoe unnecessary assert not None
2024-06-14 09:44:50 -07:00
Awni Hannun
d8b073e3a7 Add eos token to lora fine-tunes (#818)
* add eos token to lora fine-tunes

* Comment
2024-06-12 07:44:21 -07:00
Nada Amin
3cc58e17fb Tweaks to run dspy-produced calls to the server, with gemma template. (#810)
* Tweaks to run dspy-produced calls to the server, with gemma template.

following comment https://github.com/stanfordnlp/dspy/issues/385#issuecomment-1998939936

can try it out with:
```sh
python -m server --model mlx-community/gemma-1.1-7b-it-4bit --port 1143
```
modulo patching the relative imports in server.py
```
-from .tokenizer_utils import TokenizerWrapper
-from .utils import generate_step, load
+from mlx_lm.tokenizer_utils import TokenizerWrapper
+from mlx_lm.utils import generate_step, load
```

and then, ont the dspy side:
```python
import dspy
lm = dspy.OpenAI(model_type="chat", api_base="http://localhost:11434/v1/", api_key="not_needed", max_tokens=250)
lm("hello")
```

* simpler way to validate float or int

* remove logic that works around incompatible templates, too gemma specific

* tweak messages for common denominator

* use generate.py workaround for DBXR

* put behind flag

* oops

* Solution to chat template issue: pass in a custom template!

The template should likely adhere to the OpenAI chat model.
Here is such a template for Gemma.

--chat-template "{{ bos_token }}{% set extra_system = '' %}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{% if role == 'system' %}{% set extra_system = extra_system + message['content'] %}{% else %}{% if role == 'user' and extra_system %}{% set message_system = 'System: ' + extra_system %}{% else %}{% set message_system = '' %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message_system + message['content'] | trim + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"

* remove convoluted solution

* Tweak for when None is provided explicitly, and must be set to [] too.

For example, the outlines library provides None explicitly.

* style

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-12 07:17:06 -07:00
Yi Wang
6da07fb1b0 make models/phi3.py and models/phi3small.py compatible with mypy (#833) 2024-06-12 06:53:55 -07:00
149 changed files with 13930 additions and 1962 deletions

View File

@@ -26,13 +26,13 @@ jobs:
- run:
name: Install dependencies
command: |
brew install python@3.8
python3.8 -m venv env
brew install python@3.9
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install unittest-xml-reporting
cd llms/
pip install -e .
pip install -e ".[testing]"
- run:
name: Run Python tests
command: |

3
.gitignore vendored
View File

@@ -6,6 +6,9 @@ __pycache__/
# C extensions
*.so
# Vim
*.swp
# Distribution / packaging
.Python
build/

View File

@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.3.0
rev: 24.8.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View File

@@ -14,3 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.

View File

@@ -20,18 +20,22 @@ Some more useful examples are listed below.
### Image Models
- Generating images
- [FLUX](flux)
- [Stable Diffusion or SDXL](stable_diffusion)
- Image classification using [ResNets on CIFAR-10](cifar).
- Generating images with [Stable Diffusion or SDXL](stable_diffusion).
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
### Audio Models
- Speech recognition with [OpenAI's Whisper](whisper).
- Audio compression and generation with [Meta's EnCodec](encodec).
### Multimodal models
- Joint text and image embeddings with [CLIP](clip).
- Text generation from image and text inputs with [LLaVA](llava).
- Image segmentation with [Segment Anything (SAM)](segment_anything).
### Other Models

View File

@@ -63,7 +63,7 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
)
def get_model_path(path_or_hf_repo: str) -> Path:
def get_model_path(path_or_hf_repo: str, force_download: bool = False) -> Path:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
@@ -74,6 +74,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
"*.json",
"*.txt",
],
force_download=force_download,
)
)
return model_path
@@ -107,9 +108,15 @@ if __name__ == "__main__":
type=str,
default="float32",
)
parser.add_argument(
"-f",
"--force-download",
help="Force download the model from Hugging Face.",
action="store_true",
)
args = parser.parse_args()
torch_path = get_model_path(args.hf_repo)
torch_path = get_model_path(args.hf_repo, args.force_download)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)

56
clip/linear_probe.py Normal file
View File

@@ -0,0 +1,56 @@
# Mirror of the Linear Probe Evaluation Script
# from the official CLIP Repository.
import mlx.core as mx
import numpy as np
from image_processor import CLIPImageProcessor
from mlx.data.datasets import load_cifar10
from model import CLIPModel
from PIL import Image
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
def get_cifar10(batch_size, root=None):
tr = load_cifar10(root=root).batch(batch_size)
test = load_cifar10(root=root, train=False).batch(batch_size)
return tr, test
def get_features(model, image_proc, iter):
all_features = []
all_labels = []
for batch in tqdm(iter):
image, label = batch["image"], batch["label"]
x = image_proc([Image.fromarray(im) for im in image])
y = mx.array(label)
image_embeds = model.get_image_features(x)
mx.eval(image_embeds)
all_features.append(image_embeds)
all_labels.append(y)
return mx.concatenate(all_features), mx.concatenate(all_labels)
if __name__ == "__main__":
model = CLIPModel.from_pretrained("mlx_model")
image_proc = CLIPImageProcessor.from_pretrained("mlx_model")
train_iter, test_iter = get_cifar10(batch_size=256)
train_features, train_labels = get_features(model, image_proc, train_iter)
test_features, test_labels = get_features(model, image_proc, test_iter)
# Perform logistic regression
# NOTE: The value of C should be determined via a hyperparameter sweep
# using a validation split
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)
# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = (test_labels.squeeze() == predictions).mean().item() * 100
print(f"Accuracy = {accuracy:.3f}")

View File

@@ -1,4 +1,5 @@
mlx
mlx-data
numpy
transformers
torch

84
encodec/README.md Normal file
View File

@@ -0,0 +1,84 @@
# EnCodec
An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and
generate audio.
### Setup
Install the requirements:
```
pip install -r requirements.txt
```
Optionally install FFmpeg and SciPy for loading and saving audio files,
respectively.
Install [FFmpeg](https://ffmpeg.org/):
```
# on macOS using Homebrew (https://brew.sh/)
brew install ffmpeg
```
Install SciPy:
```
pip install scipy
```
### Example
An example using the model:
```python
import mlx.core as mx
from encodec import EncodecModel
from utils import load_audio, save_audio
# Load the 48 KHz model and preprocessor.
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)
# Preprocess the audio (this can also be a list of arrays for batched
# processing).
feats, mask = processor(audio)
# Encode at the given bandwidth. A lower bandwidth results in more
# compression but lower reconstruction quality.
@mx.compile
def encode(feats, mask):
return model.encode(feats, mask, bandwidth=3)
# Decode to reconstruct the audio
@mx.compile
def decode(codes, scales, mask):
return model.decode(codes, scales, mask)
codes, scales = encode(feats, mask)
reconstructed = decode(codes, scales, mask)
# Trim any padding:
reconstructed = reconstructed[0, : len(audio)]
# Save the audio as a wave file
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
```
The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the
[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164)
in several data types.
### Optional
To convert models, use the `convert.py` script. To see the options, run:
```bash
python convert.py -h
```
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and
[code](https://github.com/facebookresearch/encodec) for more details.

View File

@@ -0,0 +1,31 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
from encodec import EncodecModel
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
audio = mx.random.uniform(shape=(288000, 2))
feats, mask = processor(audio)
mx.eval(model, feats, mask)
@mx.compile
def fun():
codes, scales = model.encode(feats, mask, bandwidth=3)
reconstructed = model.decode(codes, scales, mask)
return reconstructed
for _ in range(5):
mx.eval(fun())
tic = time.time()
for _ in range(10):
mx.eval(fun())
toc = time.time()
ms = 1000 * (toc - tic) / 10
print(f"Time per it: {ms:.3f}")

View File

@@ -0,0 +1,34 @@
# Copyright © 2024 Apple Inc.
import time
import numpy as np
import torch
from transformers import AutoProcessor, EncodecModel
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
audio = np.random.uniform(size=(2, 288000)).astype(np.float32)
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps")
pt_inputs = processor(
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
).to("mps")
def fun():
pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"])
pt_audio = pt_model.decode(
pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"]
)
torch.mps.synchronize()
for _ in range(5):
fun()
tic = time.time()
for _ in range(10):
fun()
toc = time.time()
ms = 1000 * (toc - tic) / 10
print(f"Time per it: {ms:.3f}")

212
encodec/convert.py Normal file
View File

@@ -0,0 +1,212 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
from pathlib import Path
from textwrap import dedent
from types import SimpleNamespace
from typing import Any, Dict, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
import encodec
def fetch_from_hub(hf_repo: str) -> Path:
model_path = Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.json", "*.safetensors"],
)
)
return model_path
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os
from huggingface_hub import HfApi, ModelCard, logging
content = dedent(
f"""
---
language: en
license: other
library: mlx
tags:
- mlx
---
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
converted to MLX format from
[{hf_path}](https://huggingface.co/{hf_path}).
This model is intended to be used with the [EnCodec MLX
example](https://github.com/ml-explore/mlx-examples/tree/main/encodec).
"""
)
card = ModelCard(content)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
mx.save_safetensors(
str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"}
)
for weight_name in weights.keys():
index_data["weight_map"][weight_name] = "model.safetensors"
index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(index_data, f, indent=4)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Clean unused keys
config.pop("_name_or_path", None)
# sort the config for better readability
config = dict(sorted(config.items()))
# write the updated config to the config_path (if provided)
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)
def convert(
upload: bool,
model: str,
dtype: str = None,
):
hf_repo = f"facebook/encodec_{model}"
mlx_repo = f"mlx-community/encodec-{model}-{dtype}"
path = fetch_from_hub(hf_repo)
save_path = Path("mlx_models")
weights = mx.load(str(Path(path) / "model.safetensors"))
with open(path / "config.json", "r") as fid:
config = SimpleNamespace(**json.load(fid))
model = encodec.EncodecModel(config)
new_weights = {}
for k, v in weights.items():
basename, pname = k.rsplit(".", 1)
if pname == "weight_v":
g = weights[basename + ".weight_g"]
v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True))
k = basename + ".weight"
elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]:
continue
elif "lstm" in basename:
w_or_b, ih_or_hh, ln = pname.split("_")
if w_or_b == "weight":
new_pname = "Wx" if ih_or_hh == "ih" else "Wh"
elif w_or_b == "bias" and ih_or_hh == "ih":
continue
else:
v = v + weights[k.replace("_hh_", "_ih_")]
new_pname = "bias"
k = basename + "." + ln[1:] + "." + new_pname
if "conv.weight" in k:
# Possibly a transposed conv which has a different order
if "decoder" in k:
ln = int(k.split(".")[2])
if "conv" in model.decoder.layers[ln] and isinstance(
model.decoder.layers[ln].conv, nn.ConvTranspose1d
):
v = mx.moveaxis(v, 0, 2)
else:
v = mx.moveaxis(v, 1, 2)
else:
v = mx.moveaxis(v, 1, 2)
new_weights[k] = v
weights = new_weights
model.load_weights(list(weights.items()))
if dtype is not None:
t = getattr(mx, dtype)
weights = {k: v.astype(t) for k, v in weights.items()}
if isinstance(save_path, str):
save_path = Path(save_path)
save_weights(save_path, weights)
save_config(vars(config), config_path=save_path / "config.json")
if upload:
upload_to_hub(save_path, mlx_repo, hf_repo)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.")
parser.add_argument(
"--model",
type=str,
default="48khz",
help="",
choices=["24khz", "32khz", "48khz"],
)
parser.add_argument(
"--upload",
action="store_true",
help="Upload the weights to Hugging Face.",
)
parser.add_argument(
"--dtype",
type=str,
help="Data type to convert the model to.",
default="float32",
choices=["float32", "bfloat16", "float16"],
)
args = parser.parse_args()
convert(upload=args.upload, model=args.model, dtype=args.dtype)

741
encodec/encodec.py Normal file
View File

@@ -0,0 +1,741 @@
# Copyright © 2024 Apple Inc.
import functools
import json
import math
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
_lstm_kernel = mx.fast.metal_kernel(
name="lstm",
input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"],
output_names=["hidden_state", "cell_state"],
header="""
template <typename T>
T sigmoid(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
""",
source="""
uint b = thread_position_in_grid.x;
uint d = hidden_size * 4;
uint elem = b * d + thread_position_in_grid.y;
uint index = elem;
uint x_index = b * num_time_steps * d + time_step * d + index;
auto i = sigmoid(h_in[index] + x[x_index]);
index += hidden_size;
x_index += hidden_size;
auto f = sigmoid(h_in[index] + x[x_index]);
index += hidden_size;
x_index += hidden_size;
auto g = metal::precise::tanh(h_in[index] + x[x_index]);
index += hidden_size;
x_index += hidden_size;
auto o = sigmoid(h_in[index] + x[x_index]);
cell_state[elem] = f * cell[elem] + i * g;
hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]);
""",
)
def lstm_custom(x, h_in, cell, time_step):
assert x.ndim == 3, "Input to LSTM must have 3 dimensions."
out_shape = cell.shape
return _lstm_kernel(
inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]],
output_shapes=[out_shape, out_shape],
output_dtypes=[h_in.dtype, h_in.dtype],
grid=(x.shape[0], h_in.size // 4, 1),
threadgroup=(256, 1, 1),
)
class LSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
self.Wx = mx.zeros((4 * hidden_size, input_size))
self.Wh = mx.zeros((4 * hidden_size, hidden_size))
self.bias = mx.zeros((4 * hidden_size,)) if bias else None
def __call__(self, x, hidden=None, cell=None):
if self.bias is not None:
x = mx.addmm(self.bias, x, self.Wx.T)
else:
x = x @ self.Wx.T
all_hidden = []
B = x.shape[0]
cell = cell or mx.zeros((B, self.hidden_size), x.dtype)
for t in range(x.shape[-2]):
if hidden is None:
hidden = mx.zeros((B, self.hidden_size * 4), x.dtype)
else:
hidden = hidden @ self.Wh.T
hidden, cell = lstm_custom(x, hidden, cell, t)
all_hidden.append(hidden)
return mx.stack(all_hidden, axis=-2)
class EncodecConv1d(nn.Module):
"""Conv1d with asymmetric or causal padding and normalization."""
def __init__(
self,
config,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
):
super().__init__()
self.causal = config.use_causal_conv
self.pad_mode = config.pad_mode
self.norm_type = config.norm_type
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size, stride, dilation=dilation
)
if self.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
self.stride = stride
# Effective kernel size with dilations.
self.kernel_size = (kernel_size - 1) * dilation + 1
self.padding_total = kernel_size - stride
def _get_extra_padding_for_conv1d(
self,
hidden_states: mx.array,
) -> mx.array:
length = hidden_states.shape[1]
n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
n_frames = int(math.ceil(n_frames)) - 1
ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
return ideal_length - length
def _pad1d(
self,
hidden_states: mx.array,
paddings: Tuple[int, int],
mode: str = "zero",
value: float = 0.0,
):
if mode != "reflect":
return mx.pad(
hidden_states, paddings, mode="constant", constant_values=value
)
length = hidden_states.shape[1]
prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
return mx.concatenate([prefix, hidden_states, suffix], axis=1)
def __call__(self, hidden_states):
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
if self.causal:
# Left padding for causal
hidden_states = self._pad1d(
hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
)
else:
# Asymmetric padding required for odd strides
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
hidden_states = self._pad1d(
hidden_states,
(padding_left, padding_right + extra_padding),
mode=self.pad_mode,
)
hidden_states = self.conv(hidden_states)
if self.norm_type == "time_group_norm":
hidden_states = self.norm(hidden_states)
return hidden_states
class EncodecConvTranspose1d(nn.Module):
"""ConvTranspose1d with asymmetric or causal padding and normalization."""
def __init__(
self,
config,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
):
super().__init__()
self.causal = config.use_causal_conv
self.trim_right_ratio = config.trim_right_ratio
self.norm_type = config.norm_type
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
if config.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
self.padding_total = kernel_size - stride
def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
if self.norm_type == "time_group_norm":
hidden_states = self.norm(hidden_states)
if self.causal:
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
else:
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
end = hidden_states.shape[1] - padding_right
hidden_states = hidden_states[:, padding_left:end, :]
return hidden_states
class EncodecLSTM(nn.Module):
def __init__(self, config, dimension):
super().__init__()
self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)]
def __call__(self, hidden_states):
h = hidden_states
for lstm in self.lstm:
h = lstm(h)
return h + hidden_states
class EncodecResnetBlock(nn.Module):
"""
Residual block from SEANet model as used by EnCodec.
"""
def __init__(self, config, dim: int, dilations: List[int]):
super().__init__()
kernel_sizes = (config.residual_kernel_size, 1)
if len(kernel_sizes) != len(dilations):
raise ValueError("Number of kernel sizes should match number of dilations")
hidden = dim // config.compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [nn.ELU()]
block += [
EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
]
self.block = block
if getattr(config, "use_conv_shortcut", True):
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
else:
self.shortcut = nn.Identity()
def __call__(self, hidden_states):
residual = hidden_states
for layer in self.block:
hidden_states = layer(hidden_states)
return self.shortcut(residual) + hidden_states
class EncodecEncoder(nn.Module):
"""SEANet encoder as used by EnCodec."""
def __init__(self, config):
super().__init__()
model = [
EncodecConv1d(
config, config.audio_channels, config.num_filters, config.kernel_size
)
]
scaling = 1
for ratio in reversed(config.upsampling_ratios):
current_scale = scaling * config.num_filters
for j in range(config.num_residual_layers):
model += [
EncodecResnetBlock(
config, current_scale, [config.dilation_growth_rate**j, 1]
)
]
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
current_scale,
current_scale * 2,
kernel_size=ratio * 2,
stride=ratio,
)
]
scaling *= 2
model += [EncodecLSTM(config, scaling * config.num_filters)]
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
scaling * config.num_filters,
config.hidden_size,
config.last_kernel_size,
)
]
self.layers = model
def __call__(self, hidden_states):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
class EncodecDecoder(nn.Module):
"""SEANet decoder as used by EnCodec."""
def __init__(self, config):
super().__init__()
scaling = int(2 ** len(config.upsampling_ratios))
model = [
EncodecConv1d(
config,
config.hidden_size,
scaling * config.num_filters,
config.kernel_size,
)
]
model += [EncodecLSTM(config, scaling * config.num_filters)]
for ratio in config.upsampling_ratios:
current_scale = scaling * config.num_filters
model += [nn.ELU()]
model += [
EncodecConvTranspose1d(
config,
current_scale,
current_scale // 2,
kernel_size=ratio * 2,
stride=ratio,
)
]
for j in range(config.num_residual_layers):
model += [
EncodecResnetBlock(
config, current_scale // 2, (config.dilation_growth_rate**j, 1)
)
]
scaling //= 2
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
config.num_filters,
config.audio_channels,
config.last_kernel_size,
)
]
self.layers = model
def __call__(self, hidden_states):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
class EncodecEuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance."""
def __init__(self, config):
super().__init__()
self.embed = mx.zeros((config.codebook_size, config.codebook_dim))
def quantize(self, hidden_states):
embed = self.embed.T
scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
dist = -(
scaled_states
- 2 * hidden_states @ embed
+ embed.square().sum(axis=0, keepdims=True)
)
embed_ind = dist.argmax(axis=-1)
return embed_ind
def encode(self, hidden_states):
shape = hidden_states.shape
hidden_states = hidden_states.reshape((-1, shape[-1]))
embed_ind = self.quantize(hidden_states)
embed_ind = embed_ind.reshape(*shape[:-1])
return embed_ind
def decode(self, embed_ind):
return self.embed[embed_ind]
class EncodecVectorQuantization(nn.Module):
"""
Vector quantization implementation. Currently supports only euclidean distance.
"""
def __init__(self, config):
super().__init__()
self.codebook = EncodecEuclideanCodebook(config)
def encode(self, hidden_states):
return self.codebook.encode(hidden_states)
def decode(self, embed_ind):
return self.codebook.decode(embed_ind)
class EncodecResidualVectorQuantizer(nn.Module):
"""Residual Vector Quantizer."""
def __init__(self, config):
super().__init__()
self.codebook_size = config.codebook_size
hop_length = np.prod(config.upsampling_ratios)
self.frame_rate = math.ceil(config.sampling_rate / hop_length)
self.num_quantizers = int(
1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
)
self.layers = [
EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
]
def get_num_quantizers_for_bandwidth(
self, bandwidth: Optional[float] = None
) -> int:
"""Return num_quantizers based on specified target bandwidth."""
bw_per_q = math.log2(self.codebook_size) * self.frame_rate
num_quantizers = self.num_quantizers
if bandwidth is not None and bandwidth > 0.0:
num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
return num_quantizers
def encode(
self, embeddings: mx.array, bandwidth: Optional[float] = None
) -> mx.array:
"""
Encode a given input array with the specified frame rate at the given
bandwidth. The RVQ encode method sets the appropriate number of
quantizers to use and returns indices for each quantizer.
"""
num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
residual = embeddings
all_indices = []
for layer in self.layers[:num_quantizers]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = mx.stack(all_indices, axis=1)
return out_indices
def decode(self, codes: mx.array) -> mx.array:
"""Decode the given codes to the quantized representation."""
quantized_out = None
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
layer = self.layers[i]
quantized = layer.decode(indices.squeeze(1))
if quantized_out is None:
quantized_out = quantized
else:
quantized_out = quantized + quantized_out
return quantized_out
class EncodecModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.encoder = EncodecEncoder(config)
self.decoder = EncodecDecoder(config)
self.quantizer = EncodecResidualVectorQuantizer(config)
def _encode_frame(
self, input_values: mx.array, bandwidth: float, padding_mask: mx.array
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Encodes the given input using the underlying VQVAE.
"""
length = input_values.shape[1]
duration = length / self.config.sampling_rate
if (
self.config.chunk_length_s is not None
and duration > 1e-5 + self.config.chunk_length_s
):
raise RuntimeError(
f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
)
scale = None
if self.config.normalize:
# if the padding is non zero
input_values = input_values * padding_mask[..., None]
mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2]
scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8
input_values = input_values / scale
embeddings = self.encoder(input_values)
codes = self.quantizer.encode(embeddings, bandwidth)
return codes, scale
def encode(
self,
input_values: mx.array,
padding_mask: mx.array = None,
bandwidth: Optional[float] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Encodes the input audio waveform into discrete codes.
Args:
input_values (mx.array): The input audio waveform with shape
``(batch_size, channels, sequence_length)``.
padding_mask (mx.array): Padding mask used to pad the ``input_values``.
bandwidth (float, optional): The target bandwidth. Must be one of
``config.target_bandwidths``. If ``None``, uses the smallest
possible bandwidth. bandwidth is represented as a thousandth of
what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0
Returns:
A list of frames containing the discrete encoded codes for the
input audio waveform, along with rescaling factors for each chunk
when ``config.normalize==True``. Each frame is a tuple ``(codebook,
scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks,
frames)``.
"""
if bandwidth is None:
bandwidth = self.config.target_bandwidths[0]
if bandwidth not in self.config.target_bandwidths:
raise ValueError(
f"This model doesn't support the bandwidth {bandwidth}. "
f"Select one of {self.config.target_bandwidths}."
)
_, input_length, channels = input_values.shape
if channels < 1 or channels > 2:
raise ValueError(
f"Number of audio channels must be 1 or 2, but got {channels}"
)
chunk_length = self.chunk_length
if chunk_length is None:
chunk_length = input_length
stride = input_length
else:
stride = self.chunk_stride
if padding_mask is None:
padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_)
encoded_frames = []
scales = []
step = chunk_length - stride
if (input_length % stride) != step:
raise ValueError(
"The input length is not properly padded for batched chunked "
"encoding. Make sure to pad the input correctly."
)
for offset in range(0, input_length - step, stride):
mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
frame = input_values[:, offset : offset + chunk_length]
encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
encoded_frames.append(encoded_frame)
scales.append(scale)
encoded_frames = mx.stack(encoded_frames)
return (encoded_frames, scales)
@staticmethod
def _linear_overlap_add(frames: List[mx.array], stride: int):
if len(frames) == 0:
raise ValueError("`frames` cannot be an empty list.")
dtype = frames[0].dtype
N, frame_length, C = frames[0].shape
total_size = stride * (len(frames) - 1) + frames[-1].shape[1]
time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
weight = 0.5 - (time_vec - 0.5).abs()
weight = weight[:, None]
sum_weight = mx.zeros((total_size, 1), dtype=dtype)
out = mx.zeros((N, total_size, C), dtype=dtype)
offset = 0
for frame in frames:
frame_length = frame.shape[1]
out[:, offset : offset + frame_length] += weight[:frame_length] * frame
sum_weight[offset : offset + frame_length] += weight[:frame_length]
offset += stride
return out / sum_weight
def _decode_frame(
self, codes: mx.array, scale: Optional[mx.array] = None
) -> mx.array:
embeddings = self.quantizer.decode(codes)
outputs = self.decoder(embeddings)
if scale is not None:
outputs = outputs * scale
return outputs
@property
def channels(self):
return self.config.audio_channels
@property
def sampling_rate(self):
return self.config.sampling_rate
@property
def chunk_length(self):
if self.config.chunk_length_s is None:
return None
else:
return int(self.config.chunk_length_s * self.config.sampling_rate)
@property
def chunk_stride(self):
if self.config.chunk_length_s is None or self.config.overlap is None:
return None
else:
return max(1, int((1.0 - self.config.overlap) * self.chunk_length))
def decode(
self,
audio_codes: mx.array,
audio_scales: Union[mx.array, List[mx.array]],
padding_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, mx.array]:
"""
Decodes the given frames into an output audio waveform.
Note that the output might be a bit bigger than the input. In that
case, any extra steps at the end should be trimmed.
Args:
audio_codes (mx.array): Discret code embeddings of shape
``(batch_size, nb_chunks, chunk_length)``.
audio_scales (mx.array): Scaling factor for each input.
padding_mask (mx.array): Padding mask.
"""
chunk_length = self.chunk_length
if chunk_length is None:
if audio_codes.shape[1] != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
else:
decoded_frames = []
for frame, scale in zip(audio_codes, audio_scales):
frames = self._decode_frame(frame, scale)
decoded_frames.append(frames)
audio_values = self._linear_overlap_add(
decoded_frames, self.chunk_stride or 1
)
# truncate based on padding mask
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
audio_values = audio_values[:, : padding_mask.shape[1]]
return audio_values
@classmethod
def from_pretrained(cls, path_or_repo: str):
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
preprocess_audio,
sampling_rate=config.sampling_rate,
chunk_length=model.chunk_length,
chunk_stride=model.chunk_stride,
)
mx.eval(model)
return model, processor
def preprocess_audio(
raw_audio: Union[mx.array, List[mx.array]],
sampling_rate: int = 24000,
chunk_length: Optional[int] = None,
chunk_stride: Optional[int] = None,
):
r"""
Prepare inputs for the EnCodec model.
Args:
raw_audio (mx.array or List[mx.array]): The sequence or batch of
sequences to be processed.
sampling_rate (int): The sampling rate at which the audio waveform
should be digitalized.
chunk_length (int, optional): The model's chunk length.
chunk_stride (int, optional): The model's chunk stride.
"""
if not isinstance(raw_audio, list):
raw_audio = [raw_audio]
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
max_length = max(array.shape[0] for array in raw_audio)
if chunk_length is not None:
max_length += chunk_length - (max_length % chunk_stride)
inputs = []
masks = []
for x in raw_audio:
length = x.shape[0]
mask = mx.ones((length,), dtype=mx.bool_)
difference = max_length - length
if difference > 0:
mask = mx.pad(mask, (0, difference))
x = mx.pad(x, ((0, difference), (0, 0)))
inputs.append(x)
masks.append(mask)
return mx.stack(inputs), mx.stack(masks)

39
encodec/example.py Normal file
View File

@@ -0,0 +1,39 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
from utils import load_audio, save_audio
from encodec import EncodecModel
# Load the 48 KHz model and preprocessor.
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
# Preprocess the audio (this can also be a list of arrays for batched
# processing).
feats, mask = processor(audio)
# Encode at the given bandwidth. A lower bandwidth results in more
# compression but lower reconstruction quality.
@mx.compile
def encode(feats, mask):
return model.encode(feats, mask, bandwidth=3)
# Decode to reconstruct the audio
@mx.compile
def decode(codes, scales, mask):
return model.decode(codes, scales, mask)
codes, scales = encode(feats, mask)
reconstructed = decode(codes, scales, mask)
# Trim any padding:
reconstructed = reconstructed[0, : len(audio)]
# Save the audio as a wave file
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)

3
encodec/requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
mlx>=0.18
numpy
huggingface_hub

67
encodec/test.py Normal file
View File

@@ -0,0 +1,67 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import numpy as np
import torch
from transformers import AutoProcessor
from transformers import EncodecModel as PTEncodecModel
from encodec import EncodecModel, preprocess_audio
def compare_processors():
np.random.seed(0)
audio_length = 95500
audio = np.random.uniform(size=(2, audio_length)).astype(np.float32)
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
pt_inputs = processor(
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
)
mx_inputs = preprocess_audio(
mx.array(audio).T,
processor.sampling_rate,
processor.chunk_length,
processor.chunk_stride,
)
assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1))
assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1])
def compare_models():
pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
np.random.seed(0)
audio_length = 190560
audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32)
mask = np.ones((1, audio_length), dtype=np.int32)
pt_encoded = pt_model.encode(
torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None]
)
mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask))
pt_codes = pt_encoded.audio_codes.numpy()
mx_codes = mx_encoded[0]
assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch"
for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales):
if mx_scale is not None:
pt_scale = pt_scale.numpy()
assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4)
pt_audio = pt_model.decode(
pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None]
)
pt_audio = pt_audio[0].squeeze().T.detach().numpy()
mx_audio = mx_model.decode(*mx_encoded, mx.array(mask))
mx_audio = mx_audio.squeeze()
assert np.allclose(
pt_audio, mx_audio, atol=1e-4, rtol=1e-4
), "Decoding audio mismatch"
if __name__ == "__main__":
compare_processors()
compare_models()

52
encodec/utils.py Normal file
View File

@@ -0,0 +1,52 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import numpy as np
def save_audio(file: str, audio: mx.array, sampling_rate: int):
"""
Save audio to a wave (.wav) file.
"""
from scipy.io.wavfile import write
audio = (audio * 32767).astype(mx.int16)
write(file, sampling_rate, np.array(audio))
def load_audio(file: str, sampling_rate: int, channels: int):
"""
Read audio into an mx.array, resampling if necessary.
Args:
file (str): The audio file to open.
sampling_rate (int): The sample rate to resample the audio at if needed.
channels (int): The number of audio channels.
Returns:
An mx.array containing the audio waveform in float32.
"""
from subprocess import CalledProcessError, run
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", str(channels),
"-acodec", "pcm_s16le",
"-ar", str(sampling_rate),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
out = mx.array(np.frombuffer(out, np.int16))
return out.reshape(-1, channels).astype(mx.float32) / 32767.0

212
flux/README.md Normal file
View File

@@ -0,0 +1,212 @@
FLUX
====
FLUX implementation in MLX. The implementation is ported directly from
[https://github.com/black-forest-labs/flux](https://github.com/black-forest-labs/flux)
and the model weights are downloaded directly from the Hugging Face Hub.
The goal of this example is to be clean, educational and to allow for
experimentation with finetuning FLUX models as well as adding extra
functionality such as in-/outpainting, guidance with custom losses etc.
![MLX image](static/generated-mlx.png)
*Image generated using FLUX-dev in MLX and the prompt 'An image in the style of
tron emanating futuristic technology with the word "MLX" in the center with
capital red letters.'*
Installation
------------
The dependencies are minimal, namely:
- `huggingface-hub` to download the checkpoints.
- `regex` for the tokenization
- `tqdm`, `PIL`, and `numpy` for the scripts
- `sentencepiece` for the T5 tokenizer
- `datasets` for using an HF dataset directly
You can install all of the above with the `requirements.txt` as follows:
pip install -r requirements.txt
Usage
---------
You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`.
```shell
python txt2image.py --model schnell \
--n-images 1 \
--image-size 256x512 \
--verbose \
'A photo of an astronaut riding a horse on Mars.'
```
For more parameters, please use the `--help` command to view.
```shell
python txt2image.py --help
```
Inference
---------
Inference in this example is similar to the stable diffusion example. The
classes to get you started are `FluxPipeline` from the `flux` module.
```python
import mlx.core as mx
from flux import FluxPipeline
# This will download all the weights from HF hub
flux = FluxPipeline("flux-schnell")
# Make a generator that returns the latent variables from the reverse diffusion
# process
latent_generator = flux.generate_latents(
"A photo of an astronaut riding a horse on Mars",
num_steps=4,
latent_size=(32, 64), # 256x512 image
)
# The first return value of the generator contains the conditioning and the
# random noise at the beginning of the diffusion process.
conditioning = next(latent_generator)
(
x_T, # The initial noise
x_positions, # The integer positions used for image positional encoding
t5_conditioning, # The T5 features from the text prompt
t5_positions, # Integer positions for text (normally all 0s)
clip_conditioning, # The clip text features from the text prompt
) = conditioning
# Returning the conditioning as the first output from the generator allows us
# to unload T5 and clip before running the diffusion transformer.
mx.eval(conditioning)
# Evaluate each diffusion step
for x_t in latent_generator:
mx.eval(x_t)
# Note that we need to pass the latent size because it is collapsed and
# patchified in x_t and we need to unwrap it.
img = flux.decode(x_t, latent_size=(32, 64))
```
The above are essentially the implementation of the `txt2image.py` script
except for some additional logic to quantize and/or load trained adapters. One
can use the script as follows:
```shell
python txt2image.py \
--n-images 4 \
--n-rows 2 \
--image-size 256x512 \
'A photo of an astronaut riding a horse on Mars.'
```
### Experimental Options
FLUX pads the prompt to a specific size of 512 tokens for the dev model and
256 for the schnell model. Not applying padding results in faster generation
but it is not clear how it may affect the generated images. To enable that
option in this example pass `--no-t5-padding` to the `txt2image.py` script or
instantiate the pipeline with `FluxPipeline("flux-schnell", t5_padding=False)`.
Finetuning
----------
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
but ymmv) on a provided image dataset. The dataset folder must have an
`train.jsonl` file with the following format:
```jsonl
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
...
```
The training script by default trains for 600 iterations with a batch size of
1, gradient accumulation of 4 and LoRA rank of 8. Run `python dreambooth.py
--help` for the list of hyperparameters you can tune.
> [!Note]
> FLUX finetuning requires approximately 50GB of RAM. QLoRA is coming soon and
> should reduce this number significantly.
### Training Example
This is a step-by-step finetuning example. We will be using the data from
[https://github.com/google/dreambooth](https://github.com/google/dreambooth).
In particular, we will use `dog6` which is a popular example for showcasing
dreambooth [^1].
The training images are the following 5 images [^2]:
![dog6](static/dog6.png)
We start by making the following `train.jsonl` file and placing it in the same
folder as the images.
```jsonl
{"image": "00.jpg", "prompt": "A photo of sks dog"}
{"image": "01.jpg", "prompt": "A photo of sks dog"}
{"image": "02.jpg", "prompt": "A photo of sks dog"}
{"image": "03.jpg", "prompt": "A photo of sks dog"}
{"image": "04.jpg", "prompt": "A photo of sks dog"}
```
Subsequently we finetune FLUX using the following command:
```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
path/to/dreambooth/dataset/dog6
```
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
mlx-community/dreambooth-dog6
```
The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
bit more than 1 hour.
### Using the Adapter
The adapters are saved in `mlx_output` and can be used directly by the
`txt2image.py` script. For instance,
```shell
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
--adapter mlx_output/final_adapters.safetensors \
--fuse-adapter \
--no-t5-padding \
'A photo of an sks dog lying on the sand at a beach in Greece'
```
generates an image that looks like the following,
![dog image](static/dog-r4-g8-1200.png)
and of course we can pass `--image-size 512x1024` to get larger images with
different aspect ratios,
![wide dog image](static/dog-r4-g8-1200-512x1024.png)
The arguments that are relevant to the adapters are of course `--adapter` and
`--fuse-adapter`. The first defines the path to an adapter to apply to the
model and the second fuses the adapter back into the model to get a bit more
speed during generation.
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .

292
flux/dreambooth.py Normal file
View File

@@ -0,0 +1,292 @@
# Copyright © 2024 Apple Inc.
import argparse
import time
from functools import partial
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from flux import FluxPipeline, Trainer, load_dataset, save_config
def generate_progress_images(iteration, flux, args):
"""Generate images to monitor the progress of the finetuning."""
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_progress.png"
print(f"Generating {str(out_file)}", flush=True)
# Generate some images and arrange them in a grid
n_rows = 2
n_images = 4
x = flux.generate_images(
args.progress_prompt,
n_images,
args.progress_steps,
)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(n_rows * H, B // n_rows * W, C)
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(out_file)
def save_adapters(adapter_name, flux, args):
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / adapter_name
print(f"Saving {str(out_file)}")
mx.save_safetensors(
str(out_file),
dict(tree_flatten(flux.flow.trainable_parameters())),
metadata={
"lora_rank": str(args.lora_rank),
"lora_blocks": str(args.lora_blocks),
},
)
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Finetune Flux to generate images with a specific subject"
)
parser.add_argument(
"--model",
default="dev",
choices=[
"dev",
"schnell",
],
help="Which flux model to train",
)
parser.add_argument(
"--guidance", type=float, default=4.0, help="The guidance factor to use."
)
parser.add_argument(
"--iterations",
type=int,
default=600,
help="How many iterations to train for",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="The batch size to use when training the stable diffusion model",
)
parser.add_argument(
"--resolution",
type=lambda x: tuple(map(int, x.split("x"))),
default=(512, 512),
help="The resolution of the training images",
)
parser.add_argument(
"--num-augmentations",
type=int,
default=5,
help="Augment the images by random cropping and panning",
)
parser.add_argument(
"--progress-prompt",
required=True,
help="Use this prompt when generating images for evaluation",
)
parser.add_argument(
"--progress-steps",
type=int,
default=50,
help="Use this many steps when generating images for evaluation",
)
parser.add_argument(
"--progress-every",
type=int,
default=50,
help="Generate images every PROGRESS_EVERY steps",
)
parser.add_argument(
"--checkpoint-every",
type=int,
default=50,
help="Save the model every CHECKPOINT_EVERY steps",
)
parser.add_argument(
"--lora-blocks",
type=int,
default=-1,
help="Train the last LORA_BLOCKS transformer blocks",
)
parser.add_argument(
"--lora-rank", type=int, default=8, help="LoRA rank for finetuning"
)
parser.add_argument(
"--warmup-steps", type=int, default=100, help="Learning rate warmup"
)
parser.add_argument(
"--learning-rate", type=float, default="1e-4", help="Learning rate for training"
)
parser.add_argument(
"--grad-accumulate",
type=int,
default=4,
help="Accumulate gradients for that many iterations before applying them",
)
parser.add_argument(
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
)
parser.add_argument("dataset")
return parser
if __name__ == "__main__":
parser = setup_arg_parser()
args = parser.parse_args()
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), output_path / "adapter_config.json")
# Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same
# initial weights.
mx.random.seed(0x0F0F0F0F)
flux = FluxPipeline("flux-" + args.model)
flux.flow.freeze()
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
# Reset the seed to a different seed per worker if we are in distributed
# mode so that each worker is working on different data, diffusion step and
# random noise.
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
# Report how many parameters we are training
trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
)
print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
# Set up the optimizer and training steps. The steps are a bit verbose to
# support gradient accumulation together with compilation.
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
cosine = optim.cosine_decay(
args.learning_rate, args.iterations // args.grad_accumulate
)
lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps])
optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = average_gradients(grads)
optimizer.update(flux.flow, grads)
return loss
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads
@partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = tree_map(
lambda a, b: (a + b) / args.grad_accumulate,
prev_grads,
grads,
)
grads = average_gradients(grads)
optimizer.update(flux.flow, grads)
return loss
# We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step.
def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
if prev_grads is None:
if perform_step:
return single_step(x, t5_feat, clip_feat, guidance), None
else:
return compute_loss_and_grads(x, t5_feat, clip_feat, guidance)
else:
if perform_step:
return (
grad_accumulate_and_step(
x, t5_feat, clip_feat, guidance, prev_grads
),
None,
)
else:
return compute_loss_and_accumulate_grads(
x, t5_feat, clip_feat, guidance, prev_grads
)
dataset = load_dataset(args.dataset)
trainer = Trainer(flux, dataset, args)
trainer.encode_dataset()
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
# An initial generation to compare
generate_progress_images(0, flux, args)
grads = None
losses = []
tic = time.time()
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state)
losses.append(loss.item())
if (i + 1) % 10 == 0:
toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3
print(
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB",
flush=True,
)
if (i + 1) % args.progress_every == 0:
generate_progress_images(i + 1, flux, args)
if (i + 1) % args.checkpoint_every == 0:
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
if (i + 1) % 10 == 0:
losses = []
tic = time.time()
save_adapters("final_adapters.safetensors", flux, args)
print("Training successful.")

16
flux/flux/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# Copyright © 2024 Apple Inc.
from .datasets import Dataset, load_dataset
from .flux import FluxPipeline
from .lora import LoRALinear
from .sampler import FluxSampler
from .trainer import Trainer
from .utils import (
load_ae,
load_clip,
load_clip_tokenizer,
load_flow_model,
load_t5,
load_t5_tokenizer,
save_config,
)

357
flux/flux/autoencoder.py Normal file
View File

@@ -0,0 +1,357 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import List
import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.upsample import upsample_nearest
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: List[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(
num_groups=32,
dims=in_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.q = nn.Linear(in_channels, in_channels)
self.k = nn.Linear(in_channels, in_channels)
self.v = nn.Linear(in_channels, in_channels)
self.proj_out = nn.Linear(in_channels, in_channels)
def __call__(self, x: mx.array) -> mx.array:
B, H, W, C = x.shape
y = x.reshape(B, 1, -1, C)
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
y = self.proj_out(y)
return x + y.reshape(B, H, W, C)
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(
num_groups=32,
dims=in_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = nn.GroupNorm(
num_groups=32,
dims=out_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x):
h = x
h = self.norm1(h)
h = nn.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = nn.silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def __call__(self, x: mx.array):
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x: mx.array):
x = upsample_nearest(x, (2, 2))
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = []
block_in = self.ch
for i_level in range(self.num_resolutions):
block = []
attn = [] # TODO: Remove the attn, nobody appends anything to it
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = {}
down["block"] = block
down["attn"] = attn
if i_level != self.num_resolutions - 1:
down["downsample"] = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid["attn_1"] = AttnBlock(block_in)
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
)
self.conv_out = nn.Conv2d(
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x: mx.array):
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level]["block"][i_block](hs[-1])
# TODO: Remove the attn
if len(self.down[i_level]["attn"]) > 0:
h = self.down[i_level]["attn"][i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level]["downsample"](hs[-1]))
# middle
h = hs[-1]
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
# end
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid["attn_1"] = AttnBlock(block_in)
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = []
for i_level in reversed(range(self.num_resolutions)):
block = []
attn = [] # TODO: Remove the attn, nobody appends anything to it
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = {}
up["block"] = block
up["attn"] = attn
if i_level != 0:
up["upsample"] = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def __call__(self, z: mx.array):
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level]["block"][i_block](h)
# TODO: Remove the attn
if len(self.up[i_level]["attn"]) > 0:
h = self.up[i_level]["attn"][i_block](h)
if i_level != 0:
h = self.up[i_level]["upsample"](h)
# end
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __call__(self, z: mx.array):
mean, logvar = mx.split(z, 2, axis=-1)
if self.training:
std = mx.exp(0.5 * logvar)
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
return mean + std * eps
else:
return mean
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
if w.ndim == 4:
w = w.transpose(0, 2, 3, 1)
w = w.reshape(-1).reshape(w.shape)
if w.shape[1:3] == (1, 1):
w = w.squeeze((1, 2))
new_weights[k] = w
return new_weights
def encode(self, x: mx.array):
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: mx.array):
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def __call__(self, x: mx.array):
return self.decode(self.encode(x))

154
flux/flux/clip.py Normal file
View File

@@ -0,0 +1,154 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
@dataclass
class CLIPTextModelConfig:
num_layers: int = 23
model_dims: int = 1024
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
hidden_act: str = "quick_gelu"
@classmethod
def from_dict(cls, config):
return cls(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
hidden_act=config["hidden_act"],
)
@dataclass
class CLIPOutput:
# The last_hidden_state indexed at the EOS token and possibly projected if
# the model has a projection layer
pooled_output: Optional[mx.array] = None
# The full sequence output of the transformer after the final layernorm
last_hidden_state: Optional[mx.array] = None
# A list of hidden states corresponding to the outputs of the transformer layers
hidden_states: Optional[List[mx.array]] = None
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int, activation: str):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
self.layer_norm2 = nn.LayerNorm(model_dims)
self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
self.act = _ACTIVATIONS[activation]
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
x = y + x
y = self.layer_norm2(x)
y = self.linear1(y)
y = self.act(y)
y = self.linear2(y)
x = y + x
return x
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextModelConfig):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = [
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
for i in range(config.num_layers)
]
self.final_layer_norm = nn.LayerNorm(config.model_dims)
def _get_mask(self, N, dtype):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
return mask
def sanitize(self, weights):
new_weights = {}
for key, w in weights.items():
# Remove prefixes
if key.startswith("text_model."):
key = key[11:]
if key.startswith("embeddings."):
key = key[11:]
if key.startswith("encoder."):
key = key[8:]
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
new_weights[key] = w
return new_weights
def __call__(self, x):
# Extract some shapes
B, N = x.shape
eos_tokens = x.argmax(-1)
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = self._get_mask(N, x.dtype)
hidden_states = []
for l in self.layers:
x = l(x, mask)
hidden_states.append(x)
# Apply the final layernorm and return
x = self.final_layer_norm(x)
last_hidden_state = x
# Select the EOS token
pooled_output = x[mx.arange(len(x)), eos_tokens]
return CLIPOutput(
pooled_output=pooled_output,
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
)

75
flux/flux/datasets.py Normal file
View File

@@ -0,0 +1,75 @@
import json
from pathlib import Path
from PIL import Image
class Dataset:
def __getitem__(self, index: int):
raise NotImplementedError()
def __len__(self):
raise NotImplementedError()
class LocalDataset(Dataset):
prompt_key = "prompt"
def __init__(self, dataset: str, data_file):
self.dataset_base = Path(dataset)
with open(data_file, "r") as fid:
self._data = [json.loads(l) for l in fid]
def __len__(self):
return len(self._data)
def __getitem__(self, index: int):
item = self._data[index]
image = Image.open(self.dataset_base / item["image"])
return image, item[self.prompt_key]
class LegacyDataset(LocalDataset):
prompt_key = "text"
def __init__(self, dataset: str):
self.dataset_base = Path(dataset)
with open(self.dataset_base / "index.json") as f:
self._data = json.load(f)["data"]
class HuggingFaceDataset(Dataset):
def __init__(self, dataset: str):
from datasets import load_dataset as hf_load_dataset
self._df = hf_load_dataset(dataset)["train"]
def __len__(self):
return len(self._df)
def __getitem__(self, index: int):
item = self._df[index]
return item["image"], item["prompt"]
def load_dataset(dataset: str):
dataset_base = Path(dataset)
data_file = dataset_base / "train.jsonl"
legacy_file = dataset_base / "index.json"
if data_file.exists():
print(f"Load the local dataset {data_file} .", flush=True)
dataset = LocalDataset(dataset, data_file)
elif legacy_file.exists():
print(f"Load the local dataset {legacy_file} .")
print()
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
print(" See the README for details.")
print(flush=True)
dataset = LegacyDataset(dataset)
else:
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
dataset = HuggingFaceDataset(dataset)
return dataset

246
flux/flux/flux.py Normal file
View File

@@ -0,0 +1,246 @@
# Copyright © 2024 Apple Inc.
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from tqdm import tqdm
from .lora import LoRALinear
from .sampler import FluxSampler
from .utils import (
load_ae,
load_clip,
load_clip_tokenizer,
load_flow_model,
load_t5,
load_t5_tokenizer,
)
class FluxPipeline:
def __init__(self, name: str, t5_padding: bool = True):
self.dtype = mx.bfloat16
self.name = name
self.t5_padding = t5_padding
self.ae = load_ae(name)
self.flow = load_flow_model(name)
self.clip = load_clip(name)
self.clip_tokenizer = load_clip_tokenizer(name)
self.t5 = load_t5(name)
self.t5_tokenizer = load_t5_tokenizer(name)
self.sampler = FluxSampler(name)
def ensure_models_are_loaded(self):
mx.eval(
self.ae.parameters(),
self.flow.parameters(),
self.clip.parameters(),
self.t5.parameters(),
)
def reload_text_encoders(self):
self.t5 = load_t5(self.name)
self.clip = load_clip(self.name)
def tokenize(self, text):
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
clip_tokens = self.clip_tokenizer.encode(text)
return t5_tokens, clip_tokens
def _prepare_latent_images(self, x):
b, h, w, c = x.shape
# Pack the latent image to 2x2 patches
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
# Create positions ids used to positionally encode each patch. Due to
# the way RoPE works, this results in an interesting positional
# encoding where parts of the feature are holding different positional
# information. Namely, the first part holds information independent of
# the spatial position (hence 0s), the 2nd part holds vertical spatial
# information and the last one horizontal.
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
x_ids = mx.stack([i, j, k], axis=-1)
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
return x, x_ids
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
# Prepare the text features
txt = self.t5(t5_tokens)
if len(txt) == 1 and n_images > 1:
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
# Prepare the clip text features
vec = self.clip(clip_tokens).pooled_output
if len(vec) == 1 and n_images > 1:
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
return txt, txt_ids, vec
def _denoising_loop(
self,
x_t,
x_ids,
txt,
txt_ids,
vec,
num_steps: int = 35,
guidance: float = 4.0,
start: float = 1,
stop: float = 0,
):
B = len(x_t)
def scalar(x):
return mx.full((B,), x, dtype=self.dtype)
guidance = scalar(guidance)
timesteps = self.sampler.timesteps(
num_steps,
x_t.shape[1],
start=start,
stop=stop,
)
for i in range(num_steps):
t = timesteps[i]
t_prev = timesteps[i + 1]
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=scalar(t),
guidance=guidance,
)
x_t = self.sampler.step(pred, x_t, t, t_prev)
yield x_t
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
):
# Set the PRNG state
if seed is not None:
mx.random.seed(seed)
# Create the latent variables
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
x_T, x_ids = self._prepare_latent_images(x_T)
# Get the conditioning
t5_tokens, clip_tokens = self.tokenize(text)
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
# Yield the conditioning for controlled evaluation by the caller
yield (x_T, x_ids, txt, txt_ids, vec)
# Yield the latent sequences from the denoising loop
yield from self._denoising_loop(
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
)
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
h, w = latent_size
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
x = self.ae.decode(x)
return mx.clip(x + 1, 0, 2) * 0.5
def generate_images(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
reload_text_encoders: bool = True,
progress: bool = True,
):
latents = self.generate_latents(
text, n_images, num_steps, guidance, latent_size, seed
)
mx.eval(next(latents))
if reload_text_encoders:
self.reload_text_encoders()
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
mx.eval(x_t)
images = []
for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
images.append(self.decode(x_t[i : i + 1]))
mx.eval(images[-1])
images = mx.concatenate(images, axis=0)
mx.eval(images)
return images
def training_loss(
self,
x_0: mx.array,
t5_features: mx.array,
clip_features: mx.array,
guidance: mx.array,
):
# Get the text conditioning
txt = t5_features
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
vec = clip_features
# Prepare the latent input
x_0, x_ids = self._prepare_latent_images(x_0)
# Forward process
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
x_t = self.sampler.add_noise(x_0, t, noise=eps)
x_t = mx.stop_gradient(x_t)
# Do the denoising
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t,
guidance=guidance,
)
return (pred + x_0 - eps).square().mean()
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
"""Swap the linear layers in the transformer blocks with LoRA layers."""
all_blocks = self.flow.double_blocks + self.flow.single_blocks
all_blocks.reverse()
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
for i, block in zip(range(num_blocks), all_blocks):
loras = []
for name, module in block.named_modules():
if isinstance(module, nn.Linear):
loras.append((name, LoRALinear.from_base(module, r=rank)))
block.update_modules(tree_unflatten(loras))
def fuse_lora_layers(self):
fused_layers = []
for name, module in self.flow.named_modules():
if isinstance(module, LoRALinear):
fused_layers.append((name, module.fuse()))
self.flow.update_modules(tree_unflatten(fused_layers))

302
flux/flux/layers.py Normal file
View File

@@ -0,0 +1,302 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
def _rope(pos: mx.array, dim: int, theta: float):
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
omega = 1.0 / (theta**scale)
x = pos[..., None] * omega
cosx = mx.cos(x)
sinx = mx.sin(x)
pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
pe = pe.reshape(*pe.shape[:-1], 2, 2)
return pe
@partial(mx.compile, shapeless=True)
def _ab_plus_cd(a, b, c, d):
return a * b + c * d
def _apply_rope(x, pe):
s = x.shape
x = x.reshape(*s[:-1], -1, 1, 2)
x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
return x.reshape(s)
def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
B, H, L, D = q.shape
q = _apply_rope(q, pe)
k = _apply_rope(k, pe)
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
def timestep_embedding(
t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
):
half = dim // 2
freqs = mx.arange(0, half, dtype=mx.float32) / half
freqs = freqs * (-math.log(max_period))
freqs = mx.exp(freqs)
x = (time_factor * t)[:, None] * freqs[None]
x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
return x.astype(t.dtype)
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def __call__(self, ids: mx.array):
n_axes = ids.shape[-1]
pe = mx.concatenate(
[_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
axis=-3,
)
return pe[:, None]
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return self.out_layer(nn.silu(self.in_layer(x)))
class QKNorm(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = nn.RMSNorm(dim)
self.key_norm = nn.RMSNorm(dim)
def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
return self.query_norm(q), self.key_norm(k)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
H = self.num_heads
B, L, _ = x.shape
qkv = self.qkv(x)
q, k, v = mx.split(qkv, 3, axis=-1)
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
q, k = self.norm(q, k)
x = _attention(q, k, v, pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: mx.array
scale: mx.array
gate: mx.array
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
x = self.lin(nn.silu(x))
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
mod1 = ModulationOut(*xs[:3])
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
return mod1, mod2
class DoubleStreamBlock(nn.Module):
def __init__(
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.img_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approx="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.txt_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approx="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def __call__(
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
) -> Tuple[mx.array, mx.array]:
B, L, _ = img.shape
_, S, _ = txt.shape
H = self.num_heads
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_q, img_k = self.img_attn.norm(img_q, img_k)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
# run actual attention
q = mx.concatenate([txt_q, img_q], axis=2)
k = mx.concatenate([txt_k, img_k], axis=2)
v = mx.concatenate([txt_v, img_v], axis=2)
attn = _attention(q, k, v, pe)
txt_attn, img_attn = mx.split(attn, [S], axis=1)
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
return img, txt
class SingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: Optional[float] = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approx="tanh")
self.modulation = Modulation(hidden_size, double=False)
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
B, L, _ = x.shape
H = self.num_heads
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
q, k, v, mlp = mx.split(
self.linear1(x_mod),
[self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
axis=-1,
)
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
q, k = self.norm(q, k)
# compute attention
y = _attention(q, k, v, pe)
# compute activation in mlp stream, cat again and run second linear layer
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
return x + mod.gate * y
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=True
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def __call__(self, x: mx.array, vec: mx.array):
shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

76
flux/flux/lora.py Normal file
View File

@@ -0,0 +1,76 @@
# Copyright © 2024 Apple Inc.
import math
import mlx.core as mx
import mlx.nn as nn
class LoRALinear(nn.Module):
@staticmethod
def from_base(
linear: nn.Linear,
r: int = 8,
dropout: float = 0.0,
scale: float = 1.0,
):
output_dims, input_dims = linear.weight.shape
lora_lin = LoRALinear(
input_dims=input_dims,
output_dims=output_dims,
r=r,
dropout=dropout,
scale=scale,
)
lora_lin.linear = linear
return lora_lin
def fuse(self):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
dtype = weight.dtype
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = self.scale * self.lora_b.T
lora_a = self.lora_a.T
fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
if bias:
fused_linear.bias = linear.bias
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 1.0,
bias: bool = False,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, r),
)
self.lora_b = mx.zeros(shape=(r, output_dims))
def __call__(self, x):
y = self.linear(x)
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
return y + (self.scale * z).astype(x.dtype)

136
flux/flux/model.py Normal file
View File

@@ -0,0 +1,136 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from .layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
)
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
if params.guidance_embed
else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = [
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
self.single_blocks = [
SingleStreamBlock(
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
)
for _ in range(params.depth_single_blocks)
]
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
if k.startswith("model.diffusion_model."):
k = k[22:]
if k.endswith(".scale"):
k = k[:-6] + ".weight"
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
if f".{seq}." in k:
k = k.replace(f".{seq}.", f".{seq}.layers.")
break
new_weights[k] = w
return new_weights
def __call__(
self,
img: mx.array,
img_ids: mx.array,
txt: mx.array,
txt_ids: mx.array,
timesteps: mx.array,
y: mx.array,
guidance: Optional[mx.array] = None,
) -> mx.array:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = mx.concatenate([txt_ids, img_ids], axis=1)
pe = self.pe_embedder(ids).astype(img.dtype)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = mx.concatenate([txt, img], axis=1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec)
return img

57
flux/flux/sampler.py Normal file
View File

@@ -0,0 +1,57 @@
# Copyright © 2024 Apple Inc.
import math
from functools import lru_cache
import mlx.core as mx
class FluxSampler:
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.15):
self._base_shift = base_shift
self._max_shift = max_shift
self._schnell = "schnell" in name
def _time_shift(self, x, t):
x1, x2 = 256, 4096
t1, t2 = self._base_shift, self._max_shift
exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1)
t = exp_mu / (exp_mu + (1 / t - 1))
return t
@lru_cache
def timesteps(
self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
):
t = mx.linspace(start, stop, num_steps + 1)
if not self._schnell:
t = self._time_shift(image_sequence_length, t)
return t.tolist()
def random_timesteps(self, B, L, dtype=mx.float32, key=None):
if self._schnell:
# TODO: Should we upweigh 1 and 0.75?
t = mx.random.randint(1, 5, shape=(B,), key=key)
t = t.astype(dtype) / 4
else:
t = mx.random.uniform(shape=(B,), dtype=dtype, key=key)
t = self._time_shift(L, t)
return t
def sample_prior(self, shape, dtype=mx.float32, key=None):
return mx.random.normal(shape, dtype=dtype, key=key)
def add_noise(self, x, t, noise=None, key=None):
noise = (
noise
if noise is not None
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
)
t = t.reshape([-1] + [1] * (x.ndim - 1))
return x * (1 - t) + t * noise
def step(self, pred, x_t, t, t_prev):
return x_t + (t_prev - t) * pred

244
flux/flux/t5.py Normal file
View File

@@ -0,0 +1,244 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
_SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
_ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
@dataclass
class T5Config:
vocab_size: int
num_layers: int
num_heads: int
relative_attention_num_buckets: int
d_kv: int
d_model: int
feed_forward_proj: str
tie_word_embeddings: bool
d_ff: Optional[int] = None
num_decoder_layers: Optional[int] = None
relative_attention_max_distance: int = 128
layer_norm_epsilon: float = 1e-6
@classmethod
def from_dict(cls, config):
return cls(
vocab_size=config["vocab_size"],
num_layers=config["num_layers"],
num_heads=config["num_heads"],
relative_attention_num_buckets=config["relative_attention_num_buckets"],
d_kv=config["d_kv"],
d_model=config["d_model"],
feed_forward_proj=config["feed_forward_proj"],
tie_word_embeddings=config["tie_word_embeddings"],
d_ff=config.get("d_ff", 4 * config["d_model"]),
num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
relative_attention_max_distance=config.get(
"relative_attention_max_distance", 128
),
layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
)
class RelativePositionBias(nn.Module):
def __init__(self, config: T5Config, bidirectional: bool):
self.bidirectional = bidirectional
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
@staticmethod
def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
num_buckets = num_buckets // 2 if bidirectional else num_buckets
max_exact = num_buckets // 2
abspos = rpos.abs()
is_small = abspos < max_exact
scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
buckets = mx.where(is_small, abspos, buckets_large)
if bidirectional:
buckets = buckets + (rpos > 0) * num_buckets
else:
buckets = buckets * (rpos < 0)
return buckets
def __call__(self, query_length: int, key_length: int, offset: int = 0):
"""Compute binned relative position bias"""
context_position = mx.arange(offset, query_length)[:, None]
memory_position = mx.arange(key_length)[None, :]
# shape (query_length, key_length)
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
# shape (query_length, key_length, num_heads)
values = self.embeddings(relative_position_bucket)
# shape (num_heads, query_length, key_length)
return values.transpose(2, 0, 1)
class MultiHeadAttention(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
inner_dim = config.d_kv * config.num_heads
self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
def __call__(
self,
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: Optional[mx.array],
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> [mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
keys = mx.concatenate([key_cache, keys], axis=3)
values = mx.concatenate([value_cache, values], axis=2)
values_hat = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
)
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class DenseActivation(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.gated = config.feed_forward_proj.startswith("gated")
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
self.act = nn.gelu
elif activation == "silu":
self.act = nn.silu
else:
raise ValueError(f"Unknown activation: {activation}")
def __call__(self, x):
if self.gated:
hidden_act = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_act * hidden_linear
else:
x = self.act(self.wi(x))
return self.wo(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(self, x, mask):
y = self.ln1(x)
y, _ = self.attention(y, y, y, mask=mask)
x = x + y
y = self.ln2(x)
y = self.dense(y)
return x + y
class TransformerEncoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers)
]
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x: mx.array):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
pos_bias = pos_bias.astype(x.dtype)
for layer in self.layers:
x = layer(x, mask=pos_bias)
return self.ln(x)
class T5Encoder(nn.Module):
def __init__(self, config: T5Config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
for old, new in _SHARED_REPLACEMENT_PATTERNS:
k = k.replace(old, new)
if k.startswith("encoder."):
for old, new in _ENCODER_REPLACEMENT_PATTERNS:
k = k.replace(old, new)
new_weights[k] = w
return new_weights
def __call__(self, inputs: mx.array):
return self.encoder(self.wte(inputs))

185
flux/flux/tokenizers.py Normal file
View File

@@ -0,0 +1,185 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import regex
from sentencepiece import SentencePieceProcessor
class CLIPTokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab, max_length=77):
self.max_length = max_length
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
self._cache = {self.bos: self.bos, self.eos: self.eos}
@property
def bos(self):
return "<|startoftext|>"
@property
def bos_token(self):
return self.vocab[self.bos]
@property
def eos(self):
return "<|endoftext|>"
@property
def eos_token(self):
return self.vocab[self.eos]
def bpe(self, text):
if text in self._cache:
return self._cache[text]
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
unique_bigrams = set(zip(unigrams, unigrams[1:]))
if not unique_bigrams:
return unigrams
# In every iteration try to merge the two most likely bigrams. If none
# was merged we are done.
#
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams:
bigram = min(
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
)
if bigram not in self.bpe_ranks:
break
new_unigrams = []
skip = False
for a, b in zip(unigrams, unigrams[1:]):
if skip:
skip = False
continue
if (a, b) == bigram:
new_unigrams.append(a + b)
skip = True
else:
new_unigrams.append(a)
if not skip:
new_unigrams.append(b)
unigrams = new_unigrams
unique_bigrams = set(zip(unigrams, unigrams[1:]))
self._cache[text] = unigrams
return unigrams
def tokenize(self, text, prepend_bos=True, append_eos=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
# Lower case cleanup and split according to self.pat. Hugging Face does
# a much more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())
tokens = regex.findall(self.pat, clean_text)
# Split the tokens according to the byte-pair merge file
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
# Map to token ids and return
tokens = [self.vocab[t] for t in bpe_tokens]
if prepend_bos:
tokens = [self.bos_token] + tokens
if append_eos:
tokens.append(self.eos_token)
if len(tokens) > self.max_length:
tokens = tokens[: self.max_length]
if append_eos:
tokens[-1] = self.eos_token
return tokens
def encode(self, text):
if not isinstance(text, list):
return self.encode([text])
tokens = self.tokenize(text)
length = max(len(t) for t in tokens)
for t in tokens:
t.extend([self.eos_token] * (length - len(t)))
return mx.array(tokens)
class T5Tokenizer:
def __init__(self, model_file, max_length=512):
self._tokenizer = SentencePieceProcessor(model_file)
self.max_length = max_length
@property
def pad(self):
try:
return self._tokenizer.id_to_piece(self.pad_token)
except IndexError:
return None
@property
def pad_token(self):
return self._tokenizer.pad_id()
@property
def bos(self):
try:
return self._tokenizer.id_to_piece(self.bos_token)
except IndexError:
return None
@property
def bos_token(self):
return self._tokenizer.bos_id()
@property
def eos(self):
try:
return self._tokenizer.id_to_piece(self.eos_token)
except IndexError:
return None
@property
def eos_token(self):
return self._tokenizer.eos_id()
def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
tokens = self._tokenizer.encode(text)
if prepend_bos and self.bos_token >= 0:
tokens = [self.bos_token] + tokens
if append_eos and self.eos_token >= 0:
tokens.append(self.eos_token)
if pad and len(tokens) < self.max_length and self.pad_token >= 0:
tokens += [self.pad_token] * (self.max_length - len(tokens))
return tokens
def encode(self, text, pad=True):
if not isinstance(text, list):
return self.encode([text], pad=pad)
pad_token = self.pad_token if self.pad_token >= 0 else 0
tokens = self.tokenize(text, pad=pad)
length = max(len(t) for t in tokens)
for t in tokens:
t.extend([pad_token] * (length - len(t)))
return mx.array(tokens)

98
flux/flux/trainer.py Normal file
View File

@@ -0,0 +1,98 @@
import mlx.core as mx
import numpy as np
from PIL import Image, ImageFile
from tqdm import tqdm
from .datasets import Dataset
from .flux import FluxPipeline
class Trainer:
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
self.flux = flux
self.dataset = dataset
self.args = args
self.latents = []
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * b) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * c,
pan[1] * d,
crop_size[0] + pan[0] * c,
crop_size[1] + pan[1] * d,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
for i in range(num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def _encode_prompt(self, prompt):
t5_tok, clip_tok = self.flux.tokenize([prompt])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def encode_dataset(self):
"""Encode the images & prompt in the latent space to prepare for training."""
self.flux.ae.eval()
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
self._encode_image(image, self.args.num_augmentations)
self._encode_prompt(prompt)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]

230
flux/flux/utils.py Normal file
View File

@@ -0,0 +1,230 @@
# Copyright © 2024 Apple Inc.
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
import mlx.core as mx
from huggingface_hub import hf_hub_download
from .autoencoder import AutoEncoder, AutoEncoderParams
from .clip import CLIPTextModel, CLIPTextModelConfig
from .model import Flux, FluxParams
from .t5 import T5Config, T5Encoder
from .tokenizers import CLIPTokenizer, T5Tokenizer
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: Optional[str]
ae_path: Optional[str]
repo_id: Optional[str]
repo_flow: Optional[str]
repo_ae: Optional[str]
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-dev.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_DEV"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_flow="flux1-schnell.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_SCHNELL"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}
def load_flow_model(name: str, hf_download: bool = True):
# Get the safetensors file to load
ckpt_path = configs[name].ckpt_path
# Download if needed
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
# Make the model
model = Flux(configs[name].params)
# Load the checkpoint if needed
if ckpt_path is not None:
weights = mx.load(ckpt_path)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()))
return model
def load_ae(name: str, hf_download: bool = True):
# Get the safetensors file to load
ckpt_path = configs[name].ae_path
# Download if needed
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_ae is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
# Make the autoencoder
ae = AutoEncoder(configs[name].ae_params)
# Load the checkpoint if needed
if ckpt_path is not None:
weights = mx.load(ckpt_path)
weights = ae.sanitize(weights)
ae.load_weights(list(weights.items()))
return ae
def load_clip(name: str):
# Load the config
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
with open(config_path) as f:
config = CLIPTextModelConfig.from_dict(json.load(f))
# Make the clip text encoder
clip = CLIPTextModel(config)
# Load the weights
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
weights = mx.load(ckpt_path)
weights = clip.sanitize(weights)
clip.load_weights(list(weights.items()))
return clip
def load_t5(name: str):
# Load the config
config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
with open(config_path) as f:
config = T5Config.from_dict(json.load(f))
# Make the T5 model
t5 = T5Encoder(config)
# Load the weights
model_index = hf_hub_download(
configs[name].repo_id, "text_encoder_2/model.safetensors.index.json"
)
weight_files = set()
with open(model_index) as f:
for _, w in json.load(f)["weight_map"].items():
weight_files.add(w)
weights = {}
for w in weight_files:
w = f"text_encoder_2/{w}"
w = hf_hub_download(configs[name].repo_id, w)
weights.update(mx.load(w))
weights = t5.sanitize(weights)
t5.load_weights(list(weights.items()))
return t5
def load_clip_tokenizer(name: str):
vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json")
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt")
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
def load_t5_tokenizer(name: str, pad: bool = True):
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Sort the config for better readability
config = dict(sorted(config.items()))
# Write the config to the provided file
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)

7
flux/requirements.txt Normal file
View File

@@ -0,0 +1,7 @@
mlx>=0.18.1
huggingface-hub
regex
numpy
tqdm
Pillow
sentencepiece

Binary file not shown.

After

Width:  |  Height:  |  Size: 754 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 423 KiB

BIN
flux/static/dog6.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 434 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 153 KiB

150
flux/txt2image.py Normal file
View File

@@ -0,0 +1,150 @@
# Copyright © 2024 Apple Inc.
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
def to_latent_size(image_size):
h, w = image_size
h = ((h + 15) // 16) * 16
w = ((w + 15) // 16) * 16
if (h, w) != image_size:
print(
"Warning: The image dimensions need to be divisible by 16px. "
f"Changing size to {h}x{w}."
)
return (h // 8, w // 8)
def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
def load_adapter(flux, adapter_file, fuse=False):
weights, lora_config = mx.load(adapter_file, return_metadata=True)
rank = int(lora_config["lora_rank"])
num_blocks = int(lora_config["lora_blocks"])
flux.linear_to_lora_layers(rank, num_blocks)
flux.flow.load_weights(list(weights.items()), strict=False)
if fuse:
flux.fuse_lora_layers()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
parser.add_argument("--n-images", type=int, default=4)
parser.add_argument(
"--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
)
parser.add_argument("--steps", type=int)
parser.add_argument("--guidance", type=float, default=4.0)
parser.add_argument("--n-rows", type=int, default=1)
parser.add_argument("--decoding-batch-size", type=int, default=1)
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png")
parser.add_argument("--save-raw", action="store_true")
parser.add_argument("--seed", type=int)
parser.add_argument("--verbose", "-v", action="store_true")
parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
args = parser.parse_args()
# Load the models
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
args.steps = args.steps or (50 if args.model == "dev" else 2)
if args.adapter:
load_adapter(flux, args.adapter, fuse=args.fuse_adapter)
if args.quantize:
nn.quantize(flux.flow, class_predicate=quantization_predicate)
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
if args.preload_models:
flux.ensure_models_are_loaded()
# Make the generator
latent_size = to_latent_size(args.image_size)
latents = flux.generate_latents(
args.prompt,
n_images=args.n_images,
num_steps=args.steps,
latent_size=latent_size,
guidance=args.guidance,
seed=args.seed,
)
# First we get and eval the conditioning
conditioning = next(latents)
mx.eval(conditioning)
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the text encoders.
del flux.t5
del flux.clip
# Actual denoising loop
for x_t in tqdm(latents, total=args.steps):
mx.eval(x_t)
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the flow transformer.
del flux.flow
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1])
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
)
if args.save_raw:
*name, suffix = args.output.split(".")
name = ".".join(name)
x = mx.concatenate(decoded, axis=0)
x = (x * 255).astype(mx.uint8)
for i in range(len(x)):
im = Image.fromarray(np.array(x[i]))
im.save(".".join([name, str(i), suffix]))
else:
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(args.output)
# Report the peak memory used during generation
if args.verbose:
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")

View File

@@ -79,10 +79,10 @@ def load_image(image_source):
def prepare_inputs(processor, image, prompt):
if isinstance(image, str):
image = load_image(image)
inputs = processor(prompt, image, return_tensors="np")
inputs = processor(image, prompt, return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
return input_ids, pixel_values
return pixel_values, input_ids
def load_model(model_path, tokenizer_config={}):
@@ -126,8 +126,7 @@ def main():
processor, model = load_model(args.model, tokenizer_config)
prompt = codecs.decode(args.prompt, "unicode_escape")
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
pixel_values, input_ids = prepare_inputs(processor, args.image, prompt)
print(prompt)
generated_text = generate_text(

View File

@@ -68,11 +68,10 @@ class LlavaModel(nn.Module):
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
if pixel_values is None:
return inputs_embeds
# Get the ouptut hidden states from the vision model
*_, hidden_states = self.vision_tower(
@@ -105,31 +104,21 @@ class LlavaModel(nn.Module):
self, image_features, inputs_embeds, input_ids
):
image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, num_image_patches, embed_dim = image_features.shape
# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
image_positions = mx.array(
np.where(input_ids[0] == image_token_index)[0], mx.uint32
)
if len(image_positions) != num_images:
if len(image_positions) != num_image_patches:
raise ValueError(
f"The number of image tokens ({len(image_positions)}) does not "
f" match the number of image inputs ({num_images})."
f" match the number of image patches ({num_image_patches})."
)
text_segments = []
start_idx = 0
for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1
image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]
# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)
inputs_embeds[0, image_positions] = image_features
return inputs_embeds
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
input_embddings = self.get_input_embeddings(input_ids, pixel_values)

View File

@@ -16,10 +16,35 @@ conda install -c conda-forge mlx-lm
The `mlx-lm` package also has:
- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
### Quick Start
To generate text with an LLM use:
```bash
mlx_lm.generate --prompt "Hi!"
```
To chat with an LLM use:
```bash
mlx_lm.chat
```
This will give you a chat REPL that you can use to interact with the LLM. The
chat context is preserved during the lifetime of the REPL.
Commands in `mlx-lm` typically take command line options which let you specify
the model, sampling parameters, and more. Use `-h` to see a list of available
options for a command, e.g.:
```bash
mlx_lm.generate -h
```
### Python API
You can use `mlx-lm` as a module:
@@ -29,7 +54,14 @@ from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
response = generate(model, tokenizer, prompt="hello", verbose=True)
prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
text = generate(model, tokenizer, prompt=prompt, verbose=True)
```
To see a description of all the arguments you can do:
@@ -38,10 +70,14 @@ To see a description of all the arguments you can do:
>>> help(generate)
```
Check out the [generation
example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py)
to see how to use the API in more detail.
The `mlx-lm` package also comes with functionality to quantize and optionally
upload models to the Hugging Face Hub.
You can convert models in the Python API with:
You can convert models using the Python API:
```python
from mlx_lm import convert
@@ -64,8 +100,10 @@ To see a description of all the arguments you can do:
#### Streaming
For streaming generation, use the `stream_generate` function. This returns a
generator object which streams the output text. For example,
For streaming generation, use the `stream_generate` function. This yields
a generation response object.
For example,
```python
from mlx_lm import load, stream_generate
@@ -75,8 +113,13 @@ model, tokenizer = load(repo)
prompt = "Write a story about Einstein"
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True)
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(response.text, end="", flush=True)
print()
```
@@ -120,10 +163,54 @@ mlx_lm.convert \
--upload-repo mlx-community/my-4bit-mistral
```
Models can also be converted and quantized directly in the
[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
Face Space.
### Long Prompts and Generations
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
- A rotating fixed-size key-value cache.
- Prompt caching
To use the rotating key-value cache pass the argument `--max-kv-size n` where
`n` can be any integer. Smaller values like `512` will use very little RAM but
result in worse quality. Larger values like `4096` or higher will use more RAM
but have better quality.
Caching prompts can substantially speedup reusing the same long context with
different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
```bash
cat prompt.txt | mlx_lm.cache_prompt \
--model mistralai/Mistral-7B-Instruct-v0.3 \
--prompt - \
--prompt-cache-file mistral_prompt.safetensors
```
Then use the cached prompt with `mlx_lm.generate`:
```
mlx_lm.generate \
--prompt-cache-file mistral_prompt.safetensors \
--prompt "\nSummarize the above text."
```
The cached prompt is treated as a prefix to the supplied prompt. Also notice
when using a cached prompt, the model to use is read from the cache and need
not be supplied explicitly.
Prompt caching can also be used in the Python API in order to to avoid
recomputing the prompt. This is useful in multi-turn dialogues or across
requests that use the same context. See the
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
for more usage details.
### Supported Models
The example supports Hugging Face format Mistral, Llama, and Phi-2 style
models. If the model you want to run is not supported, file an
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.
@@ -140,6 +227,7 @@ Here are a few examples of Hugging Face models that work with this example:
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
@@ -167,3 +255,28 @@ model, tokenizer = load(
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
)
```
### Large Models
> [!NOTE]
This requires macOS 15.0 or higher to work.
Models which are large relative to the total RAM available on the machine can
be slow. `mlx-lm` will attempt to make them faster by wiring the memory
occupied by the model and cache. This requires macOS 15 or higher to
work.
If you see the following warning message:
> [WARNING] Generating with a model that requires ...
then the model will likely be slow on the given machine. If the model fits in
RAM then it can often be sped up by increasing the system wired memory limit.
To increase the limit, set the following `sysctl`:
```bash
sudo sysctl iogpu.wired_limit_mb=N
```
The value `N` should be larger than the size of the model in megabytes but
smaller than the memory size of the machine.

View File

@@ -57,6 +57,9 @@ mlx_lm.lora \
--iters 600
```
To fine-tune the full model weights, add the `--fine-tune-type full` flag.
Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
details on the data format see the section on [Data](#Data).
@@ -67,8 +70,8 @@ mistralai/Mistral-7B-v0.1`.
If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA.
By default, the adapter config and weights are saved in `adapters/`. You can
specify the output location with `--adapter-path`.
By default, the adapter config and learned weights are saved in `adapters/`.
You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.
@@ -118,7 +121,7 @@ mlx_lm.fuse --model <path_to_model>
```
This will by default load the adapters from `adapters/`, and save the fused
model in the path `lora_fused_model/`. All of these are configurable.
model in the path `fused_model/`. All of these are configurable.
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
@@ -141,7 +144,7 @@ mlx_lm.fuse \
--export-gguf
```
This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You
This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
can specify the file name with `--gguf-path`.
## Data
@@ -151,59 +154,146 @@ Examples GitHub repo has an [example of the WikiSQL
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
correct format.
Datasets can be specified in `*.jsonl` files locally or loaded from Hugging
Face.
### Local Datasets
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory.
loader expects a `test.jsonl` in the data directory.
Currently, `*.jsonl` files support three data formats: `chat`,
`completions`, and `text`. Here are three examples of these formats:
Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text`
data formats. Here are examples of these formats:
`chat`:
```jsonl
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
```
`tools`:
```jsonl
{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]}
```
<details>
<summary>View the expanded single data tool format</summary>
```jsonl
{
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello."
},
{
"role": "assistant",
"content": "How can I assistant you today."
}
]
"messages": [
{ "role": "user", "content": "What is the weather in San Francisco?" },
{
"role": "assistant",
"tool_calls": [
{
"id": "call_id",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"
}
}
]
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and country, eg. San Francisco, USA"
},
"format": { "type": "string", "enum": ["celsius", "fahrenheit"] }
},
"required": ["location", "format"]
}
}
}
]
}
```
The format for the `arguments` field in a function varies for different models.
Common formats include JSON strings and dictionaries. The example provided
follows the format used by
[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples)
and [Mistral
AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct).
A dictionary format is used in Hugging Face's [chat
templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example).
Refer to the documentation for the model you are fine-tuning for more details.
</details>
`completions`:
```jsonl
{
"prompt": "What is the capital of France?",
"completion": "Paris."
}
{"prompt": "What is the capital of France?", "completion": "Paris."}
```
`text`:
```jsonl
{
"text": "This is an example for the model."
}
{"text": "This is an example for the model."}
```
Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored.
For the `chat` and `completions` formats, Hugging Face [chat
templates](https://huggingface.co/blog/chat-templates) are used. This applies
the model's chat template by default. If the model does not have a chat
template, then Hugging Face will use a default. For example, the final text in
the `chat` example above with Hugging Face's default template becomes:
> [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than
> one example per line and do not split an example across multiple lines.
### Hugging Face Datasets
To use Hugging Face datasets, first install the `datasets` package:
```
pip install datasets
```
If the Hugging Face dataset is already in a supported format, you can specify
it on the command line. For example, pass `--data mlx-community/wikisql` to
train on the pre-formatted WikiwSQL data.
Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example:
```
hf_dataset:
name: "billsum"
prompt_feature: "text"
completion_feature: "summary"
```
- Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text`
dataset.
- To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument.
- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
In general, for the `chat`, `tools` and `completions` formats, Hugging Face
[chat
templates](https://huggingface.co/docs/transformers/main/en/chat_templating)
are used. This applies the model's chat template by default. If the model does
not have a chat template, then Hugging Face will use a default. For example,
the final text in the `chat` example above with Hugging Face's default template
becomes:
```text
<|im_start|>system
@@ -231,7 +321,7 @@ of memory. Here are some tips to reduce memory use should you need to do so:
setting this to `2` or `1` will reduce memory consumption. This may slow
things down a little, but will also reduce the memory use.
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
3. Reduce the number of layers to fine-tune with `--num-layers`. The default
is `16`, so you can try `8` or `4`. This reduces the amount of memory
needed for back propagation. It may also reduce the quality of the
fine-tuned model if you are fine-tuning with a lot of data.
@@ -253,7 +343,7 @@ mlx_lm.lora \
--model mistralai/Mistral-7B-v0.1 \
--train \
--batch-size 1 \
--lora-layers 4 \
--num-layers 4 \
--data wikisql
```
@@ -263,4 +353,5 @@ tokens-per-second, using the MLX Example
data set.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)

View File

@@ -17,7 +17,7 @@ mlx_lm.server --model <path_to_model_or_hf_repo>
For example:
```shell
mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1
mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
```
This will start a text generation server on port `8080` of the `localhost`
@@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
the generated prompt. If not provided, the default mappings are used.
- `stop`: (Optional) An array of strings or a single string. Thesse are
- `stop`: (Optional) An array of strings or a single string. These are
sequences of tokens on which the generation should stop.
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
@@ -73,4 +73,59 @@ curl localhost:8080/v1/chat/completions \
applying repetition penalty. Defaults to `20`.
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
values. Defaults to `None`.
values. Defaults to `None`.
- `logprobs`: (Optional) An integer specifying the number of top tokens and
corresponding log probabilities to return for each output in the generated
sequence. If set, this can be any value between 1 and 10, inclusive.
- `model`: (Optional) A string path to a local model or Hugging Face repo id.
If the path is local is must be relative to the directory the server was
started in.
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
relative to the directory the server was started in.
### Response Fields
- `id`: A unique identifier for the chat.
- `system_fingerprint`: A unique identifier for the system.
- `object`: Any of "chat.completion", "chat.completion.chunk" (for
streaming), or "text.completion".
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
- `created`: A time-stamp for when the request was processed.
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
- `index`: The index in the list.
- `logprobs`: A dictionary containing the fields:
- `token_logprobs`: A list of the log probabilities for the generated
tokens.
- `tokens`: A list of the generated token ids.
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
top tokens (if requested) with their corresponding probabilities.
- `finish_reason`: The reason the completion ended. This can be either of
`"stop"` or `"length"`.
- `message`: The text response from the model.
- `usage`: A dictionary containing the fields:
- `prompt_tokens`: The number of prompt tokens processed.
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
### List Models
Use the `v1/models` endpoint to list available models:
```shell
curl localhost:8080/v1/models -H "Content-Type: application/json"
```
This will return a list of locally available models where each model in the
list contains the following fields:
- `id`: The Hugging Face repo id.
- `created`: A time-stamp representing the model creation time.

View File

@@ -1,4 +1,9 @@
# Copyright © 2023-2024 Apple Inc.
import os
from ._version import __version__
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
from .utils import convert, generate, load, stream_generate
from .version import __version__

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.14.2"
__version__ = "0.20.4"

173
llms/mlx_lm/cache_prompt.py Normal file
View File

@@ -0,0 +1,173 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
import sys
import time
import mlx.core as mx
from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import generate_step, load
DEFAULT_QUANTIZED_KV_START = 5000
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Cache the state of a prompt to be reused with mlx_lm.generate"
)
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument(
"--eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--ignore-chat-template",
action="store_true",
help="Use the raw prompt without the tokenizer's chat template.",
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--max-kv-size",
type=int,
default=None,
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--prompt-cache-file",
help="The file to save the prompt cache in",
required=True,
)
parser.add_argument(
"--prompt",
required=True,
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for KV cache quantization.",
default=64,
)
parser.add_argument(
"--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a prefix assuming that the suffix will be
# provided at generation time.
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
prompt = prompt[:-n]
else:
prompt = args.prompt
cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
# Process the prompt
start = time.time()
max_msg_len = 0
def callback(processed, total_tokens):
current = time.time()
speed = processed / (current - start)
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
nonlocal max_msg_len
max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
for _ in generate_step(
y,
model,
max_tokens=0,
prompt_cache=cache,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
prompt_progress_callback=callback,
):
pass
print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
print("Saving...")
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
save_prompt_cache(args.prompt_cache_file, cache, metadata)
if __name__ == "__main__":
main()

91
llms/mlx_lm/chat.py Normal file
View File

@@ -0,0 +1,91 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import mlx.core as mx
from .models.cache import make_prompt_cache
from .sample_utils import make_sampler
from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
default=DEFAULT_MODEL,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
default=None,
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
)
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
if query == "q":
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate(
model,
tokenizer,
prompt,
max_tokens=args.max_tokens,
sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache,
):
print(response.text, flush=True, end="")
print()
if __name__ == "__main__":
main()

View File

@@ -29,9 +29,15 @@ def configure_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
)
parser.add_argument(
"--q-type",
choices=["affine", "affine-packed"],
default="affine",
help="The type of quantization to apply",
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
help="Type to save the non-quantized parameters.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",

355
llms/mlx_lm/evaluate.py Normal file
View File

@@ -0,0 +1,355 @@
# Adapted from a PyTorch implementation by David Grangier
import argparse
import json
import logging
import os
from importlib.metadata import version
from pathlib import Path
from typing import Optional
import lm_eval
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
from .models.cache import make_prompt_cache
from .utils import load, stream_generate
PAD = 0
def _len_longest_common_prefix(a, b):
l = 0
for item_a, item_b in zip(a, b):
if item_a != item_b:
break
l += 1
return l
def _rstrip_until(s, untils):
"""Limit a string <s> to the first occurrence of any substring in untils."""
l = len(s)
f = [s.find(u) for u in untils]
f = [l if x < 0 else x for x in f]
return s[: min(f)]
def _pad_inputs(
inputs,
maxlen,
genlen=0,
pad_left=False,
pad_multiple=32,
truncate=False,
):
# pad the prompts to the left with at least genlen tokens.
actual_maxlen = max(len(p) for p in inputs) + genlen
if actual_maxlen > maxlen:
if not truncate:
raise ValueError("Inputs are too long.")
else: # drop begining
actual_maxlen = maxlen
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
if pad_multiple > 0:
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
maxlen *= pad_multiple
assert PAD == 0
lr = np.array((1, 0) if pad_left else (0, 1))
return np.stack(
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
axis=0,
)
@register_model("mlxlm")
class MLXLM(LM):
def __init__(
self,
path_or_hf_repo: str,
batch_size: int = 16,
max_tokens: Optional[int] = None,
) -> None:
super().__init__()
self._batch_size = batch_size
self._model, self._tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self._tokenizer.model_max_length
def _score_fn(self, inputs, tokenize=True, step_size=32):
if tokenize:
inputs = self._tokenizer.encode(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:]
cache = make_prompt_cache(self._model)
mask = targets != PAD
scores, is_greedy = [], []
for i in range(0, inputs.shape[1], step_size):
logits = self._model(inputs[:, i : i + step_size], cache=cache)
log_probs = nn.log_softmax(logits.astype(mx.float32))
score = mx.take_along_axis(
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
)[..., 0]
ig = mask[:, i : i + step_size] * (
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
)
mx.eval(score, ig)
mx.metal.clear_cache()
is_greedy.append(ig)
scores.append(score)
scores = mx.concatenate(scores, axis=1)
is_greedy = mx.concatenate(is_greedy, axis=1)
return scores, mask.sum(axis=-1), is_greedy
def _loglikelihood(self, texts, score_spans=None, tokenize=True):
# sort by length to get batches with little padding.
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
sorted_spans = None
if score_spans is not None:
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
results = []
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
batch = sorted_inputs[i : i + self._batch_size]
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
for j in range(len(batch)):
if sorted_spans is None: # full sequence score
mask = mx.arange(scores[j].shape[-1]) < length
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
else: # subsequence score
start, end = sorted_spans[i + j]
score = scores[j][start:end].astype(mx.float32).sum()
ig = is_greedy[j][start:end].astype(mx.int32).sum()
length = end - start
results.append((score.item(), ig.item(), length))
# reorder the outputs
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(results))]
return results
def _tokenize(self, texts):
return [tuple(self._tokenizer.encode(t)) for t in texts]
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
:param requests: list[Instance]
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
`context: str`
Context string. Implementations of LM must be able to handle an
empty context string.
`continuation: str`
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
`logprob: float`
The log probability of `continuation`.
`isgreedy`:
Whether `continuation` would be generated by greedy sampling from `context`.
"""
logging.info("Estimating loglikelihood for %d pairs." % len(requests))
# tokenize prefix and prefix + completion for all requests.
tokenized = self._tokenize(
[t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]]
)
# max length (prefix + completion) and longest common prefix per question.
length_stats = {}
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8))
length_stats[prefix] = (
max(max_completed_l, len(completed)),
min(min_prefix_l, _len_longest_common_prefix(prefix, completed)),
)
# truncate requests for completed sequences longer than model context.
shortened = []
completion_spans = []
long_completions = 0
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
max_completed_l, prefix_l = length_stats[prefix]
# compute truncation length
truncation = max(0, max_completed_l - self._max_tokens - 1)
prefix_l = prefix_l - truncation
if prefix_l <= 0:
# completion too long, prefix is eliminated for some requests.
long_completions += 1
truncation = max(0, len(completed) - self._max_tokens - 1)
prefix_l = 1
# truncate the completed sequence
completed = completed[truncation:]
shortened.append(completed)
# scores do not include initial bos, substract 1 to span bounds
completion_spans.append((prefix_l - 1, len(completed) - 1))
if long_completions > 0:
logging.info(
f"Prefix eliminated for {long_completions} requests with "
+ "completion longer than context."
)
# model scoring, returns num_requests x (logp, is_greedy, length).
results = self._loglikelihood(
shortened,
score_spans=completion_spans,
tokenize=False,
)
return [(r[0], r[1] == r[2]) for r in results]
def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: EOT
Max context length: 4
Resulting input/prediction pairs:
INPUT: EOT 0 1 2
PRED: 0 1 2 3
INPUT: 3 4 5 6
PRED: 4 5 6 7
INPUT: 5 6 7 8
PRED: 8 9
Observe that:
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context,).
string: str
String for which we are computing overall loglikelihood
:return: list[tuple[float]]
A list of tuples (logprob,)
logprob: float
The log probability of `context` conditioned on the EOT token.
"""
logging.info(
"Estimating loglikelihood rolling for %d sequences." % len(requests)
)
inputs = [req.args[0] for req in requests]
return [t[0] for t in self._loglikelihood(inputs)]
def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, until).
context: str
Context string
until: [str]
The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token.
:return: list[str]
A list of strings continuation
continuation: str
The generated continuation.
"""
logging.info("Generating continuation for %d sequences." % len(requests))
contexts, options = zip(*[req.args for req in requests])
# contrary to the doc the second element of the tuple contains
# {'do_sample': False, 'until': ['\n\n'], 'temperature': 0}
keys = list(options[0].keys())
assert "until" in keys
untils = [x["until"] for x in options]
completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
if (
hasattr(self._tokenizer, "apply_chat_template")
and self._tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": context}]
context = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
max_tokens = min(
self._max_tokens,
self._tokenizer.model_max_length - len(self._tokenizer.encode(context)),
)
text = ""
for response in stream_generate(
self._model, self._tokenizer, prompt=context, max_tokens=max_tokens
):
text += response.text
if any(u in text for u in until):
text = _rstrip_until(text, until)
completions.append(text)
break
else:
completions.append(text)
return completions
def main():
parser = argparse.ArgumentParser(
"Evaluate an MLX model using lm-evaluation-harness."
)
parser.add_argument("--model", help="Model to evaluate", required=True)
parser.add_argument("--tasks", nargs="+", required=True)
parser.add_argument(
"--output-dir", default=".", help="Output directory for result files."
)
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
parser.add_argument("--num-shots", type=int, default=0, help="Number of shots")
parser.add_argument(
"--max-tokens",
type=int,
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
)
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
args = parser.parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Silence tokenizer warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
mx.random.seed(args.seed)
lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens)
results = lm_eval.simple_evaluate(
model=lm,
tasks=args.tasks,
num_fewshot=args.num_shots,
random_seed=args.seed,
numpy_random_seed=args.seed,
torch_random_seed=args.seed,
fewshot_random_seed=args.seed,
)
model_name = args.model.replace("/", "_")
task_names = "_".join(args.tasks)
ver = version("lm_eval")
filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json"
output_path = output_dir / filename
output_path.write_text(json.dumps(results["results"], indent=4))
print("Results:")
for result in results["results"].values():
print(json.dumps(result, indent=4))

View File

@@ -0,0 +1,52 @@
# Copyright © 2024 Apple Inc.
"""
An example of a multi-turn chat with prompt caching.
"""
from mlx_lm import generate, load
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
# Make the initial prompt cache for the model
prompt_cache = make_prompt_cache(model)
# User turn
prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)
# User turn
prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
prompt_cache=prompt_cache,
)
# Save the prompt cache to disk to reuse it at a later time
save_prompt_cache("mistral_prompt.safetensors", prompt_cache)
# Load the prompt cache from disk
prompt_cache = load_prompt_cache("mistral_prompt.safetensors")

View File

@@ -0,0 +1,33 @@
# Copyright © 2024 Apple Inc.
from mlx_lm import generate, load
# Specify the checkpoint
checkpoint = "mistralai/Mistral-7B-Instruct-v0.3"
# Load the corresponding model and tokenizer
model, tokenizer = load(path_or_hf_repo=checkpoint)
# Specify the prompt and conversation history
prompt = "Why is the sky blue?"
conversation = [{"role": "user", "content": prompt}]
# Transform the prompt into the chat template
prompt = tokenizer.apply_chat_template(
conversation=conversation, tokenize=False, add_generation_prompt=True
)
# Specify the maximum number of tokens
max_tokens = 1_000
# Specify if tokens and timing information will be printed
verbose = True
# Generate a response with the specified settings
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
verbose=verbose,
)

View File

@@ -1,8 +1,12 @@
# The path to the local model directory or Hugging Face repo.
model: "mlx_model"
# Whether or not to train (boolean)
train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora
# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"
@@ -10,7 +14,7 @@ data: "/path/to/training/data"
seed: 0
# Number of layers to fine-tune
lora_layers: 16
num_layers: 16
# Minibatch size.
batch_size: 4
@@ -51,9 +55,6 @@ max_seq_length: 2048
# Use gradient checkpointing to reduce memory use.
grad_checkpoint: false
# Use DoRA instead of LoRA.
use_dora: false
# LoRA parameters can only be specified in a config file
lora_parameters:
# The layer keys to apply LoRA to.
@@ -69,3 +70,11 @@ lora_parameters:
# warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
#hf_dataset:
# name: "billsum"
# train_split: "train[:1000]"
# valid_split: "train[-100:]"
# prompt_feature: "text"
# completion_feature: "summary"

View File

@@ -6,9 +6,9 @@ from pathlib import Path
from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf
from .tuner.dora import DoRALinear
from .tuner.lora import LoRALinear, LoRASwitchLinear
from .tuner.utils import apply_lora_layers, dequantize
from .tuner.dora import DoRAEmbedding, DoRALinear
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
from .tuner.utils import dequantize, load_adapters
from .utils import (
fetch_from_hub,
get_model_path,
@@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace:
)
parser.add_argument(
"--save-path",
default="lora_fused_model",
default="fused_model",
help="The path to save the fused model.",
)
parser.add_argument(
@@ -77,15 +77,14 @@ def main() -> None:
model, config, tokenizer = fetch_from_hub(model_path)
model.freeze()
model = apply_lora_layers(model, args.adapter_path)
model = load_adapters(model, args.adapter_path)
fused_linears = [
(n, m.to_linear())
for n, m in model.named_modules()
if isinstance(m, (LoRASwitchLinear, LoRALinear, DoRALinear))
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
]
model.update_modules(tree_unflatten(fused_linears))
if fused_linears:
model.update_modules(tree_unflatten(fused_linears))
if args.de_quantize:
print("De-quantizing model")

View File

@@ -1,17 +1,29 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import codecs
import json
import sys
import mlx.core as mx
from .models.cache import QuantizedKVCache, load_prompt_cache
from .sample_utils import make_sampler
from .utils import generate, load
DEFAULT_MODEL_PATH = "mlx_model"
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_MIN_P = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
def str2bool(string):
return string.lower() not in ["false", "f"]
def setup_arg_parser():
@@ -20,19 +32,17 @@ def setup_arg_parser():
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
help=(
"The path to the local model directory or Hugging Face repo. "
f"If no model is specified, then {DEFAULT_MODEL} is used."
),
default=None,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument(
"--eos-token",
type=str,
@@ -40,7 +50,15 @@ def setup_arg_parser():
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
"--system-prompt",
default=None,
help="System prompt to be used for the chat template",
)
parser.add_argument(
"--prompt",
"-p",
default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
@@ -55,6 +73,15 @@ def setup_arg_parser():
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument(
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
)
parser.add_argument(
"--min-tokens-to-keep",
type=int,
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--ignore-chat-template",
@@ -67,63 +94,91 @@ def setup_arg_parser():
help="Use the default chat template",
)
parser.add_argument(
"--colorize",
action="store_true",
help="Colorize output based on T[0] probability",
"--verbose",
type=str2bool,
default=True,
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
)
parser.add_argument(
"--cache-limit-gb",
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
default=None,
help="Set the MLX cache limit in GB",
required=False,
)
parser.add_argument(
"--prompt-cache-file",
type=str,
default=None,
help="A file containing saved KV caches to avoid recomputing them",
)
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for KV cache quantization.",
default=64,
)
parser.add_argument(
"--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
return parser
def colorprint(color, s):
color_codes = {
"black": 30,
"red": 31,
"green": 32,
"yellow": 33,
"blue": 34,
"magenta": 35,
"cyan": 36,
"white": 39,
}
ccode = color_codes.get(color, 30)
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
def colorprint_by_t0(s, t0):
if t0 > 0.95:
color = "white"
elif t0 > 0.70:
color = "green"
elif t0 > 0.30:
color = "yellow"
else:
color = "red"
colorprint(color, s)
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None
if using_cache:
prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file,
return_metadata=True,
)
if isinstance(prompt_cache[0], QuantizedKVCache):
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
raise ValueError(
"--kv-bits does not match the kv cache loaded from --prompt-cache-file."
)
if args.kv_group_size != prompt_cache[0].group_size:
raise ValueError(
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
tokenizer_config = (
{} if not using_cache else json.loads(metadata["tokenizer_config"])
)
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model_path = args.model
if using_cache:
if model_path is None:
model_path = metadata["model"]
elif model_path != metadata["model"]:
raise ValueError(
f"Providing a different model ({model_path}) than that "
f"used to create the prompt cache ({metadata['model']}) "
"is an error."
)
model_path = model_path or DEFAULT_MODEL
model, tokenizer = load(
args.model,
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
@@ -131,30 +186,56 @@ def main():
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
prompt = codecs.decode(args.prompt, "unicode_escape")
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}]
if args.system_prompt is not None:
messages = [{"role": "system", "content": args.system_prompt}]
else:
messages = []
messages.append(
{
"role": "user",
"content": sys.stdin.read() if prompt == "-" else prompt,
}
)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
prompt = args.prompt
formatter = colorprint_by_t0 if args.colorize else None
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if using_cache:
messages[-1]["content"] = "<query>"
test_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
generate(
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate(
model,
tokenizer,
prompt,
args.max_tokens,
verbose=True,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_tokens=args.max_tokens,
verbose=args.verbose,
sampler=sampler,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
)
if not args.verbose:
print(response)
if __name__ == "__main__":

View File

@@ -59,7 +59,7 @@ class HfVocab:
for token_id in range(self.vocab_size_base):
if token_id in self.added_tokens_ids:
continue
token_text = reverse_vocab[token_id].encode("utf-8")
token_text = reverse_vocab[token_id]
yield token_text, self.get_token_score(token_id), self.get_token_type(
token_id, token_text, self.special_ids
)
@@ -67,7 +67,7 @@ class HfVocab:
def get_token_type(
self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text):
return TokenType.BYTE
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
@@ -77,14 +77,12 @@ class HfVocab:
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
for text in self.added_tokens_list:
if text in self.specials:
toktype = self.get_token_type(
self.specials[text], b"", self.special_ids
)
toktype = self.get_token_type(self.specials[text], "", self.special_ids)
score = self.get_token_score(self.specials[text])
else:
toktype = TokenType.USER_DEFINED
score = -1000.0
yield text.encode("utf-8"), score, toktype
yield text, score, toktype
def has_newline_token(self):
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
@@ -243,15 +241,18 @@ def prepare_metadata(config, vocab):
metadata["tokenizer.ggml.tokens"] = tokens
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
vocab.tokenizer.bos_token_id, dtype=mx.uint32
)
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
vocab.tokenizer.eos_token_id, dtype=mx.uint32
)
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
vocab.tokenizer.unk_token_id, dtype=mx.uint32
)
if vocab.tokenizer.bos_token_id is not None:
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
vocab.tokenizer.bos_token_id, dtype=mx.uint32
)
if vocab.tokenizer.eos_token_id is not None:
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
vocab.tokenizer.eos_token_id, dtype=mx.uint32
)
if vocab.tokenizer.unk_token_id is not None:
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
vocab.tokenizer.unk_token_id, dtype=mx.uint32
)
metadata = {k: v for k, v in metadata.items() if v is not None}
return metadata

View File

@@ -15,9 +15,9 @@ from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
apply_lora_layers,
build_schedule,
linear_to_lora_layers,
load_adapters,
print_trainable_parameters,
)
from .utils import load, save_config
@@ -41,9 +41,10 @@ yaml_loader.add_implicit_resolver(
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"data": "data/",
"seed": 0,
"lora_layers": 16,
"num_layers": 16,
"batch_size": 4,
"iters": 1000,
"val_batches": 25,
@@ -58,7 +59,6 @@ CONFIG_DEFAULTS = {
"max_seq_length": 2048,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"use_dora": False,
}
@@ -79,10 +79,20 @@ def build_parser():
parser.add_argument(
"--data",
type=str,
help="Directory with {train, valid, test}.jsonl files",
help=(
"Directory with {train, valid, test}.jsonl files or the name "
"of a Hugging Face dataset (e.g., 'mlx-community/wikisql')"
),
)
parser.add_argument(
"--lora-layers",
"--fine-tune-type",
type=str,
choices=["lora", "dora", "full"],
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--num-layers",
type=int,
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
)
@@ -107,12 +117,12 @@ def build_parser():
parser.add_argument(
"--resume-adapter-file",
type=str,
help="Load path to resume training with the given adapters.",
help="Load path to resume training from the given fine-tuned weights.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Save/load path for the adapters.",
help="Save/load path for the fine-tuned weights.",
)
parser.add_argument(
"--save-every",
@@ -148,9 +158,6 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
parser.add_argument(
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
)
return parser
@@ -162,21 +169,31 @@ def train_model(
valid_set,
training_callback: TrainingCallback = None,
):
# Freeze all layers
model.freeze()
if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]:
l.unfreeze()
elif args.fine_tune_type in ["lora", "dora"]:
# Convert linear layers to lora/dora layers and unfreeze in the process
linear_to_lora_layers(
model,
args.num_layers,
args.lora_parameters,
use_dora=(args.fine_tune_type == "dora"),
)
else:
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
# Convert linear layers to lora layers and unfreeze in the process
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
# Resume training the given adapters.
# Resume from weights if provided
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
print_trainable_parameters(model)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
@@ -240,7 +257,7 @@ def run(args, training_callback: TrainingCallback = None):
if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path
if args.adapter_path != "":
apply_lora_layers(model, args.adapter_path)
load_adapters(model, args.adapter_path)
elif args.train:
print("Training")

View File

@@ -1,46 +1,13 @@
# Copyright © 2023-2024 Apple Inc.
import inspect
from dataclasses import dataclass
from typing import Any, Optional
import mlx.core as mx
from mlx.utils import tree_map
def create_additive_causal_mask(N: int, offset: int = 0):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
return mask * -1e9
class KVCache:
def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
n_steps = (self.step + keys.shape[2] - 1) // self.step
shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim)
new_k = mx.zeros(shape, keys.dtype)
new_v = mx.zeros(shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
from .cache import QuantizedKVCache
@dataclass
@@ -54,3 +21,93 @@ class BaseModelArgs:
if k in inspect.signature(cls).parameters
}
)
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None]
rinds = rinds[None]
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if hasattr(c, "max_size"):
offset = min(c.max_size, c.offset)
window_size = c.max_size
else:
offset = c.offset
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None
return mask
def quantized_scaled_dot_product_attention(
queries: mx.array,
q_keys: tuple[mx.array, mx.array, mx.array],
q_values: tuple[mx.array, mx.array, mx.array],
scale: float,
mask: Optional[mx.array],
group_size: int = 64,
bits: int = 8,
) -> mx.array:
B, n_q_heads, L, D = queries.shape
n_kv_heads = q_keys[0].shape[-3]
n_repeats = n_q_heads // n_kv_heads
queries *= scale
if n_repeats > 1:
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
scores = mx.quantized_matmul(
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
)
if mask is not None:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = mx.quantized_matmul(
scores, *q_values, transpose=False, group_size=group_size, bits=bits
)
if n_repeats > 1:
out = mx.reshape(out, (B, n_q_heads, L, D))
return out
def scaled_dot_product_attention(
queries,
keys,
values,
cache,
scale: float,
mask: Optional[mx.array],
) -> mx.array:
if isinstance(cache, QuantizedKVCache):
return quantized_scaled_dot_product_attention(
queries,
keys,
values,
scale=scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
return mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=scale, mask=mask
)

438
llms/mlx_lm/models/cache.py Normal file
View File

@@ -0,0 +1,438 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.
Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
"""
Save a pre-computed prompt cache to a file.
Args:
file_name (str): The ``.safetensors`` file name.
cache (List[Any]): The model state.
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
cache_data = [c.state for c in cache]
cache_info = [c.meta_state for c in cache]
cache_data = dict(tree_flatten(cache_data))
cache_classes = [type(c).__name__ for c in cache]
cache_metadata = [cache_info, metadata, cache_classes]
cache_metadata = dict(tree_flatten(cache_metadata))
mx.save_safetensors(file_name, cache_data, cache_metadata)
def load_prompt_cache(file_name, return_metadata=False):
"""
Load a prompt cache from a file.
Args:
file_name (str): The ``.safetensors`` file name.
return_metadata (bool): Whether or not to return metadata.
Default: ``False``.
Returns:
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
the metadata if requested.
"""
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
info, metadata, classes = cache_metadata
cache = [globals()[c]() for c in classes]
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
c.meta_state = meta_state
if return_metadata:
return cache, metadata
return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
This function will trim the cache if possible (in-place) and return the
number of tokens that were trimmed.
Args:
cache (List[Any]): The model's cache.
num_tokens (int): The number of tokens to trim.
Returns:
(int): The number of tokens that were trimmed.
"""
if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]
class _BaseCache:
@property
def state(self):
return []
@state.setter
def state(self, v):
if v is not None and v:
raise ValueError("This cache has no state but a state was set.")
@property
def meta_state(self):
return ""
@meta_state.setter
def meta_state(self, v):
if v is not None and v:
raise ValueError("This cache has no meta_state but a meta_state was set.")
def is_trimmable(self):
return False
class QuantizedKVCache(_BaseCache):
def __init__(self, group_size: int = 64, bits: int = 8):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
self.group_size = group_size
self.bits = bits
def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape
v_head_dim = values.shape[-1]
prev = self.offset
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
el_per_int = 8 * mx.uint32.size // self.bits
new_steps = (self.step + num_steps - 1) // self.step * self.step
shape = (B, n_kv_heads, new_steps)
def init_quant(dim):
return (
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
)
def expand_quant(x):
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
return mx.concatenate([x, new_x], axis=-2)
if self.keys is not None:
if prev % self.step != 0:
self.keys, self.values = tree_map(
lambda x: x[..., :prev, :], (self.keys, self.values)
)
self.keys, self.values = tree_map(
expand_quant, (self.keys, self.values)
)
else:
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
self.offset += num_steps
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
for i in range(len(self.keys)):
self.keys[i][..., prev : self.offset, :] = keys[i]
self.values[i][..., prev : self.offset, :] = values[i]
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
@property
def state(self):
if self.offset == self.keys[0].shape[2]:
return self.keys, self.values
else:
return tree_map(
lambda x: x[..., : self.offset, :], (self.keys, self.values)
)
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
@meta_state.setter
def meta_state(self, v):
self.step, self.offset, self.group_size, self.bits = map(int, v)
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
class KVCache(_BaseCache):
def __init__(self):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B, n_kv_heads, _, k_head_dim = keys.shape
v_head_dim = values.shape[3]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
if self.offset == self.keys.shape[2]:
return self.keys, self.values
else:
return (
self.keys[..., : self.offset, :],
self.values[..., : self.offset, :],
)
@state.setter
def state(self, v):
self.keys, self.values = v
self.offset = self.keys.shape[2]
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
quant_cache.offset = self.offset
if self.keys is not None:
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(
self.values, group_size=group_size, bits=bits
)
return quant_cache
class RotatingKVCache(_BaseCache):
def __init__(self, max_size=None, keep=0, step=256):
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def _temporal_order(self, v):
"""
Rearrange the cache into temporal order, slicing off the end if unused.
"""
if self._idx == v.shape[2]:
return v
elif self._idx < self.offset:
return mx.concatenate(
[
v[..., : self.keep, :],
v[..., self._idx :, :],
v[..., self.keep : self._idx, :],
],
axis=2,
)
else:
return v[..., : self._idx, :]
def _update_concat(self, keys, values):
if self.keys is None:
self.keys = keys
self.values = values
else:
# Put the keys/values in temporal order to
# preserve context
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)
# The largest size is self.max_size + S to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]
self._idx = self.keys.shape[2]
return self.keys, self.values
def _update_in_place(self, keys, values):
# May not have hit the max size yet, so potentially
# keep growing the cache
B, n_kv_heads, S, k_head_dim = keys.shape
prev = self.offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
v_head_dim = values.shape[3]
new_size = min(self.step, self.max_size - prev)
k_shape = (B, n_kv_heads, new_size, k_head_dim)
v_shape = (B, n_kv_heads, new_size, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + S, :] = keys
self.values[..., self._idx : self._idx + S, :] = values
self.offset += S
self._idx += S
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def update_and_fetch(self, keys, values):
if keys.shape[2] == 1:
return self._update_in_place(keys, values)
return self._update_concat(keys, values)
@property
def state(self):
if self.offset < self.keys.shape[2]:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
else:
return self.keys, self.values
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
)
@meta_state.setter
def meta_state(self, v):
self.keep, self.max_size, self.step, self.offset, self._idx = map(
int,
v,
)
def is_trimmable(self):
return self.offset < self.max_size
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
self._idx -= n
return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
raise NotImplementedError("RotatingKVCache Quantization NYI")
class MambaCache(_BaseCache):
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
@state.setter
def state(self, v):
self.cache = v

View File

@@ -1,10 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -67,7 +69,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -91,8 +93,8 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
@@ -127,7 +129,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
@@ -157,10 +159,7 @@ class CohereModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -191,11 +190,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -0,0 +1,207 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 4096
head_dim: int = 128
num_hidden_layers: int = 32
intermediate_size: int = 14336
num_attention_heads: int = 32
num_key_value_heads: int = 8
rope_theta: float = 50000.0
vocab_size: int = 256000
layer_norm_eps: float = 1e-05
logit_scale: float = 0.0625
attention_bias: bool = False
layer_norm_bias: bool = False
sliding_window: int = 4096
sliding_window_pattern: int = 4
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.args = args
self.layer_idx = layer_idx
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim
if (head_dim * n_heads) != dim:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}"
f" and `num_heads`: {n_heads})."
)
self.scale = head_dim**-0.5
attetion_bias = args.attention_bias
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
# Apply RoPE only if sliding window is enabled
if self.use_sliding_window:
if cache is None:
queries = self.rope(queries)
keys = self.rope(keys)
else:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
if self.use_sliding_window and mask is not None:
key_len = keys.shape[-2]
if mask.shape[-1] != key_len:
mask = mask[..., -key_len:]
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
def __call__(self, x):
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.hidden_size = args.hidden_size
self.n_heads = args.num_attention_heads
self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x
class CohereModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args, layer_idx=i)
for i in range(args.num_hidden_layers)
]
self.norm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
T = h.shape[1]
if T > 1:
offset = cache[0].offset if cache else 0
mask = create_causal_mask(T, offset).astype(h.dtype)
else:
mask = None
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = CohereModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out
def make_cache(self):
caches = []
for i in range(self.args.num_hidden_layers):
if (
i % self.args.sliding_window_pattern
== self.args.sliding_window_pattern - 1
):
caches.append(KVCache())
else:
caches.append(
RotatingKVCache(max_size=self.args.sliding_window, keep=0)
)
return caches
@property
def layers(self):
return self.model.layers

View File

@@ -1,11 +1,13 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -47,7 +49,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
qkv = self.Wqkv(x)
@@ -72,8 +74,8 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output)
@@ -90,7 +92,7 @@ class NormAttnNorm(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
x = h + x
@@ -177,7 +179,7 @@ class DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r, h = self.norm_attn_norm(x, mask, cache)
out = self.ffn(h) + r
@@ -199,11 +201,7 @@ class DBRX(nn.Module):
):
h = self.wte(inputs)
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
@@ -251,11 +249,3 @@ class Model(nn.Module):
experts = [(s, sv.T) for s, sv in experts]
new_weights.update(experts)
return new_weights
@property
def head_dim(self):
return self.args.d_model // self.args.n_heads
@property
def n_kv_heads(self):
return self.args.attn_config["kv_n_heads"]

View File

@@ -0,0 +1,258 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
attention_bias: bool = False
class DeepseekAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim**-0.5
attention_bias = getattr(config, "attention_bias", False)
self.q_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
self.k_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.o_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
rope_scale = 1.0
if config.rope_scaling and config.rope_scaling["type"] == "linear":
assert isinstance(config.rope_scaling["factor"], float)
rope_scale = 1 / config.rope_scaling["factor"]
self.rope = nn.RoPE(
self.head_dim,
base=config.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
0, 2, 1, 3
)
keys = keys.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekMLP(nn.Module):
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
):
super().__init__()
self.config = config
self.hidden_size = hidden_size or config.hidden_size
self.intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(scores, inds, axis=-1)
return inds, scores
class DeepseekMoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekMLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekAttention(config)
self.mlp = (
DeepseekMoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekMLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class DeepseekModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for m in ["gate_proj", "down_proj", "up_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,417 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek_v2"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
routed_scaling_factor: float = 1.0
kv_lora_rank: int = 512
q_lora_rank: int = 1536
qk_rope_head_dim: int = 64
v_head_dim: int = 128
qk_nope_head_dim: int = 128
topk_method: str = "gready"
n_group: Optional[int] = None
topk_group: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Dict = None
attention_bias: bool = False
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
return mx.clip(linear_func, 0, 1)
class DeepseekV2YarnRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
super().__init__()
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
scaling_factor, mscale_all_dim
)
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
freq_inter = scaling_factor * base ** (
mx.arange(0, dim, 2, dtype=mx.float32) / dim
)
low, high = yarn_find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_position_embeddings,
)
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
self._freqs = (freq_inter * freq_extra) / (
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
)
def __call__(self, x, offset=0):
if self.mscale != 1.0:
x = self.mscale * x
return mx.fast.rope(
x,
x.shape[-1],
traditional=True,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
class DeepseekV2Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.scale = self.q_head_dim**-0.5
if self.q_lora_rank is None:
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
)
else:
self.q_a_proj = nn.Linear(
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
self.q_b_proj = nn.Linear(
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=config.attention_bias,
)
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self.rope = DeepseekV2YarnRotaryEmbedding(
dim=self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**rope_kwargs,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekV2MLP(nn.Module):
def __init__(
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = (
config.intermediate_size if intermediate_size is None else intermediate_size
)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x):
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.softmax(gates, axis=-1, precise=True)
if self.topk_method == "group_limited_greedy":
bsz, seq_len = x.shape[:2]
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = scores.max(axis=-1)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
scores = scores * self.routed_scaling_factor
return inds, scores
class DeepseekV2MoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekV2DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekV2Attention(config)
self.mlp = (
DeepseekV2MoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekV2MLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class DeepseekV2Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekV2DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekV2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,163 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_layers: int
intermediate_size: int
num_attention_heads: int
vocab_size: int
rope_theta: float
layer_norm_epsilon: float
num_key_value_heads: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
attention_bias: bool = False
mlp_bias: bool = False
class AttentionModule(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or (dim // n_heads)
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__(
self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None
) -> mx.array:
B, L, D = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
q = self.rope(q, offset=cache.offset)
k = self.rope(k, offset=cache.offset)
k, v = cache.update_and_fetch(k, v)
else:
q = self.rope(q)
k = self.rope(k)
out = scaled_dot_product_attention(
q, k, v, cache=cache, scale=self.scale, mask=mask
)
out = out.transpose(0, 2, 1, 3).reshape(B, L, D)
return self.out_proj(out)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention = AttentionModule(args)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias)
def __call__(self, x: mx.array) -> mx.array:
return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = Attention(args)
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = MLP(args)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = x + self.attn.attention(self.ln_1(x), mask, cache)
out = h + self.mlp(self.ln_2(h))
return out
class ExaoneModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [TransformerBlock(args) for _ in range(args.num_layers)]
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
h = layer(h, mask, cache=c)
return self.ln_f(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = ExaoneModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.transformer(inputs, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.transformer.h

View File

@@ -1,10 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -58,7 +60,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -77,8 +79,8 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
@@ -111,7 +113,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -141,10 +143,7 @@ class GemmaModel(nn.Module):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -174,11 +173,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -0,0 +1,200 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
head_dim: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
rope_theta: float = 10000
rope_traditional: bool = False
attn_logit_softcapping: float = 50.0
final_logit_softcapping: float = 30.0
query_pre_attn_scalar: float = 144.0
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
self.head_dim = head_dim = args.head_dim
self.scale = 1.0 / (args.query_pre_attn_scalar**0.5)
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.attn_logit_softcapping = args.attn_logit_softcapping
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries * self.scale
if self.repeats > 1:
queries = queries.reshape(
B, self.n_kv_heads, self.repeats, L, self.head_dim
)
keys = mx.expand_dims(keys, 2)
values = mx.expand_dims(values, 2)
scores = queries @ keys.swapaxes(-1, -2)
scores = mx.tanh(scores / self.attn_logit_softcapping)
scores *= self.attn_logit_softcapping
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, precise=True, axis=-1)
output = scores @ values
if self.repeats > 1:
output = output.reshape(B, self.n_heads, L, self.head_dim)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
class GemmaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.final_logit_softcapping = args.final_logit_softcapping
self.model = GemmaModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
out = mx.tanh(out / self.final_logit_softcapping)
out = out * self.final_logit_softcapping
return out
@property
def layers(self):
return self.model.layers

View File

@@ -1,11 +1,13 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_additive_causal_mask
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -44,7 +46,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -59,8 +61,8 @@ class Attention(nn.Module):
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
@@ -98,7 +100,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
@@ -136,10 +138,7 @@ class GPT2Model(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_additive_causal_mask(
hidden_states.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(hidden_states.dtype)
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
@@ -197,11 +196,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.h
@property
def head_dim(self):
return self.args.n_embd // self.args.n_head
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,11 +1,13 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_additive_causal_mask
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -55,7 +57,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -72,8 +74,8 @@ class Attention(nn.Module):
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
@@ -112,7 +114,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
@@ -147,10 +149,7 @@ class GPTBigCodeModel(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_additive_causal_mask(
hidden_states.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(hidden_states.dtype)
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
@@ -185,11 +184,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.n_embd // self.args.n_head
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -0,0 +1,216 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
# Based on the transformers implementation at:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
max_position_embeddings: int
hidden_size: int
num_attention_heads: int
num_hidden_layers: int
layer_norm_eps: float
vocab_size: int
rotary_emb_base: int
rotary_pct: float
num_key_value_heads: int = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert (
args.hidden_size % args.num_attention_heads == 0
), "hidden_size must be divisible by num_attention_heads"
self.hidden_size = args.hidden_size
self.num_attention_heads = args.num_attention_heads
self.head_dim = self.hidden_size // self.num_attention_heads
self.rope = nn.RoPE(
dims=int(self.head_dim * args.rotary_pct),
traditional=False,
base=args.rotary_emb_base,
)
self.scale = self.head_dim**-0.5
self.query_key_value = nn.Linear(
self.hidden_size, 3 * self.hidden_size, bias=True
)
self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.query_key_value(x)
new_qkv_shape = qkv.shape[:-1] + (self.num_attention_heads, 3 * self.head_dim)
qkv = qkv.reshape(*new_qkv_shape)
queries, keys, values = [x.transpose(0, 2, 1, 3) for x in qkv.split(3, -1)]
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
def __call__(self, x) -> mx.array:
# gelu_approx corresponds to FastGELUActivation in transformers.
return self.dense_4h_to_h(nn.gelu_approx(self.dense_h_to_4h(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.layer_norm_eps = args.layer_norm_eps
self.attention = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.LayerNorm(
self.hidden_size,
eps=self.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm(
self.hidden_size, eps=self.layer_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
residual = x
# NeoX runs attention and feedforward network in parallel.
attn = self.attention(self.input_layernorm(x), mask, cache)
ffn = self.mlp(self.post_attention_layernorm(x))
out = attn + ffn + residual
return out
class GPTNeoXModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.layer_norm_eps = args.layer_norm_eps
assert self.vocab_size > 0
self.embed_in = nn.Embedding(self.vocab_size, self.hidden_size)
self.embed_out = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.h = [TransformerBlock(args=args) for _ in range(self.num_hidden_layers)]
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
_, L = inputs.shape
hidden_states = self.embed_in(inputs)
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
out = self.final_layer_norm(hidden_states)
out = self.embed_out(out)
return out
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GPTNeoXModel(args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return out
def sanitize(self, weights):
new_weights = {}
for w_key, w_value in weights.items():
# Created through register_buffer in Pytorch, not needed here.
ignore_suffixes = [
".attention.bias",
".attention.masked_bias",
".attention.rotary_emb.inv_freq",
]
skip_weight = False
for ignored_suffix in ignore_suffixes:
if w_key.endswith(ignored_suffix):
skip_weight = True
break
if skip_weight:
continue
if not w_key.startswith("model."):
w_key = f"model.{w_key}"
w_key = w_key.replace(".gpt_neox.layers.", ".h.")
w_key = w_key.replace(".gpt_neox.", ".")
new_weights[w_key] = w_value
return new_weights
@property
def layers(self):
return self.model.h

View File

@@ -0,0 +1,291 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
attention_bias: bool
moe_topk: int
num_experts: int
num_shared_expert: int
use_mixed_mlp_moe: bool
use_qk_norm: bool
rms_norm_eps: float
rope_theta: float
use_cla: bool
cla_share_factor: 2
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
class DynamicNTKAlphaRoPE(nn.Module):
def __init__(
self,
dims: int,
base: float = 10000,
scaling_alpha: float = 1.0,
):
super().__init__()
self.dims = dims
base = base * scaling_alpha ** (dims / (dims - 2))
self._freqs = base ** (mx.arange(0, self.dims, 2) / self.dims)
def __call__(self, x, offset: int = 0):
return mx.fast.rope(
x,
self.dims,
traditional=False,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
class Attention(nn.Module):
def __init__(self, kv_proj: bool, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
if kv_proj:
self.k_proj = nn.Linear(
dim, n_kv_heads * head_dim, bias=args.attention_bias
)
self.v_proj = nn.Linear(
dim, n_kv_heads * head_dim, bias=args.attention_bias
)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
self.use_qk_norm = args.use_qk_norm
if self.use_qk_norm:
self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
self.rope = DynamicNTKAlphaRoPE(
head_dim,
base=args.rope_theta,
scaling_alpha=args.rope_scaling["alpha"],
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
kv_states=None,
) -> mx.array:
B, L, D = x.shape
queries = self.q_proj(x)
if kv_states is None:
keys, values = self.k_proj(x), self.v_proj(x)
kv_states = keys, values
else:
keys, values = kv_states
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
offset = cache.offset if cache else 0
queries = self.rope(queries, offset=offset)
keys = self.rope(keys, offset=offset)
if self.use_qk_norm:
queries = self.query_layernorm(queries)
keys = self.key_layernorm(keys)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), kv_states
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class Gate(nn.Module):
def __init__(self, dim, num_experts):
super().__init__()
self.wg = nn.Linear(dim, num_experts, bias=False)
def __call__(self, x) -> mx.array:
return self.wg(x)
class MoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
intermediate_size = args.intermediate_size
self.use_shared_mlp = args.use_mixed_mlp_moe
if args.use_mixed_mlp_moe:
self.shared_mlp = MLP(dim, intermediate_size * args.num_shared_expert)
self.num_experts = num_experts = args.num_experts
self.top_k = args.moe_topk
self.gate = Gate(dim, num_experts)
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
def __call__(
self,
x: mx.array,
):
gates = self.gate(x)
gates = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.use_shared_mlp:
shared_expert_output = self.shared_mlp(x)
y = y + shared_expert_output
return y
class DecoderLayer(nn.Module):
def __init__(self, args: ModelArgs, kv_proj: bool):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = Attention(kv_proj, args)
self.mlp = MoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
shared_kv_states: Optional[Tuple[mx.array, mx.array]] = None,
):
r, shared_kv_states = self.self_attn(
self.input_layernorm(x), mask, cache, shared_kv_states
)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, shared_kv_states
class HunYuanModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0)
for i in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for i, (layer, c) in enumerate(zip(self.layers, cache)):
if i % self.args.cla_share_factor == 0:
shared_kv_states = None
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = HunYuanModel(args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights):
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -1,10 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -17,6 +19,7 @@ class ModelArgs(BaseModelArgs):
rms_norm_eps: float
vocab_size: int
bias: bool = True
max_position_embeddings: int = 32768
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
@@ -32,8 +35,50 @@ class ModelArgs(BaseModelArgs):
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
if self.rope_scaling["type"] not in ["linear", "dynamic"]:
raise ValueError(
"rope_scaling 'type' currently only supports 'linear' or 'dynamic"
)
class DynamicNTKScalingRoPE(nn.Module):
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
):
super().__init__()
self.max_position_embeddings = max_position_embeddings
self.original_base = base
self.dims = dims
self.traditional = traditional
self.scale = scale
def extra_repr(self):
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
def __call__(self, x, offset: int = 0):
seq_len = x.shape[1] + offset
if seq_len > self.max_position_embeddings:
base = self.original_base * (
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
) ** (self.dims / (self.dims - 2))
else:
base = self.original_base
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=base,
scale=self.scale,
offset=offset,
)
class Attention(nn.Module):
@@ -56,10 +101,12 @@ class Attention(nn.Module):
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
else 2.0
)
self.rope = nn.RoPE(
self.rope = DynamicNTKScalingRoPE(
head_dim,
max_position_embeddings=args.max_position_embeddings,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
@@ -69,7 +116,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -94,8 +141,8 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output)
@@ -124,7 +171,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attention(self.attention_norm(x), mask, cache)
h = x + r
@@ -150,10 +197,7 @@ class InternLM2Model(nn.Module):
):
h = self.tok_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -185,14 +229,10 @@ class Model(nn.Module):
out = self.output(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,10 +1,13 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_additive_causal_mask
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
@@ -16,6 +19,8 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
num_key_value_heads: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
@@ -28,14 +33,6 @@ class ModelArgs(BaseModelArgs):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
@@ -45,7 +42,8 @@ class Attention(nn.Module):
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
@@ -57,23 +55,19 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -92,9 +86,10 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
@@ -135,7 +130,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -164,12 +159,7 @@ class LlamaModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = create_additive_causal_mask(
h.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -210,11 +200,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

228
llms/mlx_lm/models/mamba.py Normal file
View File

@@ -0,0 +1,228 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
state_size: int
num_hidden_layers: int
conv_kernel: int
use_bias: bool
use_conv_bias: bool
time_step_rank: int
tie_word_embeddings: bool = True
use_bcdt_rms: bool = False
mixer_rms_eps: float = 1e-6
def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
self.hidden_size = self.d_model
if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"):
self.intermediate_size = self.d_inner
if not hasattr(self, "state_size") and hasattr(self, "d_state"):
self.state_size = self.d_state
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"):
self.num_hidden_layers = self.n_layer
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"):
self.num_hidden_layers = self.n_layers
if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"):
self.conv_kernel = self.d_conv
if not hasattr(self, "use_bias") and hasattr(self, "bias"):
self.use_bias = self.bias
if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"):
self.use_conv_bias = self.conv_bias
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
if self.model_type == "falcon_mamba":
self.use_bcdt_rms = True
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = mx.random.normal((self.channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
groups, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
if self.bias is not None:
y = y + self.bias
return y, x[:, -K + 1 :, :]
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.ssm_state_size = args.state_size
self.conv_kernel_size = args.conv_kernel
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
self.use_bcdt_rms = args.use_bcdt_rms
if self.use_bcdt_rms:
self.mixer_norm = lambda x: mx.fast.rms_norm(
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
)
self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
)
self.conv1d = DepthWiseConv1d(
channels=self.intermediate_size,
kernel_size=self.conv_kernel_size,
bias=self.use_conv_bias,
padding=self.conv_kernel_size - 1,
)
self.x_proj = nn.Linear(
self.intermediate_size,
self.time_step_rank + 2 * self.ssm_state_size,
bias=False,
)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
A = mx.repeat(
mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]),
repeats=self.intermediate_size,
axis=0,
)
self.A_log = mx.log(A)
self.D = mx.ones([self.intermediate_size])
self.out_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=args.use_bias
)
def ssm_step(self, x, state=None):
A = -mx.exp(self.A_log)
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = mx.split(
deltaBC,
indices_or_sections=[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1,
)
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
y = y + D * x
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1])
z_t = nn.silu(z_t)
output_t = y_t * z_t
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = MambaBlock(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers

View File

@@ -1,11 +1,13 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -83,7 +85,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
):
B, L, _ = x.shape
@@ -103,8 +105,8 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
attn_output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
attn_output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
@@ -133,7 +135,7 @@ class DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
@@ -160,10 +162,7 @@ class MiniCPMModel(nn.Module):
):
h = self.embed_tokens(inputs) * self.args.scale_emb
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -206,11 +205,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,11 +1,13 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@@ -64,7 +66,7 @@ class MixtralAttention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -85,8 +87,8 @@ class MixtralAttention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
@@ -136,7 +138,7 @@ class MixtralDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -164,11 +166,7 @@ class MixtralModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -217,11 +215,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -0,0 +1,217 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
hidden_act: str
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
norm_eps: float
vocab_size: int
num_key_value_heads: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
partial_rotary_factor: float = 0.5
rope_theta: float = 10000.0
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.rope_scaling:
if not "factor" in self.rope_scaling:
raise ValueError(f"rope_scaling must contain 'factor'")
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
"rope_type"
)
if rope_type is None:
raise ValueError(
f"rope_scaling must contain either 'type' or 'rope_type'"
)
if rope_type not in ["linear"]:
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
@partial(mx.compile, shapeless=True)
def relu_squared(x):
return nn.relu(x).square()
class NemotronLayerNorm1P(nn.LayerNorm):
def __call__(self, x):
weight = self.weight + 1 if "weight" in self else None
bias = self.bias if "bias" in self else None
return mx.fast.layer_norm(x, weight, bias, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.partial_rotary_factor = args.partial_rotary_factor
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
int(self.partial_rotary_factor * self.head_dim),
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
mlp_bias = args.mlp_bias
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(relu_squared(self.up_proj(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
self.post_attention_layernorm = NemotronLayerNorm1P(
args.hidden_size, eps=args.norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class NemotronModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = NemotronModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.model.layers

View File

@@ -1,17 +1,19 @@
# Copyright © 2023-2024 Apple Inc.
import sys
from dataclasses import dataclass
from sys import exit
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
sys.exit(1)
@dataclass
@@ -66,7 +68,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -96,7 +98,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attend(self.att_norm(x), mask, cache)
h = x + r
@@ -126,10 +128,7 @@ class Transformer(nn.Module):
):
h = self.wte(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
@@ -175,11 +174,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.transformer.blocks
@property
def head_dim(self):
return self.args.d_model // self.args.n_heads
@property
def n_kv_heads(self):
return self.args.n_heads

209
llms/mlx_lm/models/olmo2.py Normal file
View File

@@ -0,0 +1,209 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
num_key_value_heads: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = self.q_norm(queries)
keys = self.k_norm(keys)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.post_attention_layernorm(self.self_attn(x, mask, cache))
h = x + r
r = self.post_feedforward_layernorm(self.mlp(h))
out = h + r
return out
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = LlamaModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers

View File

@@ -1,10 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -78,7 +80,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -105,8 +107,8 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
@@ -150,7 +152,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.attn_norm(x), mask, cache)
h = x + r
@@ -180,10 +182,7 @@ class OpenELMModel(nn.Module):
):
h = self.token_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -219,11 +218,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_kv_heads

View File

@@ -1,182 +0,0 @@
from dataclasses import dataclass
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_additive_causal_mask
@dataclass
class ParamsArgs(BaseModelArgs):
dim: int
ffn_type: str
n_heads: int
n_layers: int
norm_eps: float
positional_embedding_type: str
post_embed_norm: bool
qk_norm: bool
vocab_size: int
weight_tying: bool
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
params_args_dict: ParamsArgs
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.n_heads = args.n_heads
self.head_dim = self.dim // self.n_heads
self.qk_norm = args.qk_norm
self.scale = self.head_dim**-0.5
self.in_proj = nn.Linear(self.dim, 3 * self.dim, bias=False)
self.out_proj = nn.Linear(self.dim, self.dim, bias=False)
if self.qk_norm:
self.q_norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=False)
self.k_norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=False)
self.rope = nn.RoPE(
self.head_dim,
traditional=False,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.in_proj(x).split(3, axis=-1)
if self.qk_norm:
queries = self.q_norm(queries)
keys = self.q_norm(keys)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
# https://github.com/mlfoundations/open_lm/blob/c65b43042ff31c0fe26f930decf1ccab1b03ab4b/open_lm/model.py#L254C2-L254C3
hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
self.w12 = nn.Linear(args.dim, 2 * hidden_dim, bias=False)
self.w3 = nn.Linear(hidden_dim, args.dim, bias=False)
def __call__(self, x) -> mx.array:
gate, x = self.w12(x).split(2, axis=-1)
return self.w3(nn.silu(gate) * x)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention = Attention(args)
self.feed_forward = MLP(args)
self.ffn_norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=False)
self.attention_norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=False)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out
class OpenLM(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=False)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
_, L = inputs.shape
h = self.tok_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = create_additive_causal_mask(
h.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.output(self.norm(h))
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
args.params_args_dict = ParamsArgs.from_dict(args.params_args_dict)
self.args = args.params_args_dict
self.model_type = args.model_type
self.model = OpenLM(self.args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {k: v for k, v in weights.items() if "inv_freq" not in k}
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.dim // self.args.n_heads
@property
def n_kv_heads(self):
return self.args.n_heads

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Tuple
@@ -5,7 +7,7 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -91,8 +93,13 @@ class PhiAttention(nn.Module):
keys = self.rope(keys)
scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
output = scaled_dot_product_attention(
queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1)
@@ -138,14 +145,12 @@ class PhiModel(nn.Module):
def __call__(self, x, cache):
x = self.embed_tokens(x)
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.layers)
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
for layer, c in zip(self.layers, cache):
x = layer(x, mask, c)
return self.final_layernorm(x)
@@ -162,19 +167,11 @@ class Model(nn.Module):
def __call__(
self,
x: mx.array,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
y = self.model(x, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,10 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding
@@ -17,10 +19,10 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
@@ -33,9 +35,9 @@ class ModelArgs(BaseModelArgs):
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] not in ["su", "linear"]:
if self.rope_scaling["type"] not in ["longrope", "su", "linear"]:
print(
"[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false."
"[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false."
)
self.rope_scaling = None
@@ -46,6 +48,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.num_hidden_layers = args.num_hidden_layers
@@ -56,20 +59,19 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "su":
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding(
head_dim,
traditional=False,
base=args.rope_theta,
scale=rope_scale,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
)
else:
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
@@ -82,7 +84,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -105,8 +107,8 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
@@ -141,7 +143,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -170,10 +172,7 @@ class Phi3Model(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -203,11 +202,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,11 +1,14 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -19,14 +22,14 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: int
mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0
blocksparse_block_size: int = (64,)
blocksparse_block_size: int = 64
blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8
@@ -157,7 +160,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -185,8 +188,8 @@ class Attention(nn.Module):
queries, keys, values, scale=self.scale, mask=mask
)
else:
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(output)
@@ -226,7 +229,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -261,10 +264,7 @@ class Phi3Model(nn.Module):
if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -303,16 +303,8 @@ class Model(nn.Module):
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -0,0 +1,211 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "phimoe"
vocab_size: int = 32064
hidden_size: int = 4096
intermediate_size: int = 6400
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
rms_norm_eps: float = 1e-6
rope_scaling: Dict[str, Union[float, List[float]]] = None
num_local_experts: int = 16
num_experts_per_tok: int = 2
rope_theta: float = 10000.0
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
self.rope = SuScaledRotaryEmbedding(
head_dim,
base=args.rope_theta,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
short_mscale=args.rope_scaling["short_mscale"],
long_mscale=args.rope_scaling["long_mscale"],
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache=None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class PhiMoESparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_dim = args.hidden_size
self.ffn_dim = args.intermediate_size
self.num_experts = args.num_local_experts
self.top_k = args.num_experts_per_tok
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
def __call__(self, x: mx.array) -> mx.array:
gates = self.gate(x)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y
class PhiMoEDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.block_sparse_moe = PhiMoESparseMoeBlock(args)
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache=None,
) -> mx.array:
residual = x
hidden_states = self.input_layernorm(x)
hidden_states = self.self_attn(hidden_states, mask=mask, cache=cache)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class PhiMoEModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [PhiMoEDecoderLayer(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.args = args
self.model = PhiMoEModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
)
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
mx.stack(to_join)
)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import inspect
import math
from dataclasses import dataclass
@@ -6,6 +8,7 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchMLP
@@ -68,8 +71,13 @@ class RoPEAttention(nn.Module):
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
output = scaled_dot_product_attention(
queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1)
@@ -165,12 +173,9 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
return self.lm_head(y)
@@ -193,11 +198,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.model_dim // self.args.num_heads
@property
def n_kv_heads(self):
return self.args.num_heads

View File

@@ -1,11 +1,13 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -60,8 +62,8 @@ class Attention(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
cache: Optional[Any] = None,
) -> mx.array:
bsz, q_len, _ = hidden_states.shape
queries = self.q_proj(hidden_states)
@@ -87,10 +89,14 @@ class Attention(nn.Module):
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
output = mx.fast.scaled_dot_product_attention(
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
output = scaled_dot_product_attention(
queries,
keys,
values,
cache=cache,
scale=self.scale,
mask=attention_mask,
)
@@ -125,8 +131,8 @@ class PlamoDecoderLayer(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[Any, ...]:
cache: Optional[Any] = None,
):
# from LlamaDecoder
residual = hidden_states
@@ -167,14 +173,11 @@ class PlamoModel(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None,
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(self.embed_tokens.weight.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None for _ in range(len(self.layers.layers))]
@@ -198,19 +201,11 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
) -> Tuple[mx.array, mx.array]:
cache: Optional[Any] = None,
) -> mx.array:
out = self.model(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads // self.args.n_shared_head

View File

@@ -1,10 +1,11 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -63,8 +64,8 @@ class Attention(nn.Module):
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
@@ -122,11 +123,7 @@ class QwenModel(nn.Module):
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
mask = None
T = x.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(x.dtype)
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.h)
@@ -151,19 +148,11 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
y = self.transformer(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads

View File

@@ -1,10 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -16,7 +18,7 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@@ -41,6 +43,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
@@ -67,7 +70,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -86,8 +89,8 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
@@ -121,7 +124,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -150,10 +153,7 @@ class Qwen2Model(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -196,11 +196,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,11 +1,13 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@@ -22,7 +24,7 @@ class ModelArgs(BaseModelArgs):
shared_expert_intermediate_size: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@@ -47,6 +49,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
@@ -67,7 +70,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -86,8 +89,8 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
@@ -159,7 +162,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -188,10 +191,7 @@ class Qwen2MoeModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -236,11 +236,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -0,0 +1,456 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import List, Literal, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import MambaCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
attention_bias: bool
conv1d_width: int
hidden_size: int
intermediate_size: int
logits_soft_cap: float
num_attention_heads: int
num_hidden_layers: int
num_key_value_heads: int
rms_norm_eps: float
rope_theta: float
attention_window_size: int
vocab_size: int
embeddings_scale_by_sqrt_dim: bool = True
block_types: Optional[List[str]] = None
_block_types: Optional[List[str]] = None
def __post_init__(self):
# For some reason these have different names in 2B and 9B
if self.block_types is None:
self.block_types = self._block_types
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
def rnn_scan(x, a, h0):
assert x.ndim == 3
assert a.shape == x.shape[-a.ndim :]
assert a.dtype == x.dtype
if x.shape[1] == 1:
# Using scan in sampling mode.
if h0 is None:
return x, x[:, 0]
else:
y = a * h0[:, None] + x
return y, y[:, -1]
else:
# Using scan in linear mode.
if h0 is not None:
h_t = h0
else:
B, _, D = x.shape
h_t = mx.zeros((B, D), dtype=x.dtype)
y = mx.zeros_like(x)
for t in range(x.shape[1]):
h_t = a[:, t] * h_t + x[:, t]
y[:, t] = h_t
return y, h_t
class Conv1d(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int,
):
super().__init__()
self.weight = mx.zeros((channels, kernel_size, 1))
self.bias = mx.zeros((channels,))
def __call__(self, x, cache=None):
B, L, C = x.shape
groups, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
y = y + self.bias
return y, x[:, -K + 1 :, :]
class RGLRU(nn.Module):
"""A Real-Gated Linear Recurrent Unit (RG-LRU) layer."""
def __init__(
self,
width: int,
num_heads: int,
):
super().__init__()
self.width = width
self.num_heads = num_heads
self.head_dim = self.width // self.num_heads
self.recurrent_param = mx.zeros((self.width,))
self.input_gate_weight = mx.zeros(
(self.num_heads, self.head_dim, self.head_dim),
)
self.input_gate_bias = mx.zeros((self.num_heads, self.head_dim))
self.recurrent_gate_weight = mx.zeros(
(self.num_heads, self.head_dim, self.head_dim),
)
self.recurrent_gate_bias = mx.zeros((self.num_heads, self.head_dim))
def __call__(
self,
x: mx.array,
cache=None,
):
B, L, _ = x.shape
def apply_block_linear(h, w, b):
h = h.reshape((B, L, self.num_heads, self.head_dim))
h = (h.swapaxes(1, 2) @ w).swapaxes(1, 2) + b
return mx.sigmoid(h.flatten(2, 3))
# Gates for x and a.
gate_x = apply_block_linear(x, self.input_gate_weight, self.input_gate_bias)
gate_a = apply_block_linear(
x, self.recurrent_gate_weight, self.recurrent_gate_bias
)
# Compute the parameter `A` of the recurrence.
log_a = -8.0 * gate_a * nn.softplus(self.recurrent_param)
a = mx.exp(log_a)
a_square = mx.exp(2 * log_a)
# Gate the input.
gated_x = x * gate_x
# Apply gamma normalization to the input.
multiplier = mx.sqrt(1 - a_square)
if cache is None:
multiplier[:, 0, :] = 1.0
normalized_x = gated_x * multiplier.astype(x.dtype)
y, last_h = rnn_scan(
x=normalized_x,
a=a,
h0=cache,
)
return y, last_h
class RecurrentBlock(nn.Module):
def __init__(
self,
width: int,
num_heads: int,
lru_width: int = None,
conv1d_temporal_width: int = 4,
):
super().__init__()
self.width = width
self.num_heads = num_heads
self.lru_width = lru_width or width
self.conv1d_temporal_width = conv1d_temporal_width
self.linear_y = nn.Linear(width, self.lru_width)
self.linear_x = nn.Linear(width, self.lru_width)
self.linear_out = nn.Linear(self.lru_width, width)
self.conv_1d = Conv1d(
channels=self.lru_width,
kernel_size=self.conv1d_temporal_width,
)
self.rg_lru = RGLRU(
width=self.lru_width,
num_heads=self.num_heads,
)
def __call__(
self,
x: mx.array,
cache=None,
mask=None,
):
# y branch.
y = self.linear_y(x)
y = nn.gelu_approx(y)
# x branch.
x = self.linear_x(x)
if cache is None:
cache = [None, None]
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
x = x * y
x = self.linear_out(x)
return x
class LocalAttentionBlock(nn.Module):
def __init__(
self,
width: int,
num_heads: int,
window_size: int,
):
super().__init__()
self.width = width
self.num_heads = num_heads
self.window_size = window_size
self.scale = (width // num_heads) ** (-0.5)
self.head_dim = self.width // self.num_heads
self.q_proj = nn.Linear(self.width, self.width, bias=False)
self.k_proj = nn.Linear(self.width, self.head_dim, bias=False)
self.v_proj = nn.Linear(self.width, self.head_dim, bias=False)
self.o_proj = nn.Linear(self.width, self.width, bias=True)
self.rope = nn.RoPE(
self.head_dim // 2,
traditional=False,
)
def __call__(
self,
x: mx.array,
cache=None,
mask=None,
):
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLPBlock(nn.Module):
def __init__(self, width: int, expanded_width: int):
super().__init__()
self.up_proj = nn.Linear(width, expanded_width // 2)
self.gate_proj = nn.Linear(width, expanded_width // 2)
self.down_proj = nn.Linear(expanded_width // 2, width)
def __call__(self, x: mx.array):
gate = self.gate_proj(x)
x = self.up_proj(x)
return self.down_proj(nn.gelu_approx(gate) * x)
class ResidualBlock(nn.Module):
def __init__(
self,
width: int,
mlp_expanded_width: int,
num_heads: int,
attention_window_size: int,
temporal_block_type: str,
lru_width: Optional[int] = None,
conv1d_temporal_width: int = 4,
):
"""Initializes the residual block.
Args:
width: The width of the block.
mlp_expanded_width: The width of the expansion inside the MLP block.
num_heads: The number of heads for the Attention or the RG-LRU.
attention_window_size: The window size for the local attention block.
temporal_block_type: Either "recurrent" or "attention", specifying the
type of recurrent block to use.
lru_width: The width of the RG-LRU if different from `width`.
conv1d_temporal_width: The width of the temporal convolution.
"""
super().__init__()
self.width = width
self.mlp_expanded_width = mlp_expanded_width
self.num_heads = num_heads
self.attention_window_size = attention_window_size
self.temporal_block_type = temporal_block_type
self.lru_width = lru_width
self.conv1d_temporal_width = conv1d_temporal_width
self.temporal_pre_norm = RMSNorm(width)
if self.temporal_block_type == "recurrent":
self.temporal_block = RecurrentBlock(
width=self.width,
num_heads=self.num_heads,
lru_width=self.lru_width,
conv1d_temporal_width=self.conv1d_temporal_width,
)
else:
self.temporal_block = LocalAttentionBlock(
width=self.width,
num_heads=self.num_heads,
window_size=self.attention_window_size,
)
self.channel_pre_norm = RMSNorm(width)
self.mlp_block = MLPBlock(
width=self.width,
expanded_width=self.mlp_expanded_width,
)
def __call__(
self,
x: mx.array,
cache=None,
mask=None,
):
raw_x = x
inputs_normalized = self.temporal_pre_norm(raw_x)
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
residual = x + raw_x
x = self.channel_pre_norm(residual)
x = self.mlp_block(x)
x = x + residual
return x
class Griffin(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(
config.vocab_size,
config.hidden_size,
)
self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
block_types = config.block_types
self.layers = [
ResidualBlock(
width=config.hidden_size,
mlp_expanded_width=config.intermediate_size,
num_heads=config.num_attention_heads,
attention_window_size=config.attention_window_size,
temporal_block_type=block_types[i % len(block_types)],
lru_width=None,
)
for i in range(config.num_hidden_layers)
]
self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
tokens,
cache=None,
):
x = self.embed_tokens(tokens)
if self.scale_by_sqrt_dim:
x = x * math.sqrt(x.shape[-1])
if cache is None:
cache = [None] * len(self.layers)
for i, block in enumerate(self.layers):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
return self.final_norm(x)
class Model(nn.Module):
def __init__(self, config):
self.args = config
self.model = Griffin(config)
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
"""
Args:
tokens: Sequence of input tokens.
"""
logits = self.model(tokens, cache=cache)
if "lm_head" in self:
logits = self.lm_head(logits)
else:
logits = self.model.embed_tokens.as_linear(logits)
c = self.args.logits_soft_cap
if c:
logits = mx.tanh(logits / c) * c
return logits
@property
def layers(self):
return self.model.layers
def sanitize(self, weights):
for k, v in weights.items():
if "conv_1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights:
self.pop("lm_head")
return weights
def make_cache(self):
cache = []
for layer in self.layers:
if layer.temporal_block_type == "recurrent":
cache.append(MambaCache())
else:
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
return cache

View File

@@ -0,0 +1,91 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
class Llama3RoPE(nn.Module):
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scaling_config: dict = None,
):
super().__init__()
self.dims = dims
self.max_position_embeddings = max_position_embeddings
self.traditional = traditional
factor = scaling_config["factor"]
low_freq_factor = scaling_config.get("low_freq_factor", 1.0)
high_freq_factor = scaling_config.get("high_freq_factor", 4.0)
old_context_len = scaling_config.get(
"original_max_position_embeddings",
8192,
)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
freqs = base ** (mx.arange(0, dims, 2) / dims)
wavelens = 2 * mx.pi * freqs
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
def extra_repr(self):
return (
f"{self.dims}, traditional={self.traditional}, "
f"max_position_embeddings={self.max_position_embeddings}"
)
def __call__(self, x, offset: int = 0):
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
def initialize_rope(
dims,
base,
traditional,
scaling_config: Optional[dict] = None,
max_position_embeddings: Optional[int] = None,
):
if scaling_config is not None:
rope_type = scaling_config.get("type") or scaling_config.get(
"rope_type", "default"
)
else:
rope_type = "default"
if rope_type in ["default", "linear"]:
scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0
return nn.RoPE(dims, traditional=traditional, base=base, scale=scale)
elif rope_type == "llama3":
return Llama3RoPE(
dims=dims,
max_position_embeddings=max_position_embeddings,
traditional=traditional,
base=base,
scaling_config=scaling_config,
)
else:
raise ValueError(f"Unsupported RoPE type {rope_type}")

View File

@@ -1,11 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -119,8 +120,8 @@ class Attention(nn.Module):
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=scale, mask=mask
).astype(values.dtype)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
@@ -196,24 +197,12 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,10 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -43,7 +45,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -62,8 +64,8 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
@@ -98,7 +100,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -127,10 +129,7 @@ class Starcoder2Model(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -165,11 +164,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,30 +1,30 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import List, Union
import mlx.core as mx
import mlx.nn as nn
class SuScaledRotaryEmbedding:
class SuScaledRotaryEmbedding(nn.Module):
def __init__(
self,
dims: int,
traditional: bool = False,
base: float = 10000.0,
scale: float = 1.0,
max_position_embeddings: int = 131072,
original_max_position_embeddings: int = 4096,
short_factor: Union[List[float], float] = 1.0,
long_factor: Union[List[float], float] = 1.0,
short_mscale: float = None,
long_mscale: float = None,
):
"""
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
Args:
dims (int): The feature dimensions to be rotated.
traditional (bool, optional): Unused. Default: ``False``.
base (int, optional): Base for the exponential scaling.
scale (float, optional): The scale used to scale the positions.
Default: ``1.0``.
max_position_embeddings (int, optional): The maximum sequence
length that this model was trained with. This is used to determine
the size of the original RoPE embeddings when using long scaling.
@@ -39,41 +39,26 @@ class SuScaledRotaryEmbedding:
long_factor (float or list[float], optional): List of scaling
factors for sequences of length greater than
``original_max_position_embeddings``. Default: ``1.0``.
short_mscale (float, optional): Scale the input prior to embedding.
long_mscale (float, optional): Scale the input prior to embedding.
"""
self.inv_freq_short = 1.0 / (
mx.array(short_factor, dtype=mx.float32)
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
)
self.inv_freq_long = 1.0 / (
scale
* mx.array(long_factor, dtype=mx.float32)
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
)
super().__init__()
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
self.original_max_position_embeddings = original_max_position_embeddings
self.scaling_factor = math.sqrt(
self.scale = long_mscale or math.sqrt(
1
+ math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings)
)
def _get_cos_sin(self, offset, L):
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)
inv_freq = (
self.inv_freq_long
if (offset + L) > self.original_max_position_embeddings
else self.inv_freq_short
)
freqs = position_ids[:, None] * inv_freq[None, :]
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.cos(emb) * self.scaling_factor
sin = mx.sin(emb) * self.scaling_factor
return cos, sin
def __call__(self, x, offset: int = 0):
def _rotate_half(_x):
midpoint = _x.shape[-1] // 2
x1, x2 = _x[..., :midpoint], _x[..., midpoint:]
return mx.concatenate([-x2, x1], axis=-1)
cos, sin = self._get_cos_sin(offset, x.shape[2])
return (x * cos) + (_rotate_half(x) * sin)
return mx.fast.rope(
self.scale * x,
x.shape[-1],
traditional=False,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math
import mlx.core as mx

View File

@@ -1,6 +1,6 @@
mlx>=0.14.1
mlx>=0.19.2
numpy
transformers>=4.39.3
transformers[sentencepiece]>=4.39.3
protobuf
pyyaml
jinja2

View File

@@ -1,6 +1,142 @@
# Copyright © 2023-2024 Apple Inc.
import math
from functools import partial
from typing import Callable, Dict, Optional
import mlx.core as mx
def make_sampler(
temp: float = 0.0,
top_p: float = 0.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
) -> Callable[mx.array, mx.array]:
"""
Make a sampler function for use with ``generate_step``.
Args:
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
Returns:
Callable[mx.array, mx.array]:
A sampler which takes log-probabilities and returns tokens.
"""
if temp == 0:
return lambda x: mx.argmax(x, axis=-1)
elif top_p > 0 and top_p < 1.0:
return lambda x: top_p_sampling(x, top_p, temp)
elif min_p != 0.0:
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
else:
return lambda x: categorical_sampling(x, temp)
def make_logits_processors(
logit_bias: Optional[Dict[int, float]] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
):
"""
Make logits processors for use with ``generate_step``.
Args:
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
logit_bias (dictionary, optional): Additive logit bias.
Returns:
List[Callable[[mx.array, mx.array], mx.array]]:
A list of logits processors. Each processor in the list is a
callable which takes an array of tokens and an array of logits
and returns the updated logits.
"""
logits_processors = []
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits
logits_processors.append(logit_bias_processor)
if repetition_penalty and repetition_penalty != 0.0:
logits_processors.append(
make_repetition_penalty(repetition_penalty, repetition_context_size)
)
return logits_processors
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling(
logprobs: mx.array,
min_p: float,
min_tokens_to_keep: int = 1,
temperature=1.0,
) -> mx.array:
"""
Apply min-p sampling to the logits.
Min-p keeps all tokens that are above a minimum probability, scaled by the
probability of the most likely token. As a result, the filter is more
aggressive given a very high-probability token.
Args:
logprobs: A vector of log probabilities.
min_p (float): Minimum token probability. Typical values are in the
0.01-0.2 range, comparably selective as setting `top_p` in the
0.99-0.8 range.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered. Default: ``1``.
"""
if not (0 <= min_p <= 1.0):
raise ValueError(
f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}"
)
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(
f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
)
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
logprobs = logprobs * (1 / temperature)
# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logprobs).squeeze(0)
sorted_logprobs = logprobs[..., sorted_indices]
# Top probability
top_logprobs = logprobs[..., sorted_indices[0]]
# Calculate the min_p threshold
scaled_min_p = top_logprobs + math.log(min_p)
# Mask tokens that have a probability less than the scaled min_p
tokens_to_remove = sorted_logprobs < scaled_min_p
tokens_to_remove[..., :min_tokens_to_keep] = False
# Create pool of tokens with probability less than scaled min_p
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
# Return sampled token
sorted_token = mx.random.categorical(selected_logprobs)
return sorted_indices[sorted_token]
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.
@@ -13,7 +149,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
token selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits / temperature, axis=-1)
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
@@ -25,10 +161,48 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
mx.zeros_like(sorted_probs),
0,
)
sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
return token
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def categorical_sampling(logits, temp):
return mx.random.categorical(logits * (1 / temp))
def make_repetition_penalty(penalty: float, context_size: int = 20):
"""
Make repetition penalty processor.
Paper: https://arxiv.org/abs/1909.05858
Args:
penalty (float): The repetition penalty factor to be applied.
context_size (int): The number of previous tokens to use.
Default: ``20``.
Returns:
Callable[[mx.array, List[int]], mx.array]:
The repetition penalty processor.
"""
if penalty < 0 or not isinstance(penalty, (int, float)):
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
def repetition_penalty_processor(tokens, logits):
if len(tokens) > 0:
tokens = tokens[-context_size:]
selected_logits = logits[:, tokens]
selected_logits = mx.where(
selected_logits < 0,
selected_logits * penalty,
selected_logits / penalty,
)
logits[:, tokens] = selected_logits
return logits
return repetition_penalty_processor

View File

@@ -3,17 +3,37 @@
import argparse
import json
import logging
import platform
import time
import uuid
import warnings
from dataclasses import dataclass, field
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import List, Literal, NamedTuple, Optional, Union
from pathlib import Path
from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import scan_cache_dir
from .tokenizer_utils import TokenizerWrapper
from .utils import generate_step, load
from ._version import __version__
from .models.cache import make_prompt_cache
from .sample_utils import make_logits_processors, make_sampler
from .utils import load, stream_generate
def get_system_fingerprint():
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
class StopCondition(NamedTuple):
@@ -27,21 +47,25 @@ def stopping_criteria(
eos_token_id: Union[int, None],
) -> StopCondition:
"""
Determines whether the token generation should stop based on predefined conditions.
Determines whether the token generation should stop based on predefined
conditions.
Args:
tokens (List[int]): The current sequence of generated tokens.
stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs.
If the end of the `tokens` list matches any of these sequences, the generation should stop.
eos_token_id (Union[int, None]): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this,
the generation should stop.
stop_id_sequences (List[List[[int]]): A list of integer lists, each
representing a sequence of token IDs. If the end of the `tokens`
list matches any of these sequences, the generation should stop.
eos_token_id (Union[int, None]): The token ID that represents the
end-of-sequence. If the last token in `tokens` matches this, the
generation should stop.
Returns:
StopCondition: A named tuple indicating whether the stop condition has been met (`stop_met`)
and how many tokens should be trimmed from the end if it has (`trim_length`).
StopCondition: A named tuple indicating whether the stop condition has
been met (`stop_met`) and how many tokens should be trimmed from the
end if it has (`trim_length`).
"""
if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=1)
return StopCondition(stop_met=True, trim_length=0)
for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids):
@@ -51,9 +75,27 @@ def stopping_criteria(
return StopCondition(stop_met=False, trim_length=0)
def sequence_overlap(s1: Sequence, s2: Sequence) -> bool:
"""
Checks if a suffix of s1 has overlap with a prefix of s2
Args:
s1 (Sequence): The first sequence
s2 (Sequence): The second sequence
Returns:
bool: If the two sequences have overlap
"""
max_overlap = min(len(s1), len(s2))
return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1))
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
default_role_mapping = {
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
"system_prompt": (
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant follows the given rules no matter what."
),
"system": "ASSISTANT's RULE: ",
"user": "USER: ",
"assistant": "ASSISTANT: ",
@@ -72,14 +114,90 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
model_key: Tuple[str, Optional[str]] = ("", None)
tokens: List[int] = field(default_factory=list)
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
self.cli_args = cli_args
self.model_key = None
self.model = None
self.tokenizer = None
# Preload the default model if it is provided
if self.cli_args.model is not None:
self.load("default_model")
def _validate_model_path(self, model_path: str):
model_path = Path(model_path)
if model_path.exists() and not model_path.is_relative_to(Path.cwd()):
raise RuntimeError(
"Local models must be relative to the current working dir."
)
# Added in adapter_path to load dynamically
def load(self, model_path, adapter_path=None):
if self.model_key == (model_path, adapter_path):
return self.model, self.tokenizer
# Remove the old model if it exists.
self.model = None
self.tokenizer = None
self.model_key = None
# Building tokenizer_config
tokenizer_config = {
"trust_remote_code": True if self.cli_args.trust_remote_code else None
}
if self.cli_args.chat_template:
tokenizer_config["chat_template"] = self.cli_args.chat_template
if model_path == "default_model" and self.cli_args.model is not None:
model, tokenizer = load(
self.cli_args.model,
adapter_path=(
adapter_path if adapter_path else self.cli_args.adapter_path
), # if the user doesn't change the model but adds an adapter path
tokenizer_config=tokenizer_config,
)
else:
self._validate_model_path(model_path)
model, tokenizer = load(
model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config
)
if self.cli_args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
self.model_key = (model_path, adapter_path)
self.model = model
self.tokenizer = tokenizer
return self.model, self.tokenizer
class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs):
def __init__(
self,
model_provider: ModelProvider,
*args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs,
):
"""
Create static request specific metadata
"""
self.model = model
self.tokenizer = tokenizer
self.created = int(time.time())
self.model_provider = model_provider
self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs)
def _set_cors_headers(self):
@@ -109,6 +227,7 @@ class APIHandler(BaseHTTPRequestHandler):
endpoints = {
"/v1/completions": self.handle_text_completions,
"/v1/chat/completions": self.handle_chat_completions,
"/chat/completions": self.handle_chat_completions,
}
if self.path not in endpoints:
@@ -129,18 +248,34 @@ class APIHandler(BaseHTTPRequestHandler):
# Extract request parameters from the body
self.stream = self.body.get("stream", False)
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.max_tokens = self.body.get("max_tokens", 100)
self.temperature = self.body.get("temperature", 1.0)
self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 0.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20)
self.logit_bias = self.body.get("logit_bias", None)
self.logprobs = self.body.get("logprobs", -1)
self.validate_model_parameters()
# Load the model if needed
try:
self.model, self.tokenizer = self.model_provider.load(
self.requested_model, self.adapter
)
except:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
return
# Get stop id sequences, if provided
stop_words = self.body.get("stop", [])
stop_words = self.body.get("stop")
stop_words = stop_words or []
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
self.tokenizer.encode(stop_word, add_special_tokens=False)
@@ -156,10 +291,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Call endpoint specific method
prompt = endpoints[self.path]()
# Call method based on response type
method = self.handle_stream if self.stream else self.handle_completion
method(prompt, stop_id_sequences)
self.handle_completion(prompt, stop_id_sequences)
def validate_model_parameters(self):
"""
@@ -171,18 +303,23 @@ class APIHandler(BaseHTTPRequestHandler):
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
raise ValueError("max_tokens must be a non-negative integer")
if not isinstance(self.temperature, float) or self.temperature < 0:
if not isinstance(self.temperature, (float, int)) or self.temperature < 0:
raise ValueError("temperature must be a non-negative float")
if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1:
if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1:
raise ValueError("top_p must be a float between 0 and 1")
if (
not isinstance(self.repetition_penalty, float)
not isinstance(self.repetition_penalty, (float, int))
or self.repetition_penalty < 0
):
raise ValueError("repetition_penalty must be a non-negative float")
if self.logprobs != -1 and not (0 < self.logprobs <= 10):
raise ValueError(
f"logprobs must be between 1 and 10 but got {self.logprobs:,}"
)
if (
not isinstance(self.repetition_context_size, int)
or self.repetition_context_size < 0
@@ -200,6 +337,8 @@ class APIHandler(BaseHTTPRequestHandler):
if not isinstance(self.requested_model, str):
raise ValueError("model must be a string")
if self.adapter is not None and not isinstance(self.adapter, str):
raise ValueError("adapter must be a string")
def generate_response(
self,
@@ -207,36 +346,50 @@ class APIHandler(BaseHTTPRequestHandler):
finish_reason: Union[Literal["length", "stop"], None],
prompt_token_count: Optional[int] = None,
completion_token_count: Optional[int] = None,
token_logprobs: Optional[List[float]] = None,
top_tokens: Optional[List[Dict[int, float]]] = None,
tokens: Optional[List[int]] = None,
) -> dict:
"""
Generate a single response packet based on response type (stream or not), completion type and parameters.
Generate a single response packet based on response type (stream or
not), completion type and parameters.
Args:
text (str): Text generated by model
finish_reason (Union[Literal["length", "stop"], None]):
The reason the response is being sent: "length", "stop" or None
prompt_token_count (Optional[int]):
The amount of tokens in the prompt,
used to populate the "usage" field (not used when stream)
completion_token_count (Optional[int]):
The amount of tokens in the response,
used to populate the "usage" field (not used when stream)
finish_reason (Union[Literal["length", "stop"], None]): The reason the
response is being sent: "length", "stop" or `None`.
prompt_token_count (Optional[int]): The number of tokens in the prompt,
used to populate the "usage" field (not used when stream).
completion_token_count (Optional[int]): The number of tokens in the
response, used to populate the "usage" field (not used when stream).
token_logprobs (Optional[List[float]]): The log probabilities per token,
in token order.
top_tokens (Optional[List[Dict[int, float]]]): List of dictionaries mapping
tokens to logprobs for the top N tokens at each token position.
tokens (Optional[List[int]]): List of tokens to return with logprobs structure
Returns:
dict: A dictionary containing the response, imitating OpenAI's API
dict: A dictionary containing the response, in the same format as
OpenAI's API.
"""
token_logprobs = token_logprobs if token_logprobs else []
top_logprobs = top_tokens if top_tokens else []
# Static response
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": self.object_type,
"model": self.requested_model,
"created": self.created,
"choices": [
{
"index": 0,
"logprobs": None,
"logprobs": {
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
"tokens": tokens,
},
"finish_reason": finish_reason,
}
],
@@ -270,40 +423,77 @@ class APIHandler(BaseHTTPRequestHandler):
return response
def get_prompt_cache(self, prompt):
cache_len = len(self.prompt_cache.tokens)
if (
self.prompt_cache.model_key != self.model_provider.model_key
or cache_len >= len(prompt)
or self.prompt_cache.tokens != prompt[:cache_len]
):
self.prompt_cache.model_key = self.model_provider.model_key
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
else:
prompt = prompt[cache_len:]
self.prompt_cache.tokens.extend(prompt)
return prompt
def handle_completion(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate a response to a prompt and send it to the client in a single batch.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
stop_id_sequences (List[List[int]]):
A list of stop words passed to the stopping_criteria function
prompt (List[int]): The tokenized prompt.
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = []
finish_reason = "length"
stop_sequence_suffix = None
logging.debug(f"Starting completion:")
for (token, _), _ in zip(
generate_step(
prompt=prompt,
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
),
range(self.max_tokens),
if self.stream:
self.end_headers()
logging.debug(f"Starting stream:")
else:
logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
prompt = self.get_prompt_cache(prompt)
text = ""
tic = time.perf_counter()
sampler = make_sampler(self.temperature, top_p=self.top_p)
logits_processors = make_logits_processors(
self.logit_bias, self.repetition_penalty, self.repetition_context_size
)
for gen_response in stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=self.prompt_cache.cache,
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
segment = gen_response.text
text += segment
logging.debug(text)
token = gen_response.token
logprobs = gen_response.logprobs
tokens.append(token)
if self.logprobs > 0:
sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1)
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(tuple(top_token_info))
token_logprobs.append(logprobs[token].item())
stop_condition = stopping_criteria(
tokens, stop_id_sequences, self.tokenizer.eos_token_id
)
@@ -313,107 +503,81 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
)
text = text[: -len(stop_sequence_suffix)]
break
detokenizer.finalize()
text = (
detokenizer.text
if stop_sequence_suffix is None
else detokenizer.text[: -len(stop_sequence_suffix)]
)
response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
# Send an additional Content-Length header when it is known
self.send_header("Content-Length", str(len(response_json)))
self.end_headers()
self.wfile.write(response_json)
self.wfile.flush()
def handle_stream(
self,
prompt: mx.array,
stop_id_sequences: List[List[int]],
):
"""
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
stop_id_sequences (List[List[int]]):
A list of stop words passed to the stopping_criteria function
"""
# No additional headers are needed, call end_headers
self.end_headers()
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = []
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
# Buffer to store the last `max_stop_id_sequence_len` tokens
# to check for stop conditions before writing to the stream.
stop_sequence_buffer = []
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
for (token, _), _ in zip(
generate_step(
prompt=prompt,
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
stop_sequence_buffer.append(token)
# Continue generating tokens until buffer is as large as the longest stop_id_sequence
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
continue
stop_condition = stopping_criteria(
tokens,
stop_id_sequences,
self.tokenizer.eos_token_id,
)
if stop_condition.stop_met:
if stop_condition.trim_length:
stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
if self.stream:
# If the end of tokens overlaps with a stop sequence, generate new
# tokens until we know if the stop sequence is hit or not
if any(
(
sequence_overlap(tokens, sequence)
for sequence in stop_id_sequences
)
break
):
continue
elif segment:
response = self.generate_response(segment, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
new_text = detokenizer.last_segment
response = self.generate_response(new_text, None)
self.prompt_cache.tokens.extend(tokens)
logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB")
if self.stream:
response = self.generate_response(segment, finish_reason)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
stop_sequence_buffer = []
# check is there any remaining text to send
if stop_sequence_buffer:
next_chunk = (
detokenizer.last_segment
if stop_sequence_suffix is None
else detokenizer.last_segment[: -len(stop_sequence_suffix)]
if self.stream_options is not None and self.stream_options["include_usage"]:
response = self.completion_usage_response(len(prompt), len(tokens))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
else:
response = self.generate_response(
text,
finish_reason,
len(prompt),
len(tokens),
token_logprobs=token_logprobs,
top_tokens=top_tokens,
tokens=tokens,
)
response = self.generate_response(next_chunk, "length")
response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
# Send an additional Content-Length header when it is known
self.send_header("Content-Length", str(len(response_json)))
self.end_headers()
self.wfile.write(response_json)
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
def completion_usage_response(
self,
prompt_token_count: Optional[int] = None,
completion_token_count: Optional[int] = None,
):
response = {
"id": self.request_id,
"system_fingerprint": self.system_fingerprint,
"object": "chat.completion",
"model": self.requested_model,
"created": self.created,
"choices": [],
"usage": {
"prompt_tokens": prompt_token_count,
"completion_tokens": completion_token_count,
"total_tokens": prompt_token_count + completion_token_count,
},
}
return response
def handle_chat_completions(self) -> mx.array:
def handle_chat_completions(self) -> List[int]:
"""
Handle a chat completion request.
@@ -425,16 +589,14 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type
self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = (
"chat.completions.chunk" if self.stream else "chat.completions"
)
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if (
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
):
prompt = self.tokenizer.apply_chat_template(
body["messages"],
body.get("tools", None),
tokenize=True,
add_generation_prompt=True,
)
@@ -442,9 +604,9 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = self.tokenizer.encode(prompt)
return mx.array(prompt)
return prompt
def handle_text_completions(self) -> mx.array:
def handle_text_completions(self) -> List[int]:
"""
Handle a text completion request.
@@ -454,26 +616,68 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type
self.request_id = f"cmpl-{uuid.uuid4()}"
self.object_type = "text_completion"
assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"]
return self.tokenizer.encode(self.body["prompt"])
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
def do_GET(self):
"""
Respond to a GET request from a client.
"""
if self.path == "/v1/models":
self.handle_models_request()
else:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
def handle_models_request(self):
"""
Handle a GET request for the /v1/models endpoint.
"""
self._set_completion_headers(200)
self.end_headers()
# Scan the cache directory for downloaded mlx models
hf_cache_info = scan_cache_dir()
downloaded_models = [
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
]
# Create a list of available models
models = [
{
"id": repo.repo_id,
"object": "model",
"created": self.created,
}
for repo in downloaded_models
]
response = {"object": "list", "data": models}
response_json = json.dumps(response).encode()
self.wfile.write(response_json)
self.wfile.flush()
def run(
host: str,
port: int,
model: nn.Module,
tokenizer: TokenizerWrapper,
model_provider: ModelProvider,
server_class=HTTPServer,
handler_class=APIHandler,
):
server_address = (host, port)
prompt_cache = PromptCache()
httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs),
lambda *args, **kwargs: handler_class(
model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
),
)
warnings.warn(
"mlx_lm.server is not recommended for production as "
@@ -488,7 +692,6 @@ def main():
parser.add_argument(
"--model",
type=str,
required=True,
help="The path to the MLX model weights, tokenizer, and config",
)
parser.add_argument(
@@ -527,6 +730,18 @@ def main():
help="Set the MLX cache limit in GB",
required=False,
)
parser.add_argument(
"--chat-template",
type=str,
default="",
help="Specify a chat template for the tokenizer",
required=False,
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
args = parser.parse_args()
logging.basicConfig(
@@ -538,13 +753,7 @@ def main():
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
model, tokenizer = load(
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
)
run(args.host, args.port, model, tokenizer)
run(args.host, args.port, ModelProvider(args))
if __name__ == "__main__":

View File

@@ -3,14 +3,6 @@ from functools import partial
from transformers import AutoTokenizer
REPLACEMENT_CHAR = "\ufffd"
def _remove_space(x):
if x and x[0] == " ":
return x[1:]
return x
class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time.
@@ -57,11 +49,9 @@ class StreamingDetokenizer:
def last_segment(self):
"""Return the last segment of readable text since last time this property was accessed."""
text = self.text
if text and text[-1] != REPLACEMENT_CHAR:
segment = text[self.offset :]
self.offset = len(text)
return segment
return ""
segment = text[self.offset :]
self.offset = len(text)
return segment
class NaiveStreamingDetokenizer(StreamingDetokenizer):
@@ -79,16 +69,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def reset(self):
self.offset = 0
self._tokens = []
self.tokens = []
self._text = ""
self._current_tokens = []
self._current_text = ""
def add_token(self, token):
self._current_tokens.append(token)
self.tokens.append(token)
def finalize(self):
self._tokens.extend(self._current_tokens)
self._text += self._tokenizer.decode(self._current_tokens)
self._current_tokens = []
self._current_text = ""
@@ -97,17 +87,17 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def text(self):
if self._current_tokens:
self._current_text = self._tokenizer.decode(self._current_tokens)
if (
self._tokenizer.clean_up_tokenization_spaces
and self._current_text[-1] == " "
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text
self._current_tokens.clear()
self._current_text = ""
return self._text + self._current_text
@property
def tokens(self):
return self._tokens
class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models.
@@ -118,42 +108,43 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
def __init__(self, tokenizer, trim_space=True):
self.trim_space = trim_space
self._sep = "\u2581".encode()
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
for value, tokenid in tokenizer.vocab.items():
self.tokenmap[tokenid] = value
# Replace bytes with their value
for i in range(len(self.tokenmap)):
if self.tokenmap[i].startswith("<0x"):
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
if value.startswith("<0x"):
# Replace bytes with their value
self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
else:
self.tokenmap[tokenid] = value.encode()
self.reset()
def reset(self):
self.offset = 0
self._unflushed = ""
self._unflushed = b""
self.text = ""
self.tokens = []
def _try_flush(self, force=False):
text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace")
if not force and text.endswith("\ufffd"):
return
if not self.text and self.trim_space and text and text[0] == " ":
text = text[1:]
self.text += text
self._unflushed = b""
def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token]
if v[0] == "\u2581":
if self.text or not self.trim_space:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
self._unflushed = v
else:
self._unflushed += v
self._unflushed += v
self._try_flush()
def finalize(self):
if self.text or not self.trim_space:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
self._unflushed = ""
self._try_flush(force=True)
self._unflushed = b""
class BPEStreamingDetokenizer(StreamingDetokenizer):
@@ -164,9 +155,10 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
"""
_byte_decoder = None
_space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re")
def __init__(self, tokenizer, trim_space=False):
self.trim_space = trim_space
def __init__(self, tokenizer):
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
@@ -185,29 +177,47 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
self.text = ""
self.tokens = []
def add_token(self, token):
v = self.tokenmap[token]
# if the token starts with space
if self._byte_decoder[v[0]] == 32:
current_text = bytearray(
self._byte_decoder[c] for c in self._unflushed
).decode("utf-8")
if self.text or not self.trim_space:
self.text += current_text
def _decode_bytes(self, seq):
barr = bytearray()
for c in seq:
res = self._byte_decoder.get(c, False)
if res:
barr.append(res)
else:
self.text += _remove_space(current_text)
self._unflushed = v
else:
self._unflushed += v
barr.extend(bytes(c, "utf-8"))
return barr.decode("utf-8", "replace")
def _maybe_trim_space(self, current_text):
if len(current_text) == 0:
return current_text
elif current_text[0] != " ":
return current_text
elif not self.text:
return current_text[1:]
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
return current_text[1:]
return current_text
def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token]
self._unflushed += v
text = self._decode_bytes(self._unflushed)
# For multi-byte utf-8 wait until they are complete
# For single spaces wait until the next token to clean it if needed
if not text.endswith("\ufffd") and not (
len(v) == 1 and self._byte_decoder[v[0]] == 32
):
self.text += self._maybe_trim_space(text)
self._unflushed = ""
def finalize(self):
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
"utf-8"
"utf-8",
"replace",
)
if self.text or not self.trim_space:
self.text += current_text
else:
self.text += _remove_space(current_text)
self.text += self._maybe_trim_space(current_text)
self._unflushed = ""
@classmethod
@@ -245,16 +255,38 @@ class TokenizerWrapper:
huggingface tokenizer.
"""
def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer):
def __init__(
self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None
):
self._tokenizer = tokenizer
self._detokenizer = detokenizer_class(tokenizer)
self._eos_token_ids = (
set(eos_token_ids)
if eos_token_ids is not None
else {tokenizer.eos_token_id}
)
def __getattr__(self, attr):
if attr == "detokenizer":
return self._detokenizer
elif attr == "eos_token_ids":
return self._eos_token_ids
elif attr.startswith("_"):
return self.__getattribute__(attr)
else:
return getattr(self._tokenizer, attr)
def __setattr__(self, attr, value):
if attr in {"detokenizer", "eos_token_ids"}:
if attr == "detokenizer":
raise AttributeError("Cannot set the detokenizer.")
elif attr == "eos_token_ids":
self._eos_token_ids = set(value) if value is not None else set()
elif attr.startswith("_"):
super().__setattr__(attr, value)
else:
setattr(self._tokenizer, attr, value)
def _match(a, b):
if type(a) != type(b):
@@ -293,17 +325,10 @@ def _is_spm_decoder_no_space(decoder):
def _is_bpe_decoder(decoder):
_target_description = {
"type": "ByteLevel",
"add_prefix_space": False,
"trim_offsets": False,
"use_regex": False,
}
return _match(_target_description, decoder)
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}):
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
@@ -324,7 +349,10 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
return TokenizerWrapper(
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
detokenizer_class,
eos_token_ids=eos_token_ids,
)

View File

@@ -1,20 +1,21 @@
import json
from pathlib import Path
from typing import Dict, List
from transformers import PreTrainedTokenizer
class Dataset:
"""
Light-weight wrapper to hold lines from a jsonl file
Light-weight wrapper to hold a dataset.
"""
def __init__(self, path: Path):
with open(path, "r") as fid:
self._data = [json.loads(l) for l in fid]
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
self._text_key = text_key
self._data = data
def __getitem__(self, idx: int):
return self._data[idx]["text"]
return self._data[idx][self._text_key]
def __len__(self):
if self._data is None:
@@ -28,14 +29,17 @@ class ChatDataset(Dataset):
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
super().__init__(path)
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data)
self._tokenizer = tokenizer
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
text = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages,
tools=self._data[idx].get("tools", None),
tokenize=False,
add_generation_prompt=True,
)
return text
@@ -43,19 +47,28 @@ class ChatDataset(Dataset):
class CompletionsDataset(Dataset):
"""
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
super().__init__(path)
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
completion_key: str = "completion",
):
super().__init__(data)
self._tokenizer = tokenizer
self._prompt_key = prompt_key
self._completion_key = completion_key
def __getitem__(self, idx: int):
data = self._data[idx]
text = self._tokenizer.apply_chat_template(
[
{"role": "user", "content": data["prompt"]},
{"role": "assistant", "content": data["completion"]},
{"role": "user", "content": data[self._prompt_key]},
{"role": "assistant", "content": data[self._completion_key]},
],
tokenize=False,
add_generation_prompt=True,
@@ -63,19 +76,15 @@ class CompletionsDataset(Dataset):
return text
def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
# Return empty dataset for non-existent paths
if not path.exists():
return []
with open(path, "r") as fid:
first_line = next(fid)
first_obj = json.loads(first_line)
if "messages" in first_obj:
return ChatDataset(path, tokenizer)
elif "prompt" in first_obj and "completion" in first_obj:
return CompletionsDataset(path, tokenizer)
elif "text" in first_obj:
return Dataset(path)
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
sample = data[0]
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif "prompt" in sample and "completion" in sample:
return CompletionsDataset(data, tokenizer)
elif "text" in sample:
return Dataset(data)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
@@ -83,12 +92,90 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
)
def load_dataset(args, tokenizer: PreTrainedTokenizer):
def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer)
names = ("train", "valid", "test")
data_path = Path(args.data)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
return train, valid, test
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
from datasets import exceptions, load_dataset
try:
dataset = load_dataset(data_id)
names = ("train", "valid", "test")
train, valid, test = [
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
for n in names
]
except exceptions.DatasetNotFoundError:
raise ValueError(f"Not found Hugging Face dataset: {data_id} .")
return train, valid, test
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets
hf_args = args.hf_dataset
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")
def create_hf_dataset(split: str = None):
ds = datasets.load_dataset(
dataset_name,
split=split,
**hf_args.get("config", {}),
)
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
)
if args.train:
train_split = hf_args.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split)
valid = create_hf_dataset(split=valid_split)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(split=hf_args.get("test_split"))
else:
test = []
return train, valid, test
def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)
if data_path.exists():
train, valid, test = load_local_dataset(data_path, tokenizer)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args.data, tokenizer)
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."

Some files were not shown because too many files have changed in this diff Show More