generalize lora finetuning for llama and mistral

This commit is contained in:
Awni Hannun
2023-12-09 14:13:55 -08:00
parent 07cdcef452
commit 8094503a68
5 changed files with 354 additions and 293 deletions

View File

@@ -1,7 +1,8 @@
# LoRA
This is an example of using MLX to fine-tune a Llama 7B[^llama] model with low
rank adaptation (LoRA)[^lora] for a target task.
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
task.
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
generate SQL queries from natural language. However, the example is intended to
@@ -15,19 +16,27 @@ Install the dependencies:
pip install -r requirements.txt
```
Next, download and convert the model. If you do not have access to the model
weights you will need to [request
Next, download and convert the model. The Mistral weights can be downloaded with:
```
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
tar -xf mistral-7B-v0.1.tar
```
If you do not have access to the Llama weights you will need to [request
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
from Meta.
Convert the weights with:
```
python convert.py <path_to_torch_weights> mlx_llama_7B.npz
python convert.py <path_to_torch_weights> <path_to_mlx_weights.npz>
```
## Run
#### Fine-tune
The main script is `lora.py`. To see a full list of options run
```
@@ -37,28 +46,34 @@ python lora.py --help
To fine-tune a model use:
```
python lora.py --model mlx_llama_7B.npz \
--tokenizer tokenizer.model \
python lora.py --model <path_to_model> \
--train \
--iters 600 \
--iters 600
```
Note, the model path should have the MLX weights, the tokenizer, and the
`params.json` configuration which will all be output by the `conver.py` script.
By default, the adapter weights are saved in `adapters.npz`. You can specify
the output location with `--adapter_file`.
#### Evaluate
To compute test set perplexity use
```
python lora.py --model mlx_llama_7B.npz \
--tokenizer tokenizer.model \
python lora.py --model <path_to_model> \
--adapter_file <path_to_adapters.npz> \
--test
```
#### Generate
For generation use
```
python lora.py --model mlx_llama_7B.npz \
--tokenizer tokenizer.model \
python lora.py --model <path_to_model> \
--adapter_file <path_to_adapters.npz> \
--num-tokens 50 \
--prompt "table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
@@ -87,4 +102,5 @@ The model trains at around 475 tokens per second on an M2 Ultra.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.