diff --git a/README.md b/README.md index 0c56324b..5d4a6ff7 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Some more useful examples are listed below. [Mistral](llms/mistral), [Phi-2](llms/phi2), and more in the [LLMs](llms) directory. - A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral). -- Parameter efficient fine-tuning with [LoRA](lora). +- Parameter efficient fine-tuning with [LoRA or QLoRA](lora). - Text-to-text multi-task Transformers with [T5](t5). - Bidirectional language understanding with [BERT](bert). diff --git a/llms/deepseek-coder/README.md b/llms/deepseek-coder/README.md index cd309863..086b1960 100644 --- a/llms/deepseek-coder/README.md +++ b/llms/deepseek-coder/README.md @@ -32,6 +32,10 @@ page](https://huggingface.co/deepseek-ai) to see a list of available models. By default, the conversion script will save the converted `weights.npz`, tokenizer, and `config.json` in the `mlx_model` directory. +> [!TIP] Alternatively, you can also download a few converted checkpoints from +> the [MLX Community](https://huggingface.co/mlx-community) organization on +> Hugging Face and skip the conversion step. + ### Run Once you've converted the weights, you can interact with the Deepseek coder diff --git a/llms/llama/convert.py b/llms/llama/convert.py index c5e6e773..d8f2c8e6 100644 --- a/llms/llama/convert.py +++ b/llms/llama/convert.py @@ -14,11 +14,13 @@ import torch from llama import Llama, ModelArgs, sanitize_config from mlx.utils import tree_flatten, tree_map, tree_unflatten + def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array: # bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss - a = a.to(torch.float32) if dtype == 'bfloat16' else a.to(getattr(torch, dtype)) + a = a.to(torch.float32) if dtype == "bfloat16" else a.to(getattr(torch, dtype)) return mx.array(a.numpy(), getattr(mx, dtype)) + def llama(model_path, *, dtype: str): SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] SHARD_SECOND = ["tok_embeddings", "wo", "w2"] @@ -48,7 +50,7 @@ def llama(model_path, *, dtype: str): state = torch.load(wf, map_location=torch.device("cpu")) for k, v in state.items(): v = torch_to_mx(v, dtype=dtype) - state[k] = None # free memory + state[k] = None # free memory if shard_key(k) in SHARD_WEIGHTS: weights[k].append(v) else: @@ -204,7 +206,7 @@ if __name__ == "__main__": parser.add_argument( "--dtype", help="dtype for loading the torch model and input for quantization or saving the converted model. " - "The original weights are stored in bfloat16.", + "The original weights are stored in bfloat16.", type=str, default="float16", ) diff --git a/lora/README.md b/lora/README.md index 87a65ee0..c2980086 100644 --- a/lora/README.md +++ b/lora/README.md @@ -1,8 +1,8 @@ -# LoRA +# Fine-Tuning with LoRA or QLoRA This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target -task. +task. The example also supports quantized LoRA (QLoRA).[^qlora] In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to generate SQL queries from natural language. However, the example is intended to @@ -43,10 +43,13 @@ Convert the model with: ``` python convert.py \ - --torch-model \ - --mlx-model + --torch-path \ + --mlx-path ``` +If you wish to use QLoRA, then convert the model with 4-bit quantization using +the `-q` option. + ## Run The main script is `lora.py`. To see a full list of options run @@ -65,8 +68,11 @@ python lora.py --model \ --iters 600 ``` +If `--model` points to a quantized model, then the training will use QLoRA, +otherwise it will use regular LoRA. + Note, the model path should have the MLX weights, the tokenizer, and the -`params.json` configuration which will all be output by the `convert.py` script. +`config.json` which will all be output by the `convert.py` script. By default, the adapter weights are saved in `adapters.npz`. You can specify the output location with `--adapter-file`. @@ -137,16 +143,20 @@ Note other keys will be ignored by the loader. Fine-tuning a large model with LoRA requires a machine with a decent amount of memory. Here are some tips to reduce memory use should you need to do so: -1. Try using a smaller batch size with `--batch-size`. The default is `4` so +1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model + with `convert.py` and the `-q` flag. See the [Setup](#setup) section for + more details. + +2. Try using a smaller batch size with `--batch-size`. The default is `4` so setting this to `2` or `1` will reduce memory consumption. This may slow things down a little, but will also reduce the memory use. -2. Reduce the number of layers to fine-tune with `--lora-layers`. The default +3. Reduce the number of layers to fine-tune with `--lora-layers`. The default is `16`, so you can try `8` or `4`. This reduces the amount of memory needed for back propagation. It may also reduce the quality of the fine-tuned model if you are fine-tuning with a lot of data. -3. Longer examples require more memory. If it makes sense for your data, one thing +4. Longer examples require more memory. If it makes sense for your data, one thing you can do is break your examples into smaller sequences when making the `{train, valid, test}.jsonl` files. @@ -164,6 +174,7 @@ The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second. [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. +[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) [^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. [^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details. [^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL. diff --git a/lora/convert.py b/lora/convert.py index 6f71d772..4a353903 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -1,69 +1,125 @@ # Copyright © 2023 Apple Inc. import argparse +import copy import json -import os import shutil from pathlib import Path +import mlx.core as mx +import mlx.nn as nn import numpy as np import torch +from mlx.utils import tree_flatten, tree_map, tree_unflatten + +from lora import Model, ModelArgs + + +def quantize(weights, config, args): + quantized_config = copy.deepcopy(config) + + # Load the model: + model = Model(ModelArgs(**config)) + weights = tree_map(mx.array, weights) + model.update(tree_unflatten(list(weights.items()))) + + # Quantize the model: + nn.QuantizedLinear.quantize_module( + model, + args.q_group_size, + args.q_bits, + linear_class_predicate=lambda m: isinstance(m, nn.Linear) + and m.weight.shape[0] != config["vocab_size"], + ) + + # Update the config: + quantized_config["quantization"] = { + "group_size": args.q_group_size, + "bits": args.q_bits, + } + quantized_weights = dict(tree_flatten(model.parameters())) + + return quantized_weights, quantized_config + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Convert Mistral or Llama models to MLX.", ) parser.add_argument( - "--torch-model", + "--torch-path", type=str, default="mistral-7B-v0.1/", - help="The torch model directory", + help="Path to the torch model directory", ) parser.add_argument( - "--mlx-model", + "--mlx-path", type=str, - default="mlx-mistral-7B-v0.1/", + default="mlx_model/", help="The directory to store the mlx model", ) + parser.add_argument( + "-q", + "--quantize", + help="Generate a quantized model.", + action="store_true", + ) + parser.add_argument( + "--q-group-size", + help="Group size for quantization.", + type=int, + default=64, + ) + parser.add_argument( + "--q-bits", + help="Bits per weight for quantization.", + type=int, + default=4, + ) args = parser.parse_args() - torch_path = Path(args.torch_model) - if not os.path.exists(args.mlx_model): - os.makedirs(args.mlx_model) - mlx_path = Path(args.mlx_model) + args = parser.parse_args() + + torch_path = Path(args.torch_path) + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) # Copy the tokenizer tokenizer_path = torch_path / "tokenizer.model" if not tokenizer_path.exists(): - print(f"Make sure there is a file tokenizer.model in {args.torch_model}") + print(f"Make sure there is a file tokenizer.model in {args.torch-path}") exit(0) shutil.copyfile( str(tokenizer_path), str(mlx_path / "tokenizer.model"), ) - # Copy the model weights - state = torch.load(str(torch_path / "consolidated.00.pth")) - np.savez( - str(mlx_path / "weights.npz"), - **{k: v.to(torch.float16).numpy() for k, v in state.items()}, - ) + # Load the torch model weights to numpy: + weights = torch.load(str(torch_path / "consolidated.00.pth")) + for k, v in weights.items(): + weights[k] = v.to(torch.float16).numpy() - # Copy the params + # Standardize the params with open(torch_path / "params.json", "r") as f: config = json.loads(f.read()) - unused = ["multiple_of"] + unused = ["multiple_of", "sliding_window"] for k in unused: - if k in config: - config.pop(k) + config.pop(k, None) n_heads = config["n_heads"] - if "sliding_window" in config: - config.pop("sliding_window") if "n_kv_heads" not in config: config["n_kv_heads"] = n_heads if "head_dim" not in config: config["head_dim"] = config["dim"] // n_heads if "hidden_dim" not in config: - config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape[0] - with open(mlx_path / "params.json", "w") as outfile: + config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] + if config.get("vocab_size", -1) < 0: + config["vocab_size"] = weights["output.weight"].shape[0] + + if args.quantize: + print("[INFO] Quantizing") + weights, config = quantize(weights, config, args) + + np.savez(str(mlx_path / "weights.npz"), **weights) + + with open(mlx_path / "config.json", "w") as outfile: json.dump(config, outfile, indent=4) diff --git a/lora/lora.py b/lora/lora.py index 528cf506..8a35a6d4 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -17,12 +17,10 @@ from sentencepiece import SentencePieceProcessor def build_parser(): - parser = argparse.ArgumentParser( - description="LoRA finetuning with Llama or Mistral" - ) + parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser.add_argument( "--model", - required=True, + default="mlx_model", help="A path to the model files containing the tokenizer, weights, config.", ) # Generation args @@ -332,18 +330,22 @@ def generate(model, prompt, tokenizer, args): print(s, flush=True) -def load_model(folder: str, dtype=mx.float16): +def load_model(folder: str): model_path = Path(folder) tokenizer = Tokenizer(str(model_path / "tokenizer.model")) - with open(model_path / "params.json", "r") as f: + with open(model_path / "config.json", "r") as f: config = json.loads(f.read()) - if config.get("vocab_size", -1) < 0: - config["vocab_size"] = tokenizer.vocab_size + quantization = config.pop("quantization", None) model_args = ModelArgs(**config) + model = Model(model_args) + if quantization is not None: + quantization["linear_class_predicate"] = lambda m: isinstance( + m, nn.Linear + ) and (m.weight.shape[0] != model_args.vocab_size) + nn.QuantizedLinear.quantize_module(model, **quantization) + weights = mx.load(str(model_path / "weights.npz")) weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: p.astype(dtype), weights) - model = Model(model_args) model.update(weights) return model, tokenizer @@ -374,7 +376,7 @@ if __name__ == "__main__": # Resume training the given adapters. if args.resume_adapter_file is not None: print(f"Loading pretrained adapters from {args.resume_adapter_file}") - model.load_weights(args.resume_adapter_file) + model.load_weights(args.resume_adapter_file, strict=False) if args.train: print("Training") @@ -387,7 +389,12 @@ if __name__ == "__main__": mx.savez(args.adapter_file, **dict(tree_flatten(model.trainable_parameters()))) # Load the LoRA adapter weights which we assume should exist by this point - model.load_weights(args.adapter_file) + if not Path(args.adapter_file).is_file(): + raise ValueError( + f"Adapter file {args.adapter_file} missing. " + "Use --train to learn and save the adapters.npz." + ) + model.load_weights(args.adapter_file, strict=False) if args.test: print("Testing") diff --git a/lora/models.py b/lora/models.py index 31eef654..3208de35 100644 --- a/lora/models.py +++ b/lora/models.py @@ -1,5 +1,4 @@ # Copyright © 2023 Apple Inc. - import math from dataclasses import dataclass from typing import List, Optional, Tuple @@ -24,7 +23,11 @@ class ModelArgs: class LoRALinear(nn.Module): @staticmethod def from_linear(linear: nn.Linear, rank: int = 8): + # TODO remove when input_dims and output_dims are attributes + # on linear and quantized linear output_dims, input_dims = linear.weight.shape + if isinstance(linear, nn.QuantizedLinear): + input_dims *= 32 // linear.bits lora_lin = LoRALinear(input_dims, output_dims, rank) lora_lin.linear = linear return lora_lin @@ -47,7 +50,10 @@ class LoRALinear(nn.Module): self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) def __call__(self, x): - y = self.linear(x.astype(self.linear.weight.dtype)) + dtype = self.linear.weight.dtype + if isinstance(self.linear, nn.QuantizedLinear): + dtype = self.linear.scales.dtype + y = self.linear(x.astype(dtype)) z = (x @ self.lora_a) @ self.lora_b return y + 2.0 * z diff --git a/lora/requirements.txt b/lora/requirements.txt index 7111f1d4..70a195e9 100644 --- a/lora/requirements.txt +++ b/lora/requirements.txt @@ -1,3 +1,3 @@ -mlx +mlx>=0.0.7 sentencepiece torch