mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
generalize lora finetuning for llama and mistral
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
# LoRA
|
||||
|
||||
This is an example of using MLX to fine-tune a Llama 7B[^llama] model with low
|
||||
rank adaptation (LoRA)[^lora] for a target task.
|
||||
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
|
||||
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
|
||||
task.
|
||||
|
||||
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
|
||||
generate SQL queries from natural language. However, the example is intended to
|
||||
@@ -15,19 +16,27 @@ Install the dependencies:
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Next, download and convert the model. If you do not have access to the model
|
||||
weights you will need to [request
|
||||
Next, download and convert the model. The Mistral weights can be downloaded with:
|
||||
|
||||
```
|
||||
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
|
||||
tar -xf mistral-7B-v0.1.tar
|
||||
```
|
||||
|
||||
If you do not have access to the Llama weights you will need to [request
|
||||
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
|
||||
from Meta.
|
||||
|
||||
Convert the weights with:
|
||||
|
||||
```
|
||||
python convert.py <path_to_torch_weights> mlx_llama_7B.npz
|
||||
python convert.py <path_to_torch_weights> <path_to_mlx_weights.npz>
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
#### Fine-tune
|
||||
|
||||
The main script is `lora.py`. To see a full list of options run
|
||||
|
||||
```
|
||||
@@ -37,28 +46,34 @@ python lora.py --help
|
||||
To fine-tune a model use:
|
||||
|
||||
```
|
||||
python lora.py --model mlx_llama_7B.npz \
|
||||
--tokenizer tokenizer.model \
|
||||
python lora.py --model <path_to_model> \
|
||||
--train \
|
||||
--iters 600 \
|
||||
--iters 600
|
||||
```
|
||||
|
||||
Note, the model path should have the MLX weights, the tokenizer, and the
|
||||
`params.json` configuration which will all be output by the `conver.py` script.
|
||||
|
||||
By default, the adapter weights are saved in `adapters.npz`. You can specify
|
||||
the output location with `--adapter_file`.
|
||||
|
||||
#### Evaluate
|
||||
|
||||
To compute test set perplexity use
|
||||
|
||||
```
|
||||
python lora.py --model mlx_llama_7B.npz \
|
||||
--tokenizer tokenizer.model \
|
||||
python lora.py --model <path_to_model> \
|
||||
--adapter_file <path_to_adapters.npz> \
|
||||
--test
|
||||
```
|
||||
|
||||
#### Generate
|
||||
|
||||
For generation use
|
||||
|
||||
```
|
||||
python lora.py --model mlx_llama_7B.npz \
|
||||
--tokenizer tokenizer.model \
|
||||
python lora.py --model <path_to_model> \
|
||||
--adapter_file <path_to_adapters.npz> \
|
||||
--num-tokens 50 \
|
||||
--prompt "table: 1-10015132-16
|
||||
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
|
||||
@@ -87,4 +102,5 @@ The model trains at around 475 tokens per second on an M2 Ultra.
|
||||
|
||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
||||
[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
|
||||
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.
|
||||
|
Reference in New Issue
Block a user