This commit is contained in:
Awni Hannun 2023-11-29 14:14:11 -08:00
parent 4950e9f374
commit 5d6353aab7
6 changed files with 726 additions and 0 deletions

91
lora/README.md Normal file
View File

@ -0,0 +1,91 @@
# LoRA
This is an example of using MLX to fine-tune a Llama 7B[^llama] model with low
rank adaptation (LoRA)[^lora] for a target task.
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
generate SQL queries from natural language. However, the example is intended to
be general should you wish to modify the task.
## Setup
Install the dependencies:
```
pip install -r requirements.txt
```
Next, download and convert the model. If you do not have access to the model
weights you will need to [request
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
from Meta.
Convert the weights with:
```
python convert.py <path_to_torch_weights> mlx_llama_7B.npz
```
## Run
The main script is `lora.py`. To see a full list of options run
```
python lora.py --help
```
To fine-tune a model use:
```
python lora.py --model mlx_llama_7B.npz \
--tokenizer tokenizer.model \
--train \
--iters 600 \
```
By default, the adapter weights are saved in `adapters.npz`. You can specify
the output location with `--adapter_file`.
To compute test set perplexity use
```
python lora.py --model mlx_llama_7B.npz \
--tokenizer tokenizer.model \
--data data \
--test
```
For generation use
```
python lora.py --model mlx_llama_7B.npz \
--tokenizer tokenizer.model \
--num-tokens 50 \
--prompt "table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
Q: What is terrence ross' nationality
A: "
```
## Results
The initial validation loss for Llama 7B on the WikiSQL is 2.66 and the final
validation loss after 1000 iterations is 1.23. The table below shows the
training and validation loss at a few points over the course of training.
| Iteration | Train Loss | Validation Loss |
| --------- | ---------- | --------------- |
| 1 | N/A | 2.659 |
| 200 | 1.264 | 1.405 |
| 400 | 1.201 | 1.303 |
| 600 | 1.123 | 1.274 |
| 800 | 1.017 | 1.255 |
| 1000 | 1.070 | 1.230 |
After training for 1000 iterations, the validation perplexity reduces to XX.
The model trains at around 475 tokens per second on an M2 Ultra.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.

46
lora/convert.py Normal file
View File

@ -0,0 +1,46 @@
import argparse
from itertools import starmap
import numpy as np
import torch
def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"
elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")
elif "w1" in key or "w2" in key or "w3" in key:
# The FFN is a separate submodule in PyTorch
key = key.replace("feed_forward.w1", "linear1")
key = key.replace("feed_forward.w3", "linear2")
key = key.replace("feed_forward.w2", "linear3")
elif "output" in key:
key = key.replace("output", "out_proj")
elif "rope" in key:
return None, None
return key, value.numpy()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument("torch_weights")
parser.add_argument("output_file")
args = parser.parse_args()
state = torch.load(args.torch_weights)
np.savez(
args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
)

197
lora/llama.py Normal file
View File

@ -0,0 +1,197 @@
import math
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8):
input_dims, output_dims = linear.weight.shape
lora_lin = LoRALinear(input_dims, output_dims, rank)
lora_lin.linear = linear
return lora_lin
def __init__(
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, lora_rank),
)
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
def __call__(self, x):
y = self.linear(x)
z = (x @ self.lora_a) @ self.lora_b
return y + 2.0 * z
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, traditional=True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
# Note that we return the keys and values to possibly be used as a cache
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
class Llama(nn.Module):
def __init__(
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.layers = [
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
]
self.norm = nn.RMSNorm(dims)
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
x = self.embedding(x)
for l in self.layers:
x, _ = l(x, mask)
x = self.norm(x)
return self.out_proj(x)
def generate(self, x, temp=1.0):
cache = []
try:
# Make an additive causal mask. We will need that to process the prompt.
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
# First we process the prompt x the same was as in __call__ but
# save the caches in cache
x = self.embedding(x)
for l in self.layers:
x, c = l(x, mask=mask)
# We store the per layer cache in a simple python list
cache.append(c)
x = self.norm(x)
# We only care about the last logits that generate the next token
y = self.out_proj(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
# y now has size [1]
yield y
# Now we parsed the prompt and generated the first token we
# need to feed it back into the model and loop to generate the
# rest.
while True:
# Unsqueezing the last dimension to add a sequence length
# dimension of 1
x = y[:, None]
x = self.embedding(x)
for i in range(len(cache)):
# We are overwriting the arrays in the cache list. When
# the computation will happen, MLX will be discarding the
# old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x)
y = self.out_proj(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
yield y
finally:
del cache
def load_model(model_path):
weights = mx.load(model_path)
mlp_dims, dims = weights["layers.0.linear1.weight"].shape
num_heads = dims // 128
num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1
vocab_size = weights["out_proj.weight"].shape[-1]
model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads)
model.update(tree_unflatten(list(weights.items())))
mx.eval(model.parameters())
return model

289
lora/lora.py Normal file
View File

@ -0,0 +1,289 @@
import argparse
import math
import numpy as np
from sentencepiece import SentencePieceProcessor
import time
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten
from llama import LoRALinear, load_model
import wikisql
def build_parser():
parser = argparse.ArgumentParser(description="Llama LoRA finetuning")
parser.add_argument(
"--model", required=True, help="The model file containing MLX weights"
)
parser.add_argument(
"--tokenizer", required=True, help="The sentencepiece tokenizer"
)
# Generation args
parser.add_argument(
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
)
parser.add_argument(
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
)
parser.add_argument(
"--temp", type=float, default=0.8, help="The sampling temperature"
)
parser.add_argument(
"--prompt",
"-p",
type=str,
help="The prompt for generation",
default=None,
)
# Training args
parser.add_argument(
"--train",
action="store_true",
help="Do training",
)
parser.add_argument("--batch_size", type=int, default=4, help="Minibatch size.")
parser.add_argument(
"--iters", type=int, default=1000, help="Iterations to train for."
)
parser.add_argument(
"--val_batches",
type=int,
default=100,
help="Number of validation batches, -1 uses the entire validation set.",
)
parser.add_argument(
"--learning_rate", type=float, default=1e-5, help="Adam learning rate."
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps_per_eval",
type=int,
default=200,
help="Number of training steps between validations.",
)
parser.add_argument(
"--adapter_file",
type=str,
default="adapters.npz",
help="Save/load path for the trained adapter weights.",
)
parser.add_argument(
"--test",
action="store_true",
help="Evaluate on the test set after training",
)
parser.add_argument(
"--test_batches",
type=int,
default=500,
help="Number of test set batches, -1 uses the entire test set.",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
return parser
def loss(model, inputs, targets, lengths):
# Run model on inputs
logits = model(inputs)
# Mask padding tokens
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
# Calculate the loss
ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
def iterate_batches(dset, tokenizer, batch_size, shuffle=True):
# Shuffle indices
indices = np.arange(len(dset))
if shuffle:
indices = np.random.permutation(indices)
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch
batch = tokenizer.encode([dset[indices[i + j]] for j in range(batch_size)])
lengths = [len(x) for x in batch]
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
for j in range(batch_size):
batch_arr[j, : lengths[j]] = batch[j]
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
all_losses = []
ntokens = 0
for it, batch in zip(
range(num_batches),
iterate_batches(dataset, tokenizer, batch_size, shuffle=False),
):
losses, toks = loss(model, *batch)
all_losses.append((losses * toks).item())
ntokens += toks.item()
return np.sum(all_losses) / ntokens
def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
# Create value and grad function for loss
loss_value_and_grad = nn.value_and_grad(model, loss)
losses = []
n_tokens = 0
# Main training loop
start = time.perf_counter()
for it, batch in zip(
range(args.iters), iterate_batches(train_set, tokenizer, args.batch_size)
):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
# Model update
optimizer.update(model, grad)
mx.eval(model.parameters(), optimizer.state, lvalue)
# Record loss
losses.append(lvalue.item())
n_tokens += toks.item()
# Report training loss if needed
if (it + 1) % args.steps_per_report == 0:
train_loss = np.mean(losses)
stop = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {args.steps_per_report / (stop - start):.3f}, "
f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
)
losses = []
n_tokens = 0
start = time.perf_counter()
# Report validation loss if needed
if it == 0 or (it + 1) % args.steps_per_eval == 0:
stop = time.perf_counter()
val_loss = evaluate(
model, val_set, loss, tokenizer, args.batch_size, args.val_batches
)
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val took {(time.perf_counter() - stop):.3f}s"
)
start = time.perf_counter()
def generate(model, prompt, tokenizer, args):
# Encode prompt
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(prompt)])
skip = 0
prompt_processing = None
tokens = []
# Genertation loop
start = time.perf_counter()
for token in model.generate(x, args.temp):
tokens.append(token)
if len(tokens) == 1:
# Actually perform the computation to measure the prompt processing time
mx.eval(token)
prompt_processing = time.perf_counter() - start
if len(tokens) >= args.num_tokens:
break
if (len(tokens) % args.write_every) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
skip = len(s)
mx.eval(tokens)
full_gen = time.perf_counter() - start
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
print()
print(f"Prompt processing took: {prompt_processing:.3f} s")
print(f"Full generation took: {full_gen:.3f} s")
if __name__ == "__main__":
parser = build_parser()
args = parser.parse_args()
np.random.seed(args.seed)
print("Loading tokenizer")
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
print("Loading pretrained model")
model = load_model(args.model)
# Freeze all layers other than LORA linears
model.freeze()
for l in model.layers[16:32]:
l.attention.query_proj = LoRALinear.from_linear(l.attention.query_proj)
l.attention.value_proj = LoRALinear.from_linear(l.attention.value_proj)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
print(f"Trainable parameters {p:.3f}M")
print("Loading datasets")
train_set, valid_set, test_set = wikisql.load()
if args.train:
print("Training")
opt = optim.Adam(learning_rate=args.learning_rate)
# Train model
train(model, train_set, valid_set, opt, loss, tokenizer, args)
# Save adapter weights
mx.savez(args.adapter_file, **dict(tree_flatten(model.trainable_parameters())))
# Load the LoRA adapter weights which we assume should exist by this point
model.load_weights(args.adapter_file)
if args.test:
print("Testing")
test_loss = evaluate(
model,
test_set,
loss,
tokenizer,
args.batch_size,
num_batches=args.test_batches,
)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if args.prompt is not None:
print("Generating")
generate(model, args.prompt, tokenizer, args)

2
lora/requirements.txt Normal file
View File

@ -0,0 +1,2 @@
sentencepiece
torch

101
lora/wikisql.py Normal file
View File

@ -0,0 +1,101 @@
"""
Code to preprocess the WikiSQL dataset adapted from
https://github.com/salesforce/WikiSQL and
https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb .
"""
import json
import os
def load():
"""
Load all three splits of the WikiSQL dataset.
"""
return (WikiSQL(dn) for dn in ["train", "dev", "test"])
class WikiSQL:
def __init__(self, dataset, save_dir="/tmp"):
valid_sets = ("train", "dev", "test")
if dataset not in valid_sets:
raise ValueError(f"Dataset must be in {valid_sets}, got {dataset}")
data_dir = os.path.join(save_dir, "wikisql")
self._maybe_download(data_dir)
self._parse_tables(os.path.join(data_dir, f"data/{dataset}.tables.jsonl"))
self._parse_queries(os.path.join(data_dir, f"data/{dataset}.jsonl"))
def _maybe_download(self, data_dir):
if not os.path.exists(data_dir):
import io
from urllib import request
import tarfile
url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
r = request.urlopen(url)
with tarfile.open(fileobj=io.BytesIO(r.read())) as tf:
tf.extractall(data_dir)
def _parse_tables(self, tables):
self._tables = {}
with open(tables) as f:
for line in f:
table = json.loads(line)
self._tables[table["id"]] = {
"columns": table["header"],
"types": table["types"],
"desc": f"table: {table['id']}\ncolumns: {', '.join(table['header'])}",
}
def _parse_queries(self, queries):
self._queries = []
with open(queries) as f:
for line in f:
query = json.loads(line)
table = self._tables[query["table_id"]]
question = query["question"]
answer = self.query_to_text(
query["sql"], query["table_id"], table["columns"], table["types"]
)
self._queries.append(
f"<s>{table['desc']}\nQ: {question}\nA: {answer}</s>"
)
def query_to_text(self, query, table, columns, types):
aggregation_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
condition_ops = ["=", ">", "<", "OP"]
column = columns[query["sel"]]
aggregation = (aggregation_ops[query["agg"]] + " ") if query["agg"] > 0 else ""
sql = f"SELECT {aggregation}{column} FROM {table}"
conditions = query["conds"]
if conditions:
cs = []
for i, o, v in conditions:
column = columns[i]
op = condition_ops[o]
if types[i] == "text":
value = f"'{v}'"
else:
value = v
cs.append(f"{column} {op} {value}")
sql += " WHERE " + " AND ".join(cs)
return sql
def __getitem__(self, idx):
return self._queries[idx]
def __len__(self):
return len(self._queries)
if __name__ == "__main__":
datanames = ["train", "dev", "test"]
sizes = [56355, 8421, 15878]
for dataname, size in zip(datanames, sizes):
len(WikiSQL(dataname)) == 56355, f"Wrong {dataname} set size."