Commit Graph

177 Commits

Author SHA1 Message Date
Madroid Ma
710c552731 add huggingface repo url print (#534) 2024-03-05 21:51:31 -08:00
Muhtasham Oblokulov
5de7c2ac33 Add tips on porting LLMs from HuggingFace (#523)
* Add tips on porting LLMs from HuggingFace

* Add CONTRIBUTING.md  to mlx-examples-llms

* Refactor imports and update comment in starcoder2.py

* Update llms/mlx_lm/models/starcoder2.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-05 17:43:15 -08:00
Prince Canuma
3fdf85e79d Starcoder2: Update config and change GQA to use repeat (#520)
* update config

* change gqa to use repeat instead of concante

* contribution
2024-03-03 06:12:03 -08:00
Anchen
1e3daea3bb chore(mlx-lm): add missing model_type for starcoder2 (#522) 2024-03-03 06:07:45 -08:00
Anchen
3655bfc3bd chore(mlx-lm): fix broken server.py script (#519) 2024-03-03 06:04:54 -08:00
Muhtasham Oblokulov
81e2a80026 Add Starcoder 2 (#502)
* Add Starcoder2 model and update utils.py

* Refactor model arguments and modules in starcoder2.py

* Refactor FeedForward class to MLP in starcoder2.py

* Fix typo

* pre-commit

* Refactor starcoder2.py: Update model arguments and modules

* Fix LM head and MLP layers

* Rename  input layer norm

* Update bias in linear layers

* Refactor token embeddings in Starcoder2Model

* Rename to standard HF attention layer name

* Add LayerNorm

* Add transposed token embeddings (like in Gemma)

* Refactor MLP and TransformerBlock classes

* Add tie_word_embeddings option to ModelArgs and update Model implementation

* Add conditional check for tying word embeddings in Starcoder2Model

* Fix bias in lm_head linear layer

* Remove unused LayerNorm in stablelm

* Update transformers dependency to use GitHub repository

* fix lm head bug, revert transformer req

* Update RoPE initialization in Attention class

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-02 19:39:23 -08:00
Miller Liang
5b1043a458 llms: convert() add 'revision' argument (#506)
* llms: convert() add 'revision' argument

* Update README.md

* Update utils.py

* Update README.md

* Update llms/mlx_lm/utils.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-02 06:28:26 -08:00
Ashish
261f1280f6 Update to StableLM code (#514)
* StableLM now part of Transformers as stablelm rather than stablelm_epoch; changed config to match new changes

* removing old file

* reference new stablelm
2024-03-01 09:53:38 -08:00
Sugato Ray
3acc1ec84e fix: string indentation with textwrap.dedent (#510)
* fix: string indentation with textwrap.dedent

* update formatting
2024-02-29 22:23:01 -08:00
Madroid Ma
f03c8a7b44 LoRA: adapter file Support path information (#505)
* LoRA: adapter file Support path information

* fix pre-commit lint

* from os.path to pathlib.Path

* Update llms/mlx_lm/tuner/trainer.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* rename check_checkpoints_path to checkpoints_path

* fix pre-commit lint

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-29 22:20:49 -08:00
Awni Hannun
ae48563378 Remove gc (#509)
* remove gc

* remove debug
2024-02-29 09:40:04 -08:00
Anchen
13794a05da chore(mlx-lm): add adapter support in generate.py (#494)
* chore(mlx-lm): add adapter support in generate.py

* chore: remove generate from lora.py and raise error to let user use mlx_lm.generate instead
2024-02-28 07:49:25 -08:00
Alex Ishida
ab0f1dd1b6 Add metadata when saving safetensors (#496)
* Add metadata when saving safetensors

Add metadata format="pt" for safetensors so that model's are accessible to `transformers` users as well.

* save with metadata format mlx

Save the model weights with metadata format of "mlx".

* Updated llms/mlx_lm/generate.py
2024-02-28 07:29:00 -08:00
Y4hL
ea92f623d6 Prevent llms/mlx_lm from serving the local directory as a webserver (#498)
* Don't serve local directory

BaseHTTPRequestHandler serves the current directory by default. Definitely not intended behaviour. Remove the "do_HEAD" and "do_GET" methods.

* Fix typo in method name

I assume hanlde_stream was intended to be called handle_stream

* Fix outdated typehint

load_model returns nn.Module, however fetch_from_hub was not updated to reflect the change

* Add some more type hints

* Add warnings for using in prod

Add a warning to README and runtime, discouraging use in production. The warning is the same as on the python docs for HTTPServer https://docs.python.org/3/library/http.server.html

* format

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-27 19:40:42 -08:00
Y4hL
676e574eff Add missing import (#497)
* Add missing import

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-27 13:27:08 -08:00
Awni Hannun
95f82e67a2 Fix import warning (#479)
* fix import warning
* fix version import
* remove api, move convert to utils
* also update circle to run external PRs
2024-02-27 08:47:56 -08:00
Anchen
82f3f31d93 chore(mlx-lm): refactor server.py to utilize generate_step from utils for consistency (#491)
* chore(mlx-lm): refactor server.py to utilize generate_step from utils for consistency

* chore(mlx-lm): update server doc

* chore: remove unused generate func
2024-02-27 06:25:24 -08:00
Anchen
19a21bfce4 chore: add /v1/completions for server (#489) 2024-02-26 20:59:33 -08:00
Madroid Ma
e5dfef5d9a LoRA: Extract the run function for easy use in scripts file (#482)
* LoRA: Extract the run_lora function for easy use in scripts

* LoRA: run_lora function adds a TrainingCallback pass.

* LoRA: change run_lora to run
2024-02-26 19:35:04 -08:00
peterjc123
ccb278bcbd Add top-p sampling for text generation (#486) 2024-02-26 06:18:11 -08:00
Awni Hannun
f24edfa9dc [mlx-lm] Add precompiled normalizations (#451)
* add precompiled normalizations

* nits
2024-02-22 12:40:55 -08:00
Awni Hannun
97c09a863d bump version and include in package (#475) 2024-02-21 09:40:36 -08:00
Awni Hannun
ab9172baac Gemma support (#474)
* gemma support

* format

* lora support for gemma
2024-02-21 08:47:13 -08:00
Angelos Katharopoulos
dc4f2e0a6b Lazy loading models for faster convert and merge (#462) 2024-02-20 13:36:55 -08:00
Madroid Ma
8eee4399f4 LoRA: Add printing and callbacks for learning rate during training (#457)
* LoRA:Refactor TrainingCallback to enhance flexibility and extensibility

This commit refactors the TrainingCallback class to accept a dictionary parameter for both on_train_loss_report and on_val_loss_report methods. By switching from multiple parameters to a single dict parameter, this change significantly improves the class's flexibility and makes it easier to extend with new training or validation metrics in the future without altering the method signatures. This approach simplifies the addition of new information to be logged or processed and aligns with best practices for scalable and maintainable code design.

* LoRA: Add printing and callbacks for learning rate during training
2024-02-20 13:07:21 -08:00
Awni Hannun
8fd953ee2b Support for slerp merging models (#455)
* support for slerp merging models

* docs

* update docs

* format'
2024-02-19 20:37:15 -08:00
Anchen
88458c4e40 feat(mlx-lm): add openAI like api server (#429)
* feat(mlx-lm): add openAI like api server

* chore: fix sse format

* chore: add top_p support

* chore: fix the load import

* chore: add workground for missing space in stream decoding

* chore: fix typo

* chore: add error handling for streaming

* chore: using slicing instead of replace

* chore: set host, port via args and improve handle stream token logic

* chore: refactor stop sequence function

* chore: rename stopping_criteria

* fix: unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16

* chore: fix the streaming unicode issue

* Update llms/mlx_lm/server.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* refacotr: move stopping_criteria out of generate func

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-18 14:01:28 -08:00
devonthomas35
cc671cd1c7 Mixtral: Fix non-default arg follows default exception (#450)
Mixtral models throw the following exception
```
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 119, in <module>
    main(args)
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 96, in main
    model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 278, in load
    model = load_model(model_path)
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 221, in load_model
    model_class, model_args_class = _get_classes(config=config)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 46, in _get_classes
    arch = importlib.import_module(f"mlx_lm.models.{model_type}")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/models/mixtral.py", line 11, in <module>
    @dataclass
     ^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1230, in dataclass
    return wrap(cls)
           ^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1220, in wrap
    return _process_class(cls, init, repr, eq, order, unsafe_hash,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1027, in _process_class
    _init_fn(all_init_fields,
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 545, in _init_fn
    raise TypeError(f'non-default argument {f.name!r} '
TypeError: non-default argument 'model_type' follows default argument
```
2024-02-18 13:30:26 -08:00
Ivan Fioravanti
b05907c87e Change argument name in lora.py (#453)
The argument name "--max_seq_length" was updated to "--max-seq-length" in the code to maintain a consistent naming convention across the program.
2024-02-18 06:04:49 -08:00
Awni Hannun
e4d5630698 Basic CircleCI (#449)
* basic style checks for circleci

* format

* fix config
2024-02-16 22:13:55 -08:00
vishal-14069
21e19b5b5a Add Repetitive penalty to LLM inference - mlx-lm (#399)
* feat: add repetition penalty

* fix: generate function argument fix

* typo fixes

* update repetitive penalty

* update generate_step and generate

* resolve conflicts in generate

* merge latest oull origin master

* update generate

* update generate and generate_step

* update repetition list - rename variable

* refactor token count

* update generate step and generate

* move repetition_context in generate_step

* update generate step

* update generate_step
2024-02-16 21:58:17 -08:00
Madroid Ma
0ba466369f LoRA: add training callbacks (#414)
* LoRA: add training callbacks

* LoRA: add trained tokens print & callback
2024-02-16 06:04:57 -08:00
Madroid Ma
726b1ddec0 fix: check LoRA layers number error (#446) 2024-02-16 06:03:33 -08:00
Angelos Katharopoulos
f71e965d57 Change gqa to use repeat instead of concatenate (#443) 2024-02-14 17:40:11 -08:00
Awni Hannun
06ddb8414d Fix Qwen2 and SD (#441)
* fix qwen2

* version bump

* fix list shape
2024-02-14 13:43:12 -08:00
Chime Ogbuji
e446598f62 Passing parameterized loss and batching to trainer (#391) 2024-02-13 07:03:25 -08:00
Madroid Ma
954aa50c54 LoRA: Improve validation error for LoRA layer count exceeding model layer (#427)
* LoRA: Improve validation error for LoRA layer count exceeding model layer

This commit enhances the error handling when the specified LoRA layer count exceeds the total number of layers in the model. It clarifies the error message to provide actionable feedback for users, guiding them to adjust their input parameters accordingly.

* format + nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-13 06:56:27 -08:00
Awni Hannun
d4666615bb Lazy import + refactor Lora layer addition (#426)
* lazy model import in mlx_lm

* change lora loading

* fix olmo lora

* remove a bunch of unused stuff from plamo

* move phixtral to mlx-lm and out of llms/
2024-02-12 10:51:02 -08:00
Ivan Fioravanti
4576946151 Add checkpoints directory for adapter weights (#431)
* Add checkpoints directory for adapter weights

The code was modified to create a checkpoints directory if it doesn't exist yet. Adapter weights are now saved to this checkpoints directory during the training iterations.
Corrected indentation of Save adapter weights code because it was part of "if eval"

* Fixing a blank added by mistake
2024-02-12 10:50:05 -08:00
Nripesh Niketan
f1ef378a58 Feat: update pre-commit rev (#432) 2024-02-11 07:23:27 -08:00
Awni Hannun
f45a1ab83c Update a few examples to use compile (#420)
* update a few examples to use compile

* update mnist

* add compile to vae and rename some stuff for simplicity

* update reqs

* use state in eval

* GCN example with RNG + dropout

* add a bit of prefetching
2024-02-08 13:00:41 -08:00
Anchen
da7adae5ec fix(mlx-m): lazy load hf_olmo (#424) 2024-02-08 09:02:43 -08:00
Markus Enzweiler
9b387007ab Example of a Convolutional Variational Autoencoder (CVAE) on MNIST (#264)
* initial commit

* style fixes

* update of ACKNOWLEDGMENTS

* fixed comment

* minor refactoring; removed unused imports

* added cifar and cvae to top-level README.md

* removed mention of cuda/mps in argparse

* fixed training status output

* load_weights() with strict=True

* pretrained model update

* fixed imports and style

* requires mlx>=0.0.9

* updated with results using mlx 0.0.9

* removed mention of private repo

* simplify and combine to one file, more consistency with other exmaples

* few more nits

* nits

* spell

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-06 20:02:27 -08:00
Chris McMaster
2303238e44 Update olmo.py (#419)
exit should be imported outside of interactive mode
2024-02-06 16:16:46 -08:00
Anchen
8b77677c05 chore(mlx-lm): add model weight index in save_weights (#413)
* chore(mlx-lm): add model weight index in save_weights

* Update llms/mlx_lm/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* chore: save total siZe as param size isntead of file size

* chore: clean up format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-06 05:32:15 -08:00
Anchen
a7d139f484 fix(mlx-lm): olmo 1b model (#417) 2024-02-06 05:27:05 -08:00
Awni Hannun
aa7447efa2 Olmo in MLX LM (#415)
* run olmo

* format
2024-02-05 21:13:49 -08:00
Ivan Fioravanti
7fbca214b1 Add max sequence length argument in lora.py (#408)
A new argument "--max_seq_length" has been added to the command-line parser and passed as a parameter to the main function of the lora.py script. This allows users to specify and control the maximum sequence length during training.
2024-02-04 12:28:21 -08:00
Junyang Lin
9d0dd34403 add qwen2 (#411) 2024-02-04 08:31:38 -08:00
Madroid Ma
ba3a9355d1 LoRA: Remove unnecessary model type judgments (#388)
* LoRA: Remove unnecessary model type judgments

1. Supported models are already checked in the load_model function in utils, no need to repeat the check in lora
2. The checks in lora are not synchronized with those in utils

* LoRA: add LoRA supported models in mlx_lm utils
2024-01-31 11:55:27 -08:00