mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 02:48:07 +08:00 
			
		
		
		
	 b8a348c1b8
			
		
	
	b8a348c1b8
	
	
	
		
			
			* use nn.RMSNorm, use sdpa, cleanup * bump mlx versions * minor update * use fast layer norm * version bump * update requirement for whisper * update requirement for gguf
		
			
				
	
	
		
			227 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
			
		
		
	
	
			227 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
| # Fine-Tuning with LoRA or QLoRA
 | |
| 
 | |
| This is an example of using MLX to fine-tune an LLM with low rank adaptation
 | |
| (LoRA) for a target task.[^lora] The example also supports quantized LoRA
 | |
| (QLoRA).[^qlora] The example works with Llama and Mistral style models
 | |
| available on Hugging Face.
 | |
| 
 | |
| > [!TIP]
 | |
| > For a more fully featured LLM package, checkout [MLX
 | |
| > LM](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm).
 | |
| 
 | |
| In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
 | |
| generate SQL queries from natural language. However, the example is intended to
 | |
| be general should you wish to use a custom dataset.
 | |
| 
 | |
| ## Contents
 | |
| 
 | |
| * [Setup](#Setup)
 | |
|   * [Convert](#convert)
 | |
| * [Run](#Run)
 | |
|   * [Fine-tune](#Fine-tune)
 | |
|   * [Evaluate](#Evaluate)
 | |
|   * [Generate](#Generate)
 | |
| * [Results](#Results)
 | |
| * [Fuse and Upload](#Fuse-and-Upload)
 | |
| * [Custom Data](#Custom-Data)
 | |
| * [Memory Issues](#Memory-Issues)
 | |
| 
 | |
| 
 | |
| ## Setup 
 | |
| 
 | |
| Install the dependencies:
 | |
| 
 | |
| ```
 | |
| pip install -r requirements.txt
 | |
| ```
 | |
| 
 | |
| ### Convert
 | |
| 
 | |
| This step is optional if you want to quantize (for QLoRA) or change the default
 | |
| data type of a pre-existing model.
 | |
| 
 | |
| You convert models using the `convert.py` script. This script takes a Hugging
 | |
| Face repo as input and outputs a model directory (which you can optionally also
 | |
| upload to Hugging Face).
 | |
| 
 | |
| To make a 4-bit quantized model, run:
 | |
| 
 | |
| ```
 | |
| python convert.py --hf-path <hf_repo> -q
 | |
| ```
 | |
| 
 | |
| For example, the following will make a 4-bit quantized Mistral 7B and by default
 | |
| store it in `mlx_model`:
 | |
| 
 | |
| ```
 | |
| python convert.py --hf-path mistralai/Mistral-7B-v0.1 -q
 | |
| ```
 | |
| 
 | |
| For more options run:
 | |
| 
 | |
| ```
 | |
| python convert.py --help
 | |
| ```
 | |
| 
 | |
| You can upload new models to the [Hugging Face MLX
 | |
| Community](https://huggingface.co/mlx-community) by specifying `--upload-name`
 | |
| to `convert.py`.
 | |
| 
 | |
| ## Run
 | |
| 
 | |
| The main script is `lora.py`. To see a full list of options run:
 | |
| 
 | |
| ```
 | |
| python lora.py --help
 | |
| ```
 | |
| 
 | |
| Note, in the following the `--model` argument can be any compatible Hugging
 | |
| Face repo or a local path to a converted mdoel. 
 | |
| 
 | |
| ### Fine-tune
 | |
| 
 | |
| To fine-tune a model use:
 | |
| 
 | |
| ```
 | |
| python lora.py --model <path_to_model> \
 | |
|                --train \
 | |
|                --iters 600
 | |
| ```
 | |
| 
 | |
| If `--model` points to a quantized model, then the training will use QLoRA,
 | |
| otherwise it will use regular LoRA.
 | |
| 
 | |
| By default, the adapter weights are saved in `adapters.npz`. You can specify
 | |
| the output location with `--adapter-file`.
 | |
| 
 | |
| You can resume fine-tuning with an existing adapter with `--resume-adapter-file
 | |
| <path_to_adapters.npz>`. 
 | |
| 
 | |
| ### Evaluate
 | |
| 
 | |
| To compute test set perplexity use:
 | |
| 
 | |
| ```
 | |
| python lora.py --model <path_to_model> \
 | |
|                --adapter-file <path_to_adapters.npz> \
 | |
|                --test
 | |
| ```
 | |
| 
 | |
| ### Generate
 | |
| 
 | |
| For generation use:
 | |
| 
 | |
| ```
 | |
| python lora.py --model <path_to_model> \
 | |
|                --adapter-file <path_to_adapters.npz> \
 | |
|                --max-tokens 50 \
 | |
|                --prompt "table: 1-10015132-16
 | |
| columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
 | |
| Q: What is terrence ross' nationality
 | |
| A: "
 | |
| ```
 | |
| 
 | |
| ## Results
 | |
| 
 | |
| The initial validation loss for Llama 7B on the WikiSQL is 2.66 and the final
 | |
| validation loss after 1000 iterations is 1.23. The table below shows the
 | |
| training and validation loss at a few points over the course of training.
 | |
| 
 | |
| | Iteration | Train Loss | Validation Loss |
 | |
| | --------- | ---------- | --------------- |
 | |
| | 1         |    N/A     |      2.659      |
 | |
| | 200       |    1.264   |      1.405      |
 | |
| | 400       |    1.201   |      1.303      |
 | |
| | 600       |    1.123   |      1.274      |
 | |
| | 800       |    1.017   |      1.255      |
 | |
| | 1000      |    1.070   |      1.230      |
 | |
| 
 | |
| The model trains at around 475 tokens per second on an M2 Ultra.
 | |
| 
 | |
| ## Fuse and Upload
 | |
| 
 | |
| You can generate a fused model with the low-rank adapters included using the
 | |
| `fuse.py` script. This script also optionally allows you to upload the fused
 | |
| model to the [Hugging Face MLX
 | |
| Community](https://huggingface.co/mlx-community).
 | |
| 
 | |
| To generate the fused model run:
 | |
| 
 | |
| ```
 | |
| python fuse.py
 | |
| ```
 | |
| 
 | |
| This will by default load the base model from `mlx_model/`, the adapters from
 | |
| `adapters.npz`,  and save the fused model in the path `lora_fused_model/`. All
 | |
| of these are configurable. You can see the list of options with:
 | |
| 
 | |
| ```
 | |
| python fuse.py --help
 | |
| ```
 | |
| 
 | |
| To upload a fused model, supply the `--upload-name` and `--hf-path` arguments
 | |
| to `fuse.py`. The latter is the repo name of the original model, which is
 | |
| useful for the sake of attribution and model versioning.
 | |
| 
 | |
| For example, to fuse and upload a model derived from Mistral-7B-v0.1, run: 
 | |
| 
 | |
| ```
 | |
| python fuse.py --upload-name My-4-bit-model --hf-repo mistralai/Mistral-7B-v0.1
 | |
| ```
 | |
| 
 | |
| ## Custom Data
 | |
| 
 | |
| You can make your own dataset for fine-tuning with LoRA. You can specify the
 | |
| dataset with `--data=<my_data_directory>`. Check the subdirectory `data/` to
 | |
| see the expected format.
 | |
| 
 | |
| For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
 | |
| `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
 | |
| loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl`
 | |
| file should look like:
 | |
| 
 | |
| ```
 | |
| {"text": "This is an example for the model."}
 | |
| ```
 | |
| 
 | |
| Note other keys will be ignored by the loader.
 | |
| 
 | |
| ## Memory Issues
 | |
| 
 | |
| Fine-tuning a large model with LoRA requires a machine with a decent amount
 | |
| of memory. Here are some tips to reduce memory use should you need to do so:
 | |
| 
 | |
| 1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
 | |
|    with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
 | |
|    more details. 
 | |
| 
 | |
| 2. Try using a smaller batch size with `--batch-size`. The default is `4` so
 | |
|    setting this to `2` or `1` will reduce memory consumption. This may slow
 | |
|    things down a little, but will also reduce the memory use.
 | |
| 
 | |
| 3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
 | |
|    is `16`, so you can try `8` or `4`. This reduces the amount of memory
 | |
|    needed for back propagation. It may also reduce the quality of the
 | |
|    fine-tuned model if you are fine-tuning with a lot of data.
 | |
| 
 | |
| 4. Longer examples require more memory. If it makes sense for your data, one thing
 | |
|    you can do is break your examples into smaller
 | |
|    sequences when making the `{train, valid, test}.jsonl` files.
 | |
| 
 | |
| For example, for a machine with 32 GB the following should run reasonably fast:
 | |
| 
 | |
| ```
 | |
| python lora.py \
 | |
|    --model mistralai/Mistral-7B-v0.1 \
 | |
|    --train \
 | |
|    --batch-size 1 \
 | |
|    --lora-layers 4
 | |
| ```
 | |
| 
 | |
| The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second.
 | |
| 
 | |
| 
 | |
| [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
 | |
| [^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
 | |
| [^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.
 |