qlora
This commit is contained in:
Awni Hannun 2024-01-04 21:05:59 -08:00 committed by GitHub
parent 4fa659acbd
commit 37b41cec60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 137 additions and 51 deletions

View File

@ -14,7 +14,7 @@ Some more useful examples are listed below.
[Mistral](llms/mistral), [Phi-2](llms/phi2), and more in the [LLMs](llms)
directory.
- A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral).
- Parameter efficient fine-tuning with [LoRA](lora).
- Parameter efficient fine-tuning with [LoRA or QLoRA](lora).
- Text-to-text multi-task Transformers with [T5](t5).
- Bidirectional language understanding with [BERT](bert).

View File

@ -32,6 +32,10 @@ page](https://huggingface.co/deepseek-ai) to see a list of available models.
By default, the conversion script will save the converted `weights.npz`,
tokenizer, and `config.json` in the `mlx_model` directory.
> [!TIP] Alternatively, you can also download a few converted checkpoints from
> the [MLX Community](https://huggingface.co/mlx-community) organization on
> Hugging Face and skip the conversion step.
### Run
Once you've converted the weights, you can interact with the Deepseek coder

View File

@ -14,11 +14,13 @@ import torch
from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
a = a.to(torch.float32) if dtype == 'bfloat16' else a.to(getattr(torch, dtype))
a = a.to(torch.float32) if dtype == "bfloat16" else a.to(getattr(torch, dtype))
return mx.array(a.numpy(), getattr(mx, dtype))
def llama(model_path, *, dtype: str):
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
@ -48,7 +50,7 @@ def llama(model_path, *, dtype: str):
state = torch.load(wf, map_location=torch.device("cpu"))
for k, v in state.items():
v = torch_to_mx(v, dtype=dtype)
state[k] = None # free memory
state[k] = None # free memory
if shard_key(k) in SHARD_WEIGHTS:
weights[k].append(v)
else:
@ -204,7 +206,7 @@ if __name__ == "__main__":
parser.add_argument(
"--dtype",
help="dtype for loading the torch model and input for quantization or saving the converted model. "
"The original weights are stored in bfloat16.",
"The original weights are stored in bfloat16.",
type=str,
default="float16",
)

View File

@ -1,8 +1,8 @@
# LoRA
# Fine-Tuning with LoRA or QLoRA
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
task.
task. The example also supports quantized LoRA (QLoRA).[^qlora]
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
generate SQL queries from natural language. However, the example is intended to
@ -43,10 +43,13 @@ Convert the model with:
```
python convert.py \
--torch-model <path_to_torch_model> \
--mlx-model <path_to_mlx_model>
--torch-path <path_to_torch_model> \
--mlx-path <path_to_mlx_model>
```
If you wish to use QLoRA, then convert the model with 4-bit quantization using
the `-q` option.
## Run
The main script is `lora.py`. To see a full list of options run
@ -65,8 +68,11 @@ python lora.py --model <path_to_model> \
--iters 600
```
If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA.
Note, the model path should have the MLX weights, the tokenizer, and the
`params.json` configuration which will all be output by the `convert.py` script.
`config.json` which will all be output by the `convert.py` script.
By default, the adapter weights are saved in `adapters.npz`. You can specify
the output location with `--adapter-file`.
@ -137,16 +143,20 @@ Note other keys will be ignored by the loader.
Fine-tuning a large model with LoRA requires a machine with a decent amount
of memory. Here are some tips to reduce memory use should you need to do so:
1. Try using a smaller batch size with `--batch-size`. The default is `4` so
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
more details.
2. Try using a smaller batch size with `--batch-size`. The default is `4` so
setting this to `2` or `1` will reduce memory consumption. This may slow
things down a little, but will also reduce the memory use.
2. Reduce the number of layers to fine-tune with `--lora-layers`. The default
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
is `16`, so you can try `8` or `4`. This reduces the amount of memory
needed for back propagation. It may also reduce the quality of the
fine-tuned model if you are fine-tuning with a lot of data.
3. Longer examples require more memory. If it makes sense for your data, one thing
4. Longer examples require more memory. If it makes sense for your data, one thing
you can do is break your examples into smaller
sequences when making the `{train, valid, test}.jsonl` files.
@ -164,6 +174,7 @@ The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.

View File

@ -1,69 +1,125 @@
# Copyright © 2023 Apple Inc.
import argparse
import copy
import json
import os
import shutil
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from lora import Model, ModelArgs
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model = Model(ModelArgs(**config))
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(
model,
args.q_group_size,
args.q_bits,
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != config["vocab_size"],
)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert Mistral or Llama models to MLX.",
)
parser.add_argument(
"--torch-model",
"--torch-path",
type=str,
default="mistral-7B-v0.1/",
help="The torch model directory",
help="Path to the torch model directory",
)
parser.add_argument(
"--mlx-model",
"--mlx-path",
type=str,
default="mlx-mistral-7B-v0.1/",
default="mlx_model/",
help="The directory to store the mlx model",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
torch_path = Path(args.torch_model)
if not os.path.exists(args.mlx_model):
os.makedirs(args.mlx_model)
mlx_path = Path(args.mlx_model)
args = parser.parse_args()
torch_path = Path(args.torch_path)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
# Copy the tokenizer
tokenizer_path = torch_path / "tokenizer.model"
if not tokenizer_path.exists():
print(f"Make sure there is a file tokenizer.model in {args.torch_model}")
print(f"Make sure there is a file tokenizer.model in {args.torch-path}")
exit(0)
shutil.copyfile(
str(tokenizer_path),
str(mlx_path / "tokenizer.model"),
)
# Copy the model weights
state = torch.load(str(torch_path / "consolidated.00.pth"))
np.savez(
str(mlx_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
)
# Load the torch model weights to numpy:
weights = torch.load(str(torch_path / "consolidated.00.pth"))
for k, v in weights.items():
weights[k] = v.to(torch.float16).numpy()
# Copy the params
# Standardize the params
with open(torch_path / "params.json", "r") as f:
config = json.loads(f.read())
unused = ["multiple_of"]
unused = ["multiple_of", "sliding_window"]
for k in unused:
if k in config:
config.pop(k)
config.pop(k, None)
n_heads = config["n_heads"]
if "sliding_window" in config:
config.pop("sliding_window")
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape[0]
with open(mlx_path / "params.json", "w") as outfile:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[0]
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
np.savez(str(mlx_path / "weights.npz"), **weights)
with open(mlx_path / "config.json", "w") as outfile:
json.dump(config, outfile, indent=4)

View File

@ -17,12 +17,10 @@ from sentencepiece import SentencePieceProcessor
def build_parser():
parser = argparse.ArgumentParser(
description="LoRA finetuning with Llama or Mistral"
)
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
required=True,
default="mlx_model",
help="A path to the model files containing the tokenizer, weights, config.",
)
# Generation args
@ -332,18 +330,22 @@ def generate(model, prompt, tokenizer, args):
print(s, flush=True)
def load_model(folder: str, dtype=mx.float16):
def load_model(folder: str):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f:
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = tokenizer.vocab_size
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
model = Model(model_args)
if quantization is not None:
quantization["linear_class_predicate"] = lambda m: isinstance(
m, nn.Linear
) and (m.weight.shape[0] != model_args.vocab_size)
nn.QuantizedLinear.quantize_module(model, **quantization)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Model(model_args)
model.update(weights)
return model, tokenizer
@ -374,7 +376,7 @@ if __name__ == "__main__":
# Resume training the given adapters.
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file)
model.load_weights(args.resume_adapter_file, strict=False)
if args.train:
print("Training")
@ -387,7 +389,12 @@ if __name__ == "__main__":
mx.savez(args.adapter_file, **dict(tree_flatten(model.trainable_parameters())))
# Load the LoRA adapter weights which we assume should exist by this point
model.load_weights(args.adapter_file)
if not Path(args.adapter_file).is_file():
raise ValueError(
f"Adapter file {args.adapter_file} missing. "
"Use --train to learn and save the adapters.npz."
)
model.load_weights(args.adapter_file, strict=False)
if args.test:
print("Testing")

View File

@ -1,5 +1,4 @@
# Copyright © 2023 Apple Inc.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
@ -24,7 +23,11 @@ class ModelArgs:
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
input_dims *= 32 // linear.bits
lora_lin = LoRALinear(input_dims, output_dims, rank)
lora_lin.linear = linear
return lora_lin
@ -47,7 +50,10 @@ class LoRALinear(nn.Module):
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
def __call__(self, x):
y = self.linear(x.astype(self.linear.weight.dtype))
dtype = self.linear.weight.dtype
if isinstance(self.linear, nn.QuantizedLinear):
dtype = self.linear.scales.dtype
y = self.linear(x.astype(dtype))
z = (x @ self.lora_a) @ self.lora_b
return y + 2.0 * z

View File

@ -1,3 +1,3 @@
mlx
mlx>=0.0.7
sentencepiece
torch