Move lora example to use the same model format / conversion as hf_llm (#252)

* huffing face the lora example to allow more models

* fixes

* comments

* more readme nits

* fusion + works better for qlora

* nits'

* comments
This commit is contained in:
Awni Hannun
2024-01-09 11:14:52 -08:00
committed by GitHub
parent bbd7172eef
commit 7b258f33ac
10 changed files with 521 additions and 224 deletions

View File

@@ -1,8 +1,9 @@
# Fine-Tuning with LoRA or QLoRA
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
task. The example also supports quantized LoRA (QLoRA).[^qlora]
This is an example of using MLX to fine-tune an LLM with low rank adaptation
(LoRA) for a target task.[^lora] The example also supports quantized LoRA
(QLoRA).[^qlora] The example works with Llama and Mistral style
models available on Hugging Face.
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
generate SQL queries from natural language. However, the example is intended to
@@ -11,11 +12,13 @@ be general should you wish to use a custom dataset.
## Contents
* [Setup](#Setup)
* [Convert](#convert)
* [Run](#Run)
* [Fine-tune](#Fine-tune)
* [Evaluate](#Evaluate)
* [Generate](#Generate)
* [Results](#Results)
* [Fuse and Upload](#Fuse-and-Upload)
* [Custom Data](#Custom-Data)
* [Memory Issues](#Memory-Issues)
@@ -28,36 +31,49 @@ Install the dependencies:
pip install -r requirements.txt
```
Next, download and convert the model. The Mistral weights can be downloaded with:
### Convert
This step is optional if you want to quantize (for QLoRA) or change the default
data type of a pre-existing model.
You convert models using the `convert.py` script. This script takes a Hugging
Face repo as input and outputs a model directory (which you can optionally also
upload to Hugging Face).
To make a 4-bit quantized model, run:
```
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
tar -xf mistral-7B-v0.1.tar
python convert.py --hf-path <hf_repo> -q
```
If you do not have access to the Llama weights you will need to [request
access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
from Meta.
Convert the model with:
For example, the following will make a 4-bit quantized Mistral 7B and by default
store it in `mlx_model`:
```
python convert.py \
--torch-path <path_to_torch_model> \
--mlx-path <path_to_mlx_model>
python convert.py --hf-path mistralai/Mistral-7B-v0.1 -q
```
If you wish to use QLoRA, then convert the model with 4-bit quantization using
the `-q` option.
For more options run:
```
python convert.py --help
```
You can upload new models to the [Hugging Face MLX
Community](https://huggingface.co/mlx-community) by specifying `--upload-name`
to `convert.py`.
## Run
The main script is `lora.py`. To see a full list of options run
The main script is `lora.py`. To see a full list of options run:
```
python lora.py --help
```
Note, in the following the `--model` argument can be any compatible Hugging
Face repo or a local path to a converted mdoel.
### Fine-tune
To fine-tune a model use:
@@ -71,9 +87,6 @@ python lora.py --model <path_to_model> \
If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA.
Note, the model path should have the MLX weights, the tokenizer, and the
`config.json` which will all be output by the `convert.py` script.
By default, the adapter weights are saved in `adapters.npz`. You can specify
the output location with `--adapter-file`.
@@ -82,7 +95,7 @@ You can resume fine-tuning with an existing adapter with `--resume-adapter-file
### Evaluate
To compute test set perplexity use
To compute test set perplexity use:
```
python lora.py --model <path_to_model> \
@@ -92,7 +105,7 @@ python lora.py --model <path_to_model> \
### Generate
For generation use
For generation use:
```
python lora.py --model <path_to_model> \
@@ -121,6 +134,37 @@ training and validation loss at a few points over the course of training.
The model trains at around 475 tokens per second on an M2 Ultra.
## Fuse and Upload
You can generate a fused model with the low-rank adapters included using the
`fuse.py` script. This script also optionally allows you to upload the fused
model to the [Hugging Face MLX
Community](https://huggingface.co/mlx-community).
To generate the fused model run:
```
python fuse.py
```
This will by default load the base model from `mlx_model/`, the adapters from
`adapters.npz`, and save the fused model in the path `lora_fused_model/`. All
of these are configurable. You can see the list of options with:
```
python fuse.py --help
```
To upload a fused model, supply the `--upload-name` and `--hf-path` arguments
to `fuse.py`. The latter is the repo name of the original model, which is
useful for the sake of attribution and model versioning.
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
```
python fuse.py --upload My-4-bit-model --hf-repo mistralai/Mistral-7B-v0.1
```
## Custom Data
You can make your own dataset for fine-tuning with LoRA. You can specify the
@@ -164,7 +208,7 @@ For example, for a machine with 32 GB the following should run reasonably fast:
```
python lora.py \
--model <path_to_model> \
--model mistralai/Mistral-7B-v0.1 \
--train \
--batch-size 1 \
--lora-layers 4
@@ -175,6 +219,4 @@ The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.