Create executables for generate, lora, server, merge, convert (#682)

* feat: create executables mlx_lm.<cmd>

* nits in docs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Phúc H. Lê Khắc
2024-04-17 00:08:49 +01:00
committed by GitHub
parent 7d7e236061
commit 35206806ac
10 changed files with 54 additions and 27 deletions

View File

@@ -27,7 +27,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
The main command is `mlx_lm.lora`. To see a full list of command-line options run:
```shell
python -m mlx_lm.lora --help
mlx_lm.lora --help
```
Note, in the following the `--model` argument can be any compatible Hugging
@@ -37,7 +37,7 @@ You can also specify a YAML config with `-c`/`--config`. For more on the format
[example YAML](examples/lora_config.yaml). For example:
```shell
python -m mlx_lm.lora --config /path/to/config.yaml
mlx_lm.lora --config /path/to/config.yaml
```
If command-line flags are also used, they will override the corresponding
@@ -48,7 +48,7 @@ values in the config.
To fine-tune a model use:
```shell
python -m mlx_lm.lora \
mlx_lm.lora \
--model <path_to_model> \
--train \
--data <path_to_data> \
@@ -76,7 +76,7 @@ You can resume fine-tuning with an existing adapter with
To compute test set perplexity use:
```shell
python -m mlx_lm.lora \
mlx_lm.lora \
--model <path_to_model> \
--adapter-path <path_to_adapters> \
--data <path_to_data> \
@@ -88,7 +88,7 @@ python -m mlx_lm.lora \
For generation use `mlx_lm.generate`:
```shell
python -m mlx_lm.generate \
mlx_lm.generate \
--model <path_to_model> \
--adapter-path <path_to_adapters> \
--prompt "<your_model_prompt>"
@@ -106,13 +106,13 @@ You can generate a model fused with the low-rank adapters using the
To see supported options run:
```shell
python -m mlx_lm.fuse --help
mlx_lm.fuse --help
```
To generate the fused model run:
```shell
python -m mlx_lm.fuse --model <path_to_model>
mlx_lm.fuse --model <path_to_model>
```
This will by default load the adapters from `adapters/`, and save the fused
@@ -125,7 +125,7 @@ useful for the sake of attribution and model versioning.
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
```shell
python -m mlx_lm.fuse \
mlx_lm.fuse \
--model mistralai/Mistral-7B-v0.1 \
--upload-repo mlx-community/my-4bit-lora-mistral \
--hf-path mistralai/Mistral-7B-v0.1
@@ -134,7 +134,7 @@ python -m mlx_lm.fuse \
To export a fused model to GGUF, run:
```shell
python -m mlx_lm.fuse \
mlx_lm.fuse \
--model mistralai/Mistral-7B-v0.1 \
--export-gguf
```

View File

@@ -6,14 +6,14 @@ Face hub or save them locally for LoRA fine tuning.
The main command is `mlx_lm.merge`:
```shell
python -m mlx_lm.merge --config config.yaml
mlx_lm.merge --config config.yaml
```
The merged model will be saved by default in `mlx_merged_model`. To see a
full list of options run:
```shell
python -m mlx_lm.merge --help
mlx_lm.merge --help
```
Here is an example `config.yaml`:

View File

@@ -11,13 +11,13 @@ API](https://platform.openai.com/docs/api-reference).
Start the server with:
```shell
python -m mlx_lm.server --model <path_to_model_or_hf_repo>
mlx_lm.server --model <path_to_model_or_hf_repo>
```
For example:
```shell
python -m mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1
mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1
```
This will start a text generation server on port `8080` of the `localhost`
@@ -27,7 +27,7 @@ Hugging Face repo if it is not already in the local cache.
To see a full list of options run:
```shell
python -m mlx_lm.server --help
mlx_lm.server --help
```
You can make a request to the model by running:

View File

@@ -52,7 +52,11 @@ def configure_parser() -> argparse.ArgumentParser:
return parser
if __name__ == "__main__":
def main():
parser = configure_parser()
args = parser.parse_args()
convert(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -101,7 +101,10 @@ def colorprint_by_t0(s, t0):
colorprint(color, s)
def main(args):
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
# Building tokenizer_config
@@ -143,6 +146,4 @@ def main(args):
if __name__ == "__main__":
parser = setup_arg_parser()
args = parser.parse_args()
main(args)
main()

View File

@@ -247,7 +247,7 @@ def run(args, training_callback: TrainingCallback = None):
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if __name__ == "__main__":
def main():
parser = build_parser()
args = parser.parse_args()
config = args.config
@@ -266,3 +266,7 @@ if __name__ == "__main__":
if not args.get(k, None):
args[k] = v
run(types.SimpleNamespace(**args))
if __name__ == "__main__":
main()

View File

@@ -162,7 +162,11 @@ def merge(
upload_to_hub(mlx_path, upload_repo, base_hf_path)
if __name__ == "__main__":
def main():
parser = configure_parser()
args = parser.parse_args()
merge(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -409,7 +409,7 @@ def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler)
httpd.serve_forever()
if __name__ == "__main__":
def main():
parser = argparse.ArgumentParser(description="MLX Http Server.")
parser.add_argument(
"--model",
@@ -449,3 +449,7 @@ if __name__ == "__main__":
)
run(args.host, args.port)
if __name__ == "__main__":
main()