generalize lora finetuning for llama and mistral

This commit is contained in:
Awni Hannun
2023-12-09 14:13:55 -08:00
parent 46c6bbe0a1
commit b8332a1e66
5 changed files with 354 additions and 293 deletions

View File

@@ -1,28 +1,28 @@
# Copyright © 2023 Apple Inc.
import argparse
import json
import math
import numpy as np
from pathlib import Path
from sentencepiece import SentencePieceProcessor
import time
from typing import Optional, Tuple, List
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten
from mlx.utils import tree_map, tree_flatten, tree_unflatten
from llama import LoRALinear, load_model
from models import ModelArgs, Model, LoRALinear
import wikisql
def build_parser():
parser = argparse.ArgumentParser(description="Llama LoRA finetuning")
parser = argparse.ArgumentParser(description="LoRA finetuning with Llama or Mistral")
parser.add_argument(
"--model", required=True, help="The model file containing MLX weights"
)
parser.add_argument(
"--tokenizer", required=True, help="The sentencepiece tokenizer"
"--model", required=True, help="A path to the model files containing the tokenizer, weights, config."
)
# Generation args
parser.add_argument(
@@ -73,6 +73,12 @@ def build_parser():
default=200,
help="Number of training steps between validations.",
)
parser.add_argument(
"--resume_adapter_file",
type=str,
default=None,
help="Load path to resume training with the given adapter weights.",
)
parser.add_argument(
"--adapter_file",
type=str,
@@ -94,9 +100,30 @@ def build_parser():
return parser
class Tokenizer:
def __init__(self, model_path: str):
assert Path(model_path).exists(), model_path
self._model = SentencePieceProcessor(model_file=model_path)
self._sep = ""
assert self._model.vocab_size() == self._model.get_piece_size()
def encode(self, s: str) -> List[int]:
return [self._model.bos_id(), *self._model.encode(s)]
def decode(self, t: List[int]) -> str:
out = self._model.decode(t)
if t and self._model.id_to_piece(t[0])[0] == self._sep:
return " " + out
return out
@property
def vocab_size(self) -> int:
return self._model.vocab_size()
def loss(model, inputs, targets, lengths):
# Run model on inputs
logits = model(inputs)
logits, _ = model(inputs)
# Mask padding tokens
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
@@ -117,7 +144,7 @@ def iterate_batches(dset, tokenizer, batch_size, shuffle=True):
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch
batch = tokenizer.encode([dset[indices[i + j]] for j in range(batch_size)])
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
lengths = [len(x) for x in batch]
# Pad to the max length
@@ -195,40 +222,55 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
def generate(model, prompt, tokenizer, args):
# Encode prompt
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(prompt)])
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
def generate_step():
temp = args.temp
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt[None])
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache)
y = sample(logits.squeeze(1))
yield y
skip = 0
prompt_processing = None
tokens = []
# Genertation loop
start = time.perf_counter()
for token in model.generate(x, args.temp):
for token, _ in zip(generate_step(), range(args.num_tokens)):
tokens.append(token)
if len(tokens) == 1:
# Actually perform the computation to measure the prompt processing time
mx.eval(token)
prompt_processing = time.perf_counter() - start
if len(tokens) >= args.num_tokens:
break
if (len(tokens) % args.write_every) == 0:
if (len(tokens) % 10) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
skip = len(s)
print(s, end="", flush=True)
tokens = []
mx.eval(tokens)
full_gen = time.perf_counter() - start
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
print()
print(f"Prompt processing took: {prompt_processing:.3f} s")
print(f"Full generation took: {full_gen:.3f} s")
print(s, flush=True)
def load_model(folder: str, dtype=mx.float32):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f:
config = json.loads(f.read())
model_args = ModelArgs(**config)
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = tokenizer.vocab_size
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Model(model_args)
model.update(weights)
return model, tokenizer
if __name__ == "__main__":
@@ -237,17 +279,14 @@ if __name__ == "__main__":
np.random.seed(args.seed)
print("Loading tokenizer")
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
print("Loading pretrained model")
model = load_model(args.model)
model, tokenizer = load_model(args.model)
# Freeze all layers other than LORA linears
model.freeze()
for l in model.layers[16:32]:
l.attention.query_proj = LoRALinear.from_linear(l.attention.query_proj)
l.attention.value_proj = LoRALinear.from_linear(l.attention.value_proj)
l.attention.wq = LoRALinear.from_linear(l.attention.wq)
l.attention.wv = LoRALinear.from_linear(l.attention.wv)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
@@ -257,6 +296,11 @@ if __name__ == "__main__":
print("Loading datasets")
train_set, valid_set, test_set = wikisql.load()
# Resume training the given adapters.
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file)
if args.train:
print("Training")
opt = optim.Adam(learning_rate=args.learning_rate)
@@ -287,5 +331,4 @@ if __name__ == "__main__":
if args.prompt is not None:
print("Generating")
generate(model, args.prompt, tokenizer, args)