From 35206806acc47e73eef80a55772dc22d3bf5d8c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ph=C3=BAc=20H=2E=20L=C3=AA=20Kh=E1=BA=AFc?= Date: Wed, 17 Apr 2024 00:08:49 +0100 Subject: [PATCH] Create executables for generate, lora, server, merge, convert (#682) * feat: create executables mlx_lm. * nits in docs --------- Co-authored-by: Awni Hannun --- llms/README.md | 10 +++++----- llms/mlx_lm/LORA.md | 18 +++++++++--------- llms/mlx_lm/MERGE.md | 4 ++-- llms/mlx_lm/SERVER.md | 6 +++--- llms/mlx_lm/convert.py | 6 +++++- llms/mlx_lm/generate.py | 9 +++++---- llms/mlx_lm/lora.py | 6 +++++- llms/mlx_lm/merge.py | 6 +++++- llms/mlx_lm/server.py | 6 +++++- llms/setup.py | 10 ++++++++++ 10 files changed, 54 insertions(+), 27 deletions(-) diff --git a/llms/README.md b/llms/README.md index 27348e04..a0c17972 100644 --- a/llms/README.md +++ b/llms/README.md @@ -66,7 +66,7 @@ To see a description of all the arguments you can do: You can also use `mlx-lm` from the command line with: ``` -python -m mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.1 --prompt "hello" +mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.1 --prompt "hello" ``` This will download a Mistral 7B model from the Hugging Face Hub and generate @@ -75,19 +75,19 @@ text using the given prompt. For a full list of options run: ``` -python -m mlx_lm.generate --help +mlx_lm.generate --help ``` To quantize a model from the command line run: ``` -python -m mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.1 -q +mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.1 -q ``` For more options run: ``` -python -m mlx_lm.convert --help +mlx_lm.convert --help ``` You can upload new models to Hugging Face by specifying `--upload-repo` to @@ -95,7 +95,7 @@ You can upload new models to Hugging Face by specifying `--upload-repo` to [MLX Hugging Face community](https://huggingface.co/mlx-community) you can do: ``` -python -m mlx_lm.convert \ +mlx_lm.convert \ --hf-path mistralai/Mistral-7B-v0.1 \ -q \ --upload-repo mlx-community/my-4bit-mistral diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 04d00ead..6d9392d5 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -27,7 +27,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: The main command is `mlx_lm.lora`. To see a full list of command-line options run: ```shell -python -m mlx_lm.lora --help +mlx_lm.lora --help ``` Note, in the following the `--model` argument can be any compatible Hugging @@ -37,7 +37,7 @@ You can also specify a YAML config with `-c`/`--config`. For more on the format [example YAML](examples/lora_config.yaml). For example: ```shell -python -m mlx_lm.lora --config /path/to/config.yaml +mlx_lm.lora --config /path/to/config.yaml ``` If command-line flags are also used, they will override the corresponding @@ -48,7 +48,7 @@ values in the config. To fine-tune a model use: ```shell -python -m mlx_lm.lora \ +mlx_lm.lora \ --model \ --train \ --data \ @@ -76,7 +76,7 @@ You can resume fine-tuning with an existing adapter with To compute test set perplexity use: ```shell -python -m mlx_lm.lora \ +mlx_lm.lora \ --model \ --adapter-path \ --data \ @@ -88,7 +88,7 @@ python -m mlx_lm.lora \ For generation use `mlx_lm.generate`: ```shell -python -m mlx_lm.generate \ +mlx_lm.generate \ --model \ --adapter-path \ --prompt "" @@ -106,13 +106,13 @@ You can generate a model fused with the low-rank adapters using the To see supported options run: ```shell -python -m mlx_lm.fuse --help +mlx_lm.fuse --help ``` To generate the fused model run: ```shell -python -m mlx_lm.fuse --model +mlx_lm.fuse --model ``` This will by default load the adapters from `adapters/`, and save the fused @@ -125,7 +125,7 @@ useful for the sake of attribution and model versioning. For example, to fuse and upload a model derived from Mistral-7B-v0.1, run: ```shell -python -m mlx_lm.fuse \ +mlx_lm.fuse \ --model mistralai/Mistral-7B-v0.1 \ --upload-repo mlx-community/my-4bit-lora-mistral \ --hf-path mistralai/Mistral-7B-v0.1 @@ -134,7 +134,7 @@ python -m mlx_lm.fuse \ To export a fused model to GGUF, run: ```shell -python -m mlx_lm.fuse \ +mlx_lm.fuse \ --model mistralai/Mistral-7B-v0.1 \ --export-gguf ``` diff --git a/llms/mlx_lm/MERGE.md b/llms/mlx_lm/MERGE.md index 2ee2414c..093c7ed6 100644 --- a/llms/mlx_lm/MERGE.md +++ b/llms/mlx_lm/MERGE.md @@ -6,14 +6,14 @@ Face hub or save them locally for LoRA fine tuning. The main command is `mlx_lm.merge`: ```shell -python -m mlx_lm.merge --config config.yaml +mlx_lm.merge --config config.yaml ``` The merged model will be saved by default in `mlx_merged_model`. To see a full list of options run: ```shell -python -m mlx_lm.merge --help +mlx_lm.merge --help ``` Here is an example `config.yaml`: diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 68bb3545..edea5457 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -11,13 +11,13 @@ API](https://platform.openai.com/docs/api-reference). Start the server with: ```shell -python -m mlx_lm.server --model +mlx_lm.server --model ``` For example: ```shell -python -m mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1 +mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1 ``` This will start a text generation server on port `8080` of the `localhost` @@ -27,7 +27,7 @@ Hugging Face repo if it is not already in the local cache. To see a full list of options run: ```shell -python -m mlx_lm.server --help +mlx_lm.server --help ``` You can make a request to the model by running: diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index 5f2f3adf..a3f43f71 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -52,7 +52,11 @@ def configure_parser() -> argparse.ArgumentParser: return parser -if __name__ == "__main__": +def main(): parser = configure_parser() args = parser.parse_args() convert(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 6d859c3c..da94eef2 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -101,7 +101,10 @@ def colorprint_by_t0(s, t0): colorprint(color, s) -def main(args): +def main(): + parser = setup_arg_parser() + args = parser.parse_args() + mx.random.seed(args.seed) # Building tokenizer_config @@ -143,6 +146,4 @@ def main(args): if __name__ == "__main__": - parser = setup_arg_parser() - args = parser.parse_args() - main(args) + main() diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 36343262..18840cf4 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -247,7 +247,7 @@ def run(args, training_callback: TrainingCallback = None): print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") -if __name__ == "__main__": +def main(): parser = build_parser() args = parser.parse_args() config = args.config @@ -266,3 +266,7 @@ if __name__ == "__main__": if not args.get(k, None): args[k] = v run(types.SimpleNamespace(**args)) + + +if __name__ == "__main__": + main() diff --git a/llms/mlx_lm/merge.py b/llms/mlx_lm/merge.py index 9c88970e..a009338e 100644 --- a/llms/mlx_lm/merge.py +++ b/llms/mlx_lm/merge.py @@ -162,7 +162,11 @@ def merge( upload_to_hub(mlx_path, upload_repo, base_hf_path) -if __name__ == "__main__": +def main(): parser = configure_parser() args = parser.parse_args() merge(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index e717f324..482ee00c 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -409,7 +409,7 @@ def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler) httpd.serve_forever() -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser(description="MLX Http Server.") parser.add_argument( "--model", @@ -449,3 +449,7 @@ if __name__ == "__main__": ) run(args.host, args.port) + + +if __name__ == "__main__": + main() diff --git a/llms/setup.py b/llms/setup.py index 58d02291..26e1a3b8 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -24,4 +24,14 @@ setup( install_requires=requirements, packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], python_requires=">=3.8", + entry_points={ + "console_scripts": [ + "mlx_lm.convert = mlx_lm.convert:main", + "mlx_lm.fuse = mlx_lm.fuse:main", + "mlx_lm.generate = mlx_lm.generate:main", + "mlx_lm.lora = mlx_lm.lora:main", + "mlx_lm.merge = mlx_lm.merge:main", + "mlx_lm.server = mlx_lm.server:main", + ] + }, )