mlx-examples/lora/lora.py
2023-11-30 11:08:53 -08:00

292 lines
8.5 KiB
Python

# Copyright © 2023 Apple Inc.
import argparse
import math
import numpy as np
from sentencepiece import SentencePieceProcessor
import time
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten
from llama import LoRALinear, load_model
import wikisql
def build_parser():
parser = argparse.ArgumentParser(description="Llama LoRA finetuning")
parser.add_argument(
"--model", required=True, help="The model file containing MLX weights"
)
parser.add_argument(
"--tokenizer", required=True, help="The sentencepiece tokenizer"
)
# Generation args
parser.add_argument(
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
)
parser.add_argument(
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
)
parser.add_argument(
"--temp", type=float, default=0.8, help="The sampling temperature"
)
parser.add_argument(
"--prompt",
"-p",
type=str,
help="The prompt for generation",
default=None,
)
# Training args
parser.add_argument(
"--train",
action="store_true",
help="Do training",
)
parser.add_argument("--batch_size", type=int, default=4, help="Minibatch size.")
parser.add_argument(
"--iters", type=int, default=1000, help="Iterations to train for."
)
parser.add_argument(
"--val_batches",
type=int,
default=100,
help="Number of validation batches, -1 uses the entire validation set.",
)
parser.add_argument(
"--learning_rate", type=float, default=1e-5, help="Adam learning rate."
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps_per_eval",
type=int,
default=200,
help="Number of training steps between validations.",
)
parser.add_argument(
"--adapter_file",
type=str,
default="adapters.npz",
help="Save/load path for the trained adapter weights.",
)
parser.add_argument(
"--test",
action="store_true",
help="Evaluate on the test set after training",
)
parser.add_argument(
"--test_batches",
type=int,
default=500,
help="Number of test set batches, -1 uses the entire test set.",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
return parser
def loss(model, inputs, targets, lengths):
# Run model on inputs
logits = model(inputs)
# Mask padding tokens
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
# Calculate the loss
ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
def iterate_batches(dset, tokenizer, batch_size, shuffle=True):
# Shuffle indices
indices = np.arange(len(dset))
if shuffle:
indices = np.random.permutation(indices)
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch
batch = tokenizer.encode([dset[indices[i + j]] for j in range(batch_size)])
lengths = [len(x) for x in batch]
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
for j in range(batch_size):
batch_arr[j, : lengths[j]] = batch[j]
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
all_losses = []
ntokens = 0
for it, batch in zip(
range(num_batches),
iterate_batches(dataset, tokenizer, batch_size, shuffle=False),
):
losses, toks = loss(model, *batch)
all_losses.append((losses * toks).item())
ntokens += toks.item()
return np.sum(all_losses) / ntokens
def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
# Create value and grad function for loss
loss_value_and_grad = nn.value_and_grad(model, loss)
losses = []
n_tokens = 0
# Main training loop
start = time.perf_counter()
for it, batch in zip(
range(args.iters), iterate_batches(train_set, tokenizer, args.batch_size)
):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
# Model update
optimizer.update(model, grad)
mx.eval(model.parameters(), optimizer.state, lvalue)
# Record loss
losses.append(lvalue.item())
n_tokens += toks.item()
# Report training loss if needed
if (it + 1) % args.steps_per_report == 0:
train_loss = np.mean(losses)
stop = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {args.steps_per_report / (stop - start):.3f}, "
f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
)
losses = []
n_tokens = 0
start = time.perf_counter()
# Report validation loss if needed
if it == 0 or (it + 1) % args.steps_per_eval == 0:
stop = time.perf_counter()
val_loss = evaluate(
model, val_set, loss, tokenizer, args.batch_size, args.val_batches
)
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val took {(time.perf_counter() - stop):.3f}s"
)
start = time.perf_counter()
def generate(model, prompt, tokenizer, args):
# Encode prompt
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(prompt)])
skip = 0
prompt_processing = None
tokens = []
# Genertation loop
start = time.perf_counter()
for token in model.generate(x, args.temp):
tokens.append(token)
if len(tokens) == 1:
# Actually perform the computation to measure the prompt processing time
mx.eval(token)
prompt_processing = time.perf_counter() - start
if len(tokens) >= args.num_tokens:
break
if (len(tokens) % args.write_every) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
skip = len(s)
mx.eval(tokens)
full_gen = time.perf_counter() - start
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
print()
print(f"Prompt processing took: {prompt_processing:.3f} s")
print(f"Full generation took: {full_gen:.3f} s")
if __name__ == "__main__":
parser = build_parser()
args = parser.parse_args()
np.random.seed(args.seed)
print("Loading tokenizer")
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
print("Loading pretrained model")
model = load_model(args.model)
# Freeze all layers other than LORA linears
model.freeze()
for l in model.layers[16:32]:
l.attention.query_proj = LoRALinear.from_linear(l.attention.query_proj)
l.attention.value_proj = LoRALinear.from_linear(l.attention.value_proj)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
print(f"Trainable parameters {p:.3f}M")
print("Loading datasets")
train_set, valid_set, test_set = wikisql.load()
if args.train:
print("Training")
opt = optim.Adam(learning_rate=args.learning_rate)
# Train model
train(model, train_set, valid_set, opt, loss, tokenizer, args)
# Save adapter weights
mx.savez(args.adapter_file, **dict(tree_flatten(model.trainable_parameters())))
# Load the LoRA adapter weights which we assume should exist by this point
model.load_weights(args.adapter_file)
if args.test:
print("Testing")
test_loss = evaluate(
model,
test_set,
loss,
tokenizer,
args.batch_size,
num_batches=args.test_batches,
)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if args.prompt is not None:
print("Generating")
generate(model, args.prompt, tokenizer, args)