mlx-examples/lora
2023-12-09 14:13:55 -08:00
..
convert.py generalize lora finetuning for llama and mistral 2023-12-09 14:13:55 -08:00
lora.py generalize lora finetuning for llama and mistral 2023-12-09 14:13:55 -08:00
models.py generalize lora finetuning for llama and mistral 2023-12-09 14:13:55 -08:00
README.md generalize lora finetuning for llama and mistral 2023-12-09 14:13:55 -08:00
requirements.txt lora 2023-11-29 14:14:11 -08:00
wikisql.py add copyright in source 2023-11-30 11:08:53 -08:00

LoRA

This is an example of using MLX to fine-tune either a Llama 7B1 or a Mistral 7B2 model with low rank adaptation (LoRA)3 for a target task.

In this example we'll use the WikiSQL4 dataset to train the LLM to generate SQL queries from natural language. However, the example is intended to be general should you wish to modify the task.

Setup

Install the dependencies:

pip install -r requirements.txt

Next, download and convert the model. The Mistral weights can be downloaded with:

curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
tar -xf mistral-7B-v0.1.tar

If you do not have access to the Llama weights you will need to request access from Meta.

Convert the weights with:

python convert.py <path_to_torch_weights> <path_to_mlx_weights.npz>

Run

Fine-tune

The main script is lora.py. To see a full list of options run

python lora.py --help

To fine-tune a model use:

python lora.py --model <path_to_model> \
               --train \
               --iters 600

Note, the model path should have the MLX weights, the tokenizer, and the params.json configuration which will all be output by the conver.py script.

By default, the adapter weights are saved in adapters.npz. You can specify the output location with --adapter_file.

Evaluate

To compute test set perplexity use

python lora.py --model <path_to_model> \
               --adapter_file <path_to_adapters.npz> \
               --test 

Generate

For generation use

python lora.py --model <path_to_model> \
               --adapter_file <path_to_adapters.npz> \
               --num-tokens 50 \
               --prompt "table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
Q: What is terrence ross' nationality
A: "

Results

The initial validation loss for Llama 7B on the WikiSQL is 2.66 and the final validation loss after 1000 iterations is 1.23. The table below shows the training and validation loss at a few points over the course of training.

Iteration Train Loss Validation Loss
1 N/A 2.659
200 1.264 1.405
400 1.201 1.303
600 1.123 1.274
800 1.017 1.255
1000 1.070 1.230

After training for 1000 iterations, the validation perplexity reduces to XX.

The model trains at around 475 tokens per second on an M2 Ultra.


  1. Refer to the arXiv paper and blog post for more details. ↩︎

  2. Refer to the blog post and github repository for more details. ↩︎

  3. Refer to the arXiv paper for more details on LoRA. ↩︎

  4. Refer to the GitHub repo for more information about WikiSQL. ↩︎