| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  | # Fine-Tuning with LoRA or QLoRA
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | This is an example of using MLX to fine-tune an LLM with low rank adaptation | 
					
						
							|  |  |  | (LoRA) for a target task.[^lora] The example also supports quantized LoRA | 
					
						
							| 
									
										
										
										
											2024-03-23 07:13:51 -07:00
										 |  |  | (QLoRA).[^qlora] The example works with Llama and Mistral style models | 
					
						
							|  |  |  | available on Hugging Face. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | > [!TIP]
 | 
					
						
							|  |  |  | > For a more fully featured LLM package, checkout [MLX
 | 
					
						
							|  |  |  | > LM](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm).
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to | 
					
						
							|  |  |  | generate SQL queries from natural language. However, the example is intended to | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  | be general should you wish to use a custom dataset. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ## Contents
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 09:59:07 -08:00
										 |  |  | * [Setup](#Setup) | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  |   * [Convert](#convert) | 
					
						
							| 
									
										
										
										
											2023-12-15 09:59:07 -08:00
										 |  |  | * [Run](#Run) | 
					
						
							|  |  |  |   * [Fine-tune](#Fine-tune) | 
					
						
							|  |  |  |   * [Evaluate](#Evaluate) | 
					
						
							|  |  |  |   * [Generate](#Generate) | 
					
						
							|  |  |  | * [Results](#Results) | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | * [Fuse and Upload](#Fuse-and-Upload) | 
					
						
							| 
									
										
										
										
											2023-12-15 09:59:07 -08:00
										 |  |  | * [Custom Data](#Custom-Data) | 
					
						
							|  |  |  | * [Memory Issues](#Memory-Issues) | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | ## Setup 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Install the dependencies: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | pip install -r requirements.txt | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | ### Convert
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | This step is optional if you want to quantize (for QLoRA) or change the default | 
					
						
							|  |  |  | data type of a pre-existing model. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | You convert models using the `convert.py` script. This script takes a Hugging | 
					
						
							|  |  |  | Face repo as input and outputs a model directory (which you can optionally also | 
					
						
							|  |  |  | upload to Hugging Face). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | To make a 4-bit quantized model, run: | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | python convert.py --hf-path <hf_repo> -q | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | For example, the following will make a 4-bit quantized Mistral 7B and by default | 
					
						
							|  |  |  | store it in `mlx_model`: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | python convert.py --hf-path mistralai/Mistral-7B-v0.1 -q | 
					
						
							|  |  |  | ``` | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | For more options run: | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | python convert.py --help | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | You can upload new models to the [Hugging Face MLX | 
					
						
							|  |  |  | Community](https://huggingface.co/mlx-community) by specifying `--upload-name` | 
					
						
							|  |  |  | to `convert.py`. | 
					
						
							| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | ## Run
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | The main script is `lora.py`. To see a full list of options run: | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | python lora.py --help | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | Note, in the following the `--model` argument can be any compatible Hugging | 
					
						
							|  |  |  | Face repo or a local path to a converted mdoel.  | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  | ### Fine-tune
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | To fine-tune a model use: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  | python lora.py --model <path_to_model> \ | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |                --train \ | 
					
						
							| 
									
										
										
										
											2024-01-13 11:35:03 -05:00
										 |  |  |                --iters 600 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  | If `--model` points to a quantized model, then the training will use QLoRA, | 
					
						
							|  |  |  | otherwise it will use regular LoRA. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | By default, the adapter weights are saved in `adapters.npz`. You can specify | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  | the output location with `--adapter-file`. | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  | You can resume fine-tuning with an existing adapter with `--resume-adapter-file | 
					
						
							|  |  |  | <path_to_adapters.npz>`.  | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  | ### Evaluate
 | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | To compute test set perplexity use: | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  | python lora.py --model <path_to_model> \ | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  |                --adapter-file <path_to_adapters.npz> \ | 
					
						
							| 
									
										
										
										
											2024-01-13 11:35:03 -05:00
										 |  |  |                --test | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  | ### Generate
 | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | For generation use: | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  | python lora.py --model <path_to_model> \ | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  |                --adapter-file <path_to_adapters.npz> \ | 
					
						
							| 
									
										
										
										
											2024-01-10 16:13:06 -08:00
										 |  |  |                --max-tokens 50 \ | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |                --prompt "table: 1-10015132-16 | 
					
						
							|  |  |  | columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team | 
					
						
							|  |  |  | Q: What is terrence ross' nationality | 
					
						
							| 
									
										
										
										
											2024-01-13 11:35:03 -05:00
										 |  |  | A: " | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ## Results
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | The initial validation loss for Llama 7B on the WikiSQL is 2.66 and the final | 
					
						
							|  |  |  | validation loss after 1000 iterations is 1.23. The table below shows the | 
					
						
							|  |  |  | training and validation loss at a few points over the course of training. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | | Iteration | Train Loss | Validation Loss | | 
					
						
							|  |  |  | | --------- | ---------- | --------------- | | 
					
						
							|  |  |  | | 1         |    N/A     |      2.659      | | 
					
						
							|  |  |  | | 200       |    1.264   |      1.405      | | 
					
						
							|  |  |  | | 400       |    1.201   |      1.303      | | 
					
						
							|  |  |  | | 600       |    1.123   |      1.274      | | 
					
						
							|  |  |  | | 800       |    1.017   |      1.255      | | 
					
						
							|  |  |  | | 1000      |    1.070   |      1.230      | | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | The model trains at around 475 tokens per second on an M2 Ultra. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | ## Fuse and Upload
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | You can generate a fused model with the low-rank adapters included using the | 
					
						
							|  |  |  | `fuse.py` script. This script also optionally allows you to upload the fused | 
					
						
							|  |  |  | model to the [Hugging Face MLX | 
					
						
							|  |  |  | Community](https://huggingface.co/mlx-community). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | To generate the fused model run: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | python fuse.py | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | This will by default load the base model from `mlx_model/`, the adapters from | 
					
						
							|  |  |  | `adapters.npz`,  and save the fused model in the path `lora_fused_model/`. All | 
					
						
							|  |  |  | of these are configurable. You can see the list of options with: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | python fuse.py --help | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | To upload a fused model, supply the `--upload-name` and `--hf-path` arguments | 
					
						
							|  |  |  | to `fuse.py`. The latter is the repo name of the original model, which is | 
					
						
							|  |  |  | useful for the sake of attribution and model versioning. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:  | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							| 
									
										
										
										
											2024-01-09 19:46:38 -08:00
										 |  |  | python fuse.py --upload-name My-4-bit-model --hf-repo mistralai/Mistral-7B-v0.1 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  | ## Custom Data
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | You can make your own dataset for fine-tuning with LoRA. You can specify the | 
					
						
							|  |  |  | dataset with `--data=<my_data_directory>`. Check the subdirectory `data/` to | 
					
						
							|  |  |  | see the expected format. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  | For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a | 
					
						
							|  |  |  | `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data | 
					
						
							|  |  |  | loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl` | 
					
						
							|  |  |  | file should look like: | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | {"text": "This is an example for the model."} | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Note other keys will be ignored by the loader. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ## Memory Issues
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-21 23:17:11 +01:00
										 |  |  | Fine-tuning a large model with LoRA requires a machine with a decent amount | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  | of memory. Here are some tips to reduce memory use should you need to do so: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  | 1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model | 
					
						
							|  |  |  |    with `convert.py` and the `-q` flag. See the [Setup](#setup) section for | 
					
						
							|  |  |  |    more details.  | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 2. Try using a smaller batch size with `--batch-size`. The default is `4` so | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |    setting this to `2` or `1` will reduce memory consumption. This may slow | 
					
						
							|  |  |  |    things down a little, but will also reduce the memory use. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  | 3. Reduce the number of layers to fine-tune with `--lora-layers`. The default | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |    is `16`, so you can try `8` or `4`. This reduces the amount of memory | 
					
						
							|  |  |  |    needed for back propagation. It may also reduce the quality of the | 
					
						
							|  |  |  |    fine-tuned model if you are fine-tuning with a lot of data. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  | 4. Longer examples require more memory. If it makes sense for your data, one thing | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |    you can do is break your examples into smaller | 
					
						
							|  |  |  |    sequences when making the `{train, valid, test}.jsonl` files. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 12:18:29 -08:00
										 |  |  | For example, for a machine with 32 GB the following should run reasonably fast: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | python lora.py \ | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  |    --model mistralai/Mistral-7B-v0.1 \ | 
					
						
							| 
									
										
										
										
											2023-12-15 12:18:29 -08:00
										 |  |  |    --train \ | 
					
						
							|  |  |  |    --batch-size 1 \ | 
					
						
							| 
									
										
										
										
											2024-01-13 11:35:03 -05:00
										 |  |  |    --lora-layers 4 | 
					
						
							| 
									
										
										
										
											2023-12-15 12:18:29 -08:00
										 |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 12:20:15 -08:00
										 |  |  | The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second. | 
					
						
							| 
									
										
										
										
											2023-12-15 12:18:29 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. | 
					
						
							| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  | [^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | [^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL. |