Commit Graph

648 Commits

Author SHA1 Message Date
Chime Ogbuji
95e1f22812 Incorporate use of response template for completion masking
Follow example of trl's DataCollatorForCompletionOnlyLM to use response template to identify beginning of completion/continuation tokens for the purpose of masking out the other tokens during loss calculation
2025-02-09 07:43:04 -08:00
Chime Ogbuji
cb87f6f22c Add response template (or token) argument
For use in calculating mask for everything up to the after the response prompt (i.e., the continuation/completion)
2025-02-09 07:43:01 -08:00
Chime Ogbuji
6df285ef6c Synch use of special tokens with iterate_batches 2025-02-09 07:41:24 -08:00
Chime Ogbuji
f989401881 Default for hf_datasets configuration 2025-02-09 07:41:24 -08:00
Chime Ogbuji
5ce58e4b6a Update documentation 2025-02-09 07:41:24 -08:00
Chime Ogbuji
3f08dfc762 Don't dupe BOS
Ensure completion batching doesn't allow BOS dupes for instruction models with chat models whose tokenizer configurations have ```add_bos_token = True``` (see: 1095)
2025-02-09 07:41:24 -08:00
Chime Ogbuji
69282ab7fc Minor fix 2025-02-09 07:41:24 -08:00
Chime Ogbuji
4890870053 Add ability to fetch raw prompt and completion text from completion datasets 2025-02-09 07:41:23 -08:00
Chime Ogbuji
a5b866cf73 Fix index calculation 2025-02-09 07:41:01 -08:00
Chime Ogbuji
a4a86ad898 Fix iteration over HF dataset collection 2025-02-09 07:41:01 -08:00
Chime Ogbuji
78c33e5037 Fix keyword argument invokation 2025-02-09 07:41:00 -08:00
Chime Ogbuji
387c45efa2 Fixes to references to hf_datasets 2025-02-09 07:40:09 -08:00
Chime Ogbuji
214c79be9c Fixes to config format in documentattion 2025-02-09 07:38:41 -08:00
Chime Ogbuji
8ec802f468 Updates to LoRA documentation 2025-02-09 07:38:41 -08:00
Chime Ogbuji
14a75f3f03 Generalize HF datasets to a collection of HF dataasets via datasets, adds support for custom chat HF datasets (#1088), and fixes (#1087) 2025-02-09 07:38:40 -08:00
Chime Ogbuji
3496cbea46 Add input masking for fine-tuning in documentation
Renamed the batch iteration function (iterate_delineated_batches -> iterate_completion_batches).
2025-02-09 07:12:54 -08:00
Chime Ogbuji
71d9f8cc38 Fix 2025-02-09 07:12:54 -08:00
Chime Ogbuji
02abeeade4 Update sublist search and calculation of input id length 2025-02-09 07:12:54 -08:00
Chime Ogbuji
30fd5af843 Fix variable reference 2025-02-09 07:12:54 -08:00
Chime Ogbuji
27cd361d76 Updates CL lora tuner with input masking that uses default_loss (and iterate_batches) by default. 2025-02-09 07:12:54 -08:00
Chime Ogbuji
84fc1bde48 Minor documentation update 2025-02-09 07:12:54 -08:00
Chime Ogbuji
79a042768f Replace iterate_input_masked_batches with iterate_delineated_batches, an updated attempt to better sync with iterate_batches logic 2025-02-09 07:12:54 -08:00
Chime Ogbuji
604be3cec9 Add input_masked loss calculation and batching w/ padding 2025-02-09 07:12:54 -08:00
Awni Hannun
f58c7de901
Some improvements to speedup alignment computation in MLX Whisper (#1259)
* some improvements to speedup alignment computation in MLX Whisper

* fix alignment
2025-02-08 15:47:00 -08:00
Awni Hannun
1503bd4f55
support hunyuan 7b (#1263) 2025-02-08 15:46:47 -08:00
Awni Hannun
31611b62d7
Add IBM granite model (#1265)
* add granite

* add thinking option
2025-02-08 15:46:15 -08:00
Awni Hannun
6120a5f376
Faster DSv2/3 expert score computation (#1257)
* fix deepseek sharding (#1242)

* compile and use put along axis in deep seek routing function
2025-02-07 10:24:57 -08:00
Awni Hannun
52c41b5b5a
Fix prompt cache for models without chat template (#1250)
* fix deepseek sharding (#1242)

* fix prompt cache with no chat template
2025-02-06 11:10:58 -08:00
Nripesh Niketan
747c08e202
Chore: pre-commit bump (#1253) 2025-02-06 09:06:31 -08:00
Pedro Cuenca
e2e5478da5
READMEs: fix typo in link, minor update. (#1246) 2025-02-04 11:52:32 -08:00
Awni Hannun
21d0ab6e8a
fix deepseek sharding (#1242) 2025-02-03 16:59:50 -08:00
Gökdeniz Gülmez
0989c073b0
Optimizations for mamba1 (#1213)
* added mx.einsum() operations: before: 41.293 tokens-per-sec, after: 57.822 tokens-per-sec

* Fused Operations in delta, B, C = ... :. Before: 57.822 tokens-per-sec, after: 83.890 tokens-per-sec

* Pre-computing A_log. After: 83.890 tokens-per-sec, before: 85.848 tokens-per-sec

* Update MambaBlock, Batched Input Processing, Improved Cache Handling, Pre-computed Constants, Cleaner State Management, Explicit Return Values:. Before: 82.442 tokens-per-sec, after: 129.130 tokens-per-sec.

* cleaning up and adding apple copyright to helium modelfile

* update Copyright to this year

* nits + even faster

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-02-03 13:36:08 -08:00
Awni Hannun
d9924d08d1
Fix no validation in lora (#1241) 2025-02-03 09:55:24 -08:00
Awni Hannun
9c2ef38d4d
only download local shard (#1240) 2025-02-02 13:58:44 -08:00
Awni Hannun
e8afb59de4
better overflow correction (#1229) 2025-01-28 14:37:30 -08:00
Anchen
7a83077cd7
chore(mlx-lm): support text type content in messages (#1225)
* chore(mlx-lm): support text type content

* chore: optimize the messagef content processing

* nits + format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-27 17:13:50 -08:00
Awni Hannun
f44a52e2dc
batched min p and fix spec gen sampling (#1222) 2025-01-27 15:40:31 -08:00
Gökdeniz Gülmez
77faa14ba4
adding support for kyutai's helium (#1208)
* initial commit

* adding helium into training

* Update ACKNOWLEDGMENTS.md

* nits

* nits

* fixes / nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-26 07:19:07 -08:00
Awni Hannun
9a3ddc3e65
some fixes for pipeline parallel deep seek r1 (#1216) 2025-01-21 19:40:29 -08:00
Victor Nogueira
df1406735b
Fix dataset variable name, in datasets.py (#1212) 2025-01-21 14:12:43 -08:00
Jarrett
07f88f8057
fix(lora): add back store_true default args (#1205) 2025-01-16 11:15:42 -08:00
Awni Hannun
50f0a7f6d9
add internlm3 (#1206) 2025-01-15 14:55:41 -08:00
Ivan Fioravanti
6ae6c72c2e
reduction moved to CPU in case of distributed training (#1200) 2025-01-14 17:20:42 -08:00
Awni Hannun
c117af83b8
fix gpt bigcode (#1204) 2025-01-13 10:22:32 -08:00
Chime Ogbuji
0228c46434
Custom local dataset features (#1085)
* Generalize prompt_feature and completion_feature for use in local datasets to facilitate compatibility with many other training dataset formats.

* Persist configured prompt/completion key

* rebase + nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-13 10:01:18 -08:00
Prince Canuma
bf2da36fc6
Fix Cohere2: mask shape error (long context) (#1202)
* fix mask shape error (long context)

* Update llms/mlx_lm/models/cohere2.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* revert layer_idx

* black formatting

* Update cohere2.py

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-12 12:58:08 -08:00
Xingjun.Wang
514502da22
Support snapshot_download for ModelScope (#1194)
* add MLX_USE_MODELSCOPE env

* update

* update snapshot_download

* update

* remove modelscope dependency and add import check

* update

* nits

* fix

---------

Co-authored-by: wangxingjun778 <jason@U-C7X6TX5G-2239.local>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-10 15:29:34 -08:00
Awni Hannun
93c5cfd781
Add a speculative decoding generator (#1155)
* add a speculative decoding generator

* fix

* fixes

* optional kwarg pop
2025-01-10 15:27:08 -08:00
Awni Hannun
5cae0a60e6
deepseek v3 model with pipeline parallelism (#1191)
* deepseekv3

* use upload_large_file instead of deprecated multi comit

* add pipeline generation and example

* comment

* get fp16 working

* use mlx==0.22
2025-01-09 15:55:53 -08:00
Jarrett
40b88eff48
fix(lora): config yaml & arg default merge bug (#1196) 2025-01-09 11:33:54 -08:00