diff --git a/lora/README.md b/lora/README.md new file mode 100644 index 00000000..ef1acd14 --- /dev/null +++ b/lora/README.md @@ -0,0 +1,91 @@ +# LoRA + +This is an example of using MLX to fine-tune a Llama 7B[^llama] model with low +rank adaptation (LoRA)[^lora] for a target task. + +In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to +generate SQL queries from natural language. However, the example is intended to +be general should you wish to modify the task. + +## Setup + +Install the dependencies: + +``` +pip install -r requirements.txt +``` + +Next, download and convert the model. If you do not have access to the model +weights you will need to [request +access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) +from Meta. + +Convert the weights with: + +``` +python convert.py mlx_llama_7B.npz +``` + +## Run + +The main script is `lora.py`. To see a full list of options run + +``` +python lora.py --help +``` + +To fine-tune a model use: + +``` +python lora.py --model mlx_llama_7B.npz \ + --tokenizer tokenizer.model \ + --train \ + --iters 600 \ +``` + +By default, the adapter weights are saved in `adapters.npz`. You can specify +the output location with `--adapter_file`. + +To compute test set perplexity use + +``` +python lora.py --model mlx_llama_7B.npz \ + --tokenizer tokenizer.model \ + --data data \ + --test +``` + +For generation use + +``` +python lora.py --model mlx_llama_7B.npz \ + --tokenizer tokenizer.model \ + --num-tokens 50 \ + --prompt "table: 1-10015132-16 +columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team +Q: What is terrence ross' nationality +A: " +``` + +## Results + +The initial validation loss for Llama 7B on the WikiSQL is 2.66 and the final +validation loss after 1000 iterations is 1.23. The table below shows the +training and validation loss at a few points over the course of training. + +| Iteration | Train Loss | Validation Loss | +| --------- | ---------- | --------------- | +| 1 | N/A | 2.659 | +| 200 | 1.264 | 1.405 | +| 400 | 1.201 | 1.303 | +| 600 | 1.123 | 1.274 | +| 800 | 1.017 | 1.255 | +| 1000 | 1.070 | 1.230 | + +After training for 1000 iterations, the validation perplexity reduces to XX. + +The model trains at around 475 tokens per second on an M2 Ultra. + +[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. +[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. +[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL. diff --git a/lora/convert.py b/lora/convert.py new file mode 100644 index 00000000..5cad31a0 --- /dev/null +++ b/lora/convert.py @@ -0,0 +1,46 @@ +import argparse +from itertools import starmap + +import numpy as np +import torch + + +def map_torch_to_mlx(key, value): + if "tok_embedding" in key: + key = "embedding.weight" + + elif "norm" in key: + key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2") + + elif "wq" in key or "wk" in key or "wv" in key or "wo" in key: + key = key.replace("wq", "query_proj") + key = key.replace("wk", "key_proj") + key = key.replace("wv", "value_proj") + key = key.replace("wo", "out_proj") + + elif "w1" in key or "w2" in key or "w3" in key: + # The FFN is a separate submodule in PyTorch + key = key.replace("feed_forward.w1", "linear1") + key = key.replace("feed_forward.w3", "linear2") + key = key.replace("feed_forward.w2", "linear3") + + elif "output" in key: + key = key.replace("output", "out_proj") + + elif "rope" in key: + return None, None + + return key, value.numpy() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") + parser.add_argument("torch_weights") + parser.add_argument("output_file") + args = parser.parse_args() + + state = torch.load(args.torch_weights) + np.savez( + args.output_file, + **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None} + ) diff --git a/lora/llama.py b/lora/llama.py new file mode 100644 index 00000000..d8933e22 --- /dev/null +++ b/lora/llama.py @@ -0,0 +1,197 @@ +import math + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten + + +class LoRALinear(nn.Module): + @staticmethod + def from_linear(linear: nn.Linear, rank: int = 8): + input_dims, output_dims = linear.weight.shape + lora_lin = LoRALinear(input_dims, output_dims, rank) + lora_lin.linear = linear + return lora_lin + + def __init__( + self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, lora_rank), + ) + self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) + + def __call__(self, x): + y = self.linear(x) + z = (x @ self.lora_a) @ self.lora_b + return y + 2.0 * z + + +class LlamaAttention(nn.Module): + def __init__(self, dims: int, num_heads: int): + super().__init__() + + self.num_heads = num_heads + + self.rope = nn.RoPE(dims // num_heads, traditional=True) + + self.query_proj = nn.Linear(dims, dims, bias=False) + self.key_proj = nn.Linear(dims, dims, bias=False) + self.value_proj = nn.Linear(dims, dims, bias=False) + self.out_proj = nn.Linear(dims, dims, bias=False) + + def __call__(self, queries, keys, values, mask=None, cache=None): + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + # Extract some shapes + num_heads = self.num_heads + B, L, D = queries.shape + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + + # Add RoPE to the queries and keys and combine them with the cache + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + # Finally perform the attention computation + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores = scores + mask + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + # Note that we return the keys and values to possibly be used as a cache + return self.out_proj(values_hat), (keys, values) + + +class LlamaEncoderLayer(nn.Module): + def __init__(self, dims: int, mlp_dims: int, num_heads: int): + super().__init__() + + self.attention = LlamaAttention(dims, num_heads) + + self.norm1 = nn.RMSNorm(dims) + self.norm2 = nn.RMSNorm(dims) + + self.linear1 = nn.Linear(dims, mlp_dims, bias=False) + self.linear2 = nn.Linear(dims, mlp_dims, bias=False) + self.linear3 = nn.Linear(mlp_dims, dims, bias=False) + + def __call__(self, x, mask=None, cache=None): + y = self.norm1(x) + y, cache = self.attention(y, y, y, mask, cache) + x = x + y + + y = self.norm2(x) + a = self.linear1(y) + b = self.linear2(y) + y = a * mx.sigmoid(a) * b + y = self.linear3(y) + x = x + y + + return x, cache + + +class Llama(nn.Module): + def __init__( + self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int + ): + super().__init__() + + self.embedding = nn.Embedding(vocab_size, dims) + self.layers = [ + LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers) + ] + self.norm = nn.RMSNorm(dims) + self.out_proj = nn.Linear(dims, vocab_size, bias=False) + + def __call__(self, x): + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + + x = self.embedding(x) + + for l in self.layers: + x, _ = l(x, mask) + + x = self.norm(x) + + return self.out_proj(x) + + def generate(self, x, temp=1.0): + cache = [] + try: + # Make an additive causal mask. We will need that to process the prompt. + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(self.embedding.weight.dtype) + + # First we process the prompt x the same was as in __call__ but + # save the caches in cache + x = self.embedding(x) + for l in self.layers: + x, c = l(x, mask=mask) + # We store the per layer cache in a simple python list + cache.append(c) + x = self.norm(x) + # We only care about the last logits that generate the next token + y = self.out_proj(x[:, -1]) + y = mx.random.categorical(y * (1 / temp)) + + # y now has size [1] + yield y + + # Now we parsed the prompt and generated the first token we + # need to feed it back into the model and loop to generate the + # rest. + while True: + # Unsqueezing the last dimension to add a sequence length + # dimension of 1 + x = y[:, None] + + x = self.embedding(x) + for i in range(len(cache)): + # We are overwriting the arrays in the cache list. When + # the computation will happen, MLX will be discarding the + # old cache the moment it is not needed anymore. + x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) + x = self.norm(x) + y = self.out_proj(x[:, -1]) + y = mx.random.categorical(y * (1 / temp)) + + yield y + + finally: + del cache + + +def load_model(model_path): + weights = mx.load(model_path) + mlp_dims, dims = weights["layers.0.linear1.weight"].shape + num_heads = dims // 128 + num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1 + vocab_size = weights["out_proj.weight"].shape[-1] + model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads) + model.update(tree_unflatten(list(weights.items()))) + mx.eval(model.parameters()) + return model diff --git a/lora/lora.py b/lora/lora.py new file mode 100644 index 00000000..5c34de0c --- /dev/null +++ b/lora/lora.py @@ -0,0 +1,289 @@ +import argparse +import math +import numpy as np +from sentencepiece import SentencePieceProcessor +import time + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.utils import tree_flatten + + +from llama import LoRALinear, load_model +import wikisql + + +def build_parser(): + parser = argparse.ArgumentParser(description="Llama LoRA finetuning") + parser.add_argument( + "--model", required=True, help="The model file containing MLX weights" + ) + parser.add_argument( + "--tokenizer", required=True, help="The sentencepiece tokenizer" + ) + # Generation args + parser.add_argument( + "--num-tokens", "-n", type=int, default=100, help="How many tokens to generate" + ) + parser.add_argument( + "--write-every", type=int, default=1, help="After how many tokens to detokenize" + ) + parser.add_argument( + "--temp", type=float, default=0.8, help="The sampling temperature" + ) + parser.add_argument( + "--prompt", + "-p", + type=str, + help="The prompt for generation", + default=None, + ) + + # Training args + parser.add_argument( + "--train", + action="store_true", + help="Do training", + ) + parser.add_argument("--batch_size", type=int, default=4, help="Minibatch size.") + parser.add_argument( + "--iters", type=int, default=1000, help="Iterations to train for." + ) + parser.add_argument( + "--val_batches", + type=int, + default=100, + help="Number of validation batches, -1 uses the entire validation set.", + ) + parser.add_argument( + "--learning_rate", type=float, default=1e-5, help="Adam learning rate." + ) + parser.add_argument( + "--steps_per_report", + type=int, + default=10, + help="Number of training steps between loss reporting.", + ) + parser.add_argument( + "--steps_per_eval", + type=int, + default=200, + help="Number of training steps between validations.", + ) + parser.add_argument( + "--adapter_file", + type=str, + default="adapters.npz", + help="Save/load path for the trained adapter weights.", + ) + parser.add_argument( + "--test", + action="store_true", + help="Evaluate on the test set after training", + ) + parser.add_argument( + "--test_batches", + type=int, + default=500, + help="Number of test set batches, -1 uses the entire test set.", + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + return parser + + +def loss(model, inputs, targets, lengths): + # Run model on inputs + logits = model(inputs) + + # Mask padding tokens + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + + # Calculate the loss + ce = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() + ce = ce.sum() / ntoks + return ce, ntoks + + +def iterate_batches(dset, tokenizer, batch_size, shuffle=True): + # Shuffle indices + indices = np.arange(len(dset)) + if shuffle: + indices = np.random.permutation(indices) + + # Collect batches from dataset + for i in range(0, len(indices) - batch_size + 1, batch_size): + # Encode batch + batch = tokenizer.encode([dset[indices[i + j]] for j in range(batch_size)]) + lengths = [len(x) for x in batch] + + # Pad to the max length + batch_arr = np.zeros((batch_size, max(lengths)), np.int32) + for j in range(batch_size): + batch_arr[j, : lengths[j]] = batch[j] + batch = mx.array(batch_arr) + yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + + +def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches): + all_losses = [] + ntokens = 0 + for it, batch in zip( + range(num_batches), + iterate_batches(dataset, tokenizer, batch_size, shuffle=False), + ): + losses, toks = loss(model, *batch) + all_losses.append((losses * toks).item()) + ntokens += toks.item() + + return np.sum(all_losses) / ntokens + + +def train(model, train_set, val_set, optimizer, loss, tokenizer, args): + # Create value and grad function for loss + loss_value_and_grad = nn.value_and_grad(model, loss) + + losses = [] + n_tokens = 0 + + # Main training loop + start = time.perf_counter() + for it, batch in zip( + range(args.iters), iterate_batches(train_set, tokenizer, args.batch_size) + ): + # Forward and backward pass + (lvalue, toks), grad = loss_value_and_grad(model, *batch) + + # Model update + optimizer.update(model, grad) + mx.eval(model.parameters(), optimizer.state, lvalue) + + # Record loss + losses.append(lvalue.item()) + n_tokens += toks.item() + + # Report training loss if needed + if (it + 1) % args.steps_per_report == 0: + train_loss = np.mean(losses) + + stop = time.perf_counter() + print( + f"Iter {it + 1}: Train loss {train_loss:.3f}, " + f"It/sec {args.steps_per_report / (stop - start):.3f}, " + f"Tokens/sec {float(n_tokens) / (stop - start):.3f}" + ) + losses = [] + n_tokens = 0 + start = time.perf_counter() + + # Report validation loss if needed + if it == 0 or (it + 1) % args.steps_per_eval == 0: + stop = time.perf_counter() + val_loss = evaluate( + model, val_set, loss, tokenizer, args.batch_size, args.val_batches + ) + print( + f"Iter {it + 1}: " + f"Val loss {val_loss:.3f}, " + f"Val took {(time.perf_counter() - stop):.3f}s" + ) + + start = time.perf_counter() + + +def generate(model, prompt, tokenizer, args): + # Encode prompt + x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(prompt)]) + + skip = 0 + prompt_processing = None + tokens = [] + + # Genertation loop + start = time.perf_counter() + for token in model.generate(x, args.temp): + tokens.append(token) + + if len(tokens) == 1: + # Actually perform the computation to measure the prompt processing time + mx.eval(token) + prompt_processing = time.perf_counter() - start + + if len(tokens) >= args.num_tokens: + break + + if (len(tokens) % args.write_every) == 0: + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s[skip:], end="", flush=True) + skip = len(s) + + mx.eval(tokens) + full_gen = time.perf_counter() - start + + s = tokenizer.decode([t.item() for t in tokens]) + print(s[skip:], end="", flush=True) + print() + print(f"Prompt processing took: {prompt_processing:.3f} s") + print(f"Full generation took: {full_gen:.3f} s") + + +if __name__ == "__main__": + parser = build_parser() + args = parser.parse_args() + + np.random.seed(args.seed) + + print("Loading tokenizer") + tokenizer = SentencePieceProcessor(model_file=args.tokenizer) + + print("Loading pretrained model") + model = load_model(args.model) + + # Freeze all layers other than LORA linears + model.freeze() + for l in model.layers[16:32]: + l.attention.query_proj = LoRALinear.from_linear(l.attention.query_proj) + l.attention.value_proj = LoRALinear.from_linear(l.attention.value_proj) + + p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 + print(f"Total parameters {p:.3f}M") + p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 + print(f"Trainable parameters {p:.3f}M") + + print("Loading datasets") + train_set, valid_set, test_set = wikisql.load() + + if args.train: + print("Training") + opt = optim.Adam(learning_rate=args.learning_rate) + + # Train model + train(model, train_set, valid_set, opt, loss, tokenizer, args) + + # Save adapter weights + mx.savez(args.adapter_file, **dict(tree_flatten(model.trainable_parameters()))) + + # Load the LoRA adapter weights which we assume should exist by this point + model.load_weights(args.adapter_file) + + if args.test: + print("Testing") + + test_loss = evaluate( + model, + test_set, + loss, + tokenizer, + args.batch_size, + num_batches=args.test_batches, + ) + test_ppl = math.exp(test_loss) + + print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") + + if args.prompt is not None: + print("Generating") + + generate(model, args.prompt, tokenizer, args) diff --git a/lora/requirements.txt b/lora/requirements.txt new file mode 100644 index 00000000..c036fa59 --- /dev/null +++ b/lora/requirements.txt @@ -0,0 +1,2 @@ +sentencepiece +torch diff --git a/lora/wikisql.py b/lora/wikisql.py new file mode 100644 index 00000000..a6989ed0 --- /dev/null +++ b/lora/wikisql.py @@ -0,0 +1,101 @@ +""" +Code to preprocess the WikiSQL dataset adapted from +https://github.com/salesforce/WikiSQL and +https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb . +""" + + +import json +import os + + +def load(): + """ + Load all three splits of the WikiSQL dataset. + """ + return (WikiSQL(dn) for dn in ["train", "dev", "test"]) + + +class WikiSQL: + def __init__(self, dataset, save_dir="/tmp"): + valid_sets = ("train", "dev", "test") + if dataset not in valid_sets: + raise ValueError(f"Dataset must be in {valid_sets}, got {dataset}") + data_dir = os.path.join(save_dir, "wikisql") + self._maybe_download(data_dir) + + self._parse_tables(os.path.join(data_dir, f"data/{dataset}.tables.jsonl")) + self._parse_queries(os.path.join(data_dir, f"data/{dataset}.jsonl")) + + def _maybe_download(self, data_dir): + if not os.path.exists(data_dir): + import io + from urllib import request + import tarfile + + url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2" + r = request.urlopen(url) + with tarfile.open(fileobj=io.BytesIO(r.read())) as tf: + tf.extractall(data_dir) + + def _parse_tables(self, tables): + self._tables = {} + with open(tables) as f: + for line in f: + table = json.loads(line) + self._tables[table["id"]] = { + "columns": table["header"], + "types": table["types"], + "desc": f"table: {table['id']}\ncolumns: {', '.join(table['header'])}", + } + + def _parse_queries(self, queries): + self._queries = [] + with open(queries) as f: + for line in f: + query = json.loads(line) + table = self._tables[query["table_id"]] + question = query["question"] + answer = self.query_to_text( + query["sql"], query["table_id"], table["columns"], table["types"] + ) + self._queries.append( + f"{table['desc']}\nQ: {question}\nA: {answer}" + ) + + def query_to_text(self, query, table, columns, types): + aggregation_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"] + condition_ops = ["=", ">", "<", "OP"] + column = columns[query["sel"]] + aggregation = (aggregation_ops[query["agg"]] + " ") if query["agg"] > 0 else "" + sql = f"SELECT {aggregation}{column} FROM {table}" + + conditions = query["conds"] + if conditions: + cs = [] + for i, o, v in conditions: + column = columns[i] + op = condition_ops[o] + + if types[i] == "text": + value = f"'{v}'" + else: + value = v + cs.append(f"{column} {op} {value}") + + sql += " WHERE " + " AND ".join(cs) + + return sql + + def __getitem__(self, idx): + return self._queries[idx] + + def __len__(self): + return len(self._queries) + + +if __name__ == "__main__": + datanames = ["train", "dev", "test"] + sizes = [56355, 8421, 15878] + for dataname, size in zip(datanames, sizes): + len(WikiSQL(dataname)) == 56355, f"Wrong {dataname} set size."