mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-11-03 04:58:09 +08:00
generalize lora finetuning for llama and mistral
This commit is contained in:
125
lora/lora.py
125
lora/lora.py
@@ -1,28 +1,28 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
import time
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx.utils import tree_map, tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
from llama import LoRALinear, load_model
|
||||
from models import ModelArgs, Model, LoRALinear
|
||||
import wikisql
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description="Llama LoRA finetuning")
|
||||
parser = argparse.ArgumentParser(description="LoRA finetuning with Llama or Mistral")
|
||||
parser.add_argument(
|
||||
"--model", required=True, help="The model file containing MLX weights"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer", required=True, help="The sentencepiece tokenizer"
|
||||
"--model", required=True, help="A path to the model files containing the tokenizer, weights, config."
|
||||
)
|
||||
# Generation args
|
||||
parser.add_argument(
|
||||
@@ -73,6 +73,12 @@ def build_parser():
|
||||
default=200,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_adapter_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Load path to resume training with the given adapter weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter_file",
|
||||
type=str,
|
||||
@@ -94,9 +100,30 @@ def build_parser():
|
||||
return parser
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
assert Path(model_path).exists(), model_path
|
||||
self._model = SentencePieceProcessor(model_file=model_path)
|
||||
self._sep = "▁"
|
||||
assert self._model.vocab_size() == self._model.get_piece_size()
|
||||
|
||||
def encode(self, s: str) -> List[int]:
|
||||
return [self._model.bos_id(), *self._model.encode(s)]
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
out = self._model.decode(t)
|
||||
if t and self._model.id_to_piece(t[0])[0] == self._sep:
|
||||
return " " + out
|
||||
return out
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._model.vocab_size()
|
||||
|
||||
|
||||
def loss(model, inputs, targets, lengths):
|
||||
# Run model on inputs
|
||||
logits = model(inputs)
|
||||
logits, _ = model(inputs)
|
||||
|
||||
# Mask padding tokens
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
@@ -117,7 +144,7 @@ def iterate_batches(dset, tokenizer, batch_size, shuffle=True):
|
||||
# Collect batches from dataset
|
||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||
# Encode batch
|
||||
batch = tokenizer.encode([dset[indices[i + j]] for j in range(batch_size)])
|
||||
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
# Pad to the max length
|
||||
@@ -195,40 +222,55 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
|
||||
|
||||
|
||||
def generate(model, prompt, tokenizer, args):
|
||||
# Encode prompt
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(prompt)])
|
||||
print(args.prompt, end="", flush=True)
|
||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||
|
||||
def generate_step():
|
||||
temp = args.temp
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
logits, cache = model(prompt[None])
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[:, None], cache)
|
||||
y = sample(logits.squeeze(1))
|
||||
yield y
|
||||
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
tokens = []
|
||||
|
||||
# Genertation loop
|
||||
start = time.perf_counter()
|
||||
for token in model.generate(x, args.temp):
|
||||
for token, _ in zip(generate_step(), range(args.num_tokens)):
|
||||
tokens.append(token)
|
||||
|
||||
if len(tokens) == 1:
|
||||
# Actually perform the computation to measure the prompt processing time
|
||||
mx.eval(token)
|
||||
prompt_processing = time.perf_counter() - start
|
||||
|
||||
if len(tokens) >= args.num_tokens:
|
||||
break
|
||||
|
||||
if (len(tokens) % args.write_every) == 0:
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
|
||||
mx.eval(tokens)
|
||||
full_gen = time.perf_counter() - start
|
||||
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
print()
|
||||
print(f"Prompt processing took: {prompt_processing:.3f} s")
|
||||
print(f"Full generation took: {full_gen:.3f} s")
|
||||
print(s, flush=True)
|
||||
|
||||
|
||||
def load_model(folder: str, dtype=mx.float32):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
model_args = ModelArgs(**config)
|
||||
if config.get("vocab_size", -1) < 0:
|
||||
config["vocab_size"] = tokenizer.vocab_size
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Model(model_args)
|
||||
model.update(weights)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -237,17 +279,14 @@ if __name__ == "__main__":
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading tokenizer")
|
||||
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model = load_model(args.model)
|
||||
model, tokenizer = load_model(args.model)
|
||||
|
||||
# Freeze all layers other than LORA linears
|
||||
model.freeze()
|
||||
for l in model.layers[16:32]:
|
||||
l.attention.query_proj = LoRALinear.from_linear(l.attention.query_proj)
|
||||
l.attention.value_proj = LoRALinear.from_linear(l.attention.value_proj)
|
||||
l.attention.wq = LoRALinear.from_linear(l.attention.wq)
|
||||
l.attention.wv = LoRALinear.from_linear(l.attention.wv)
|
||||
|
||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
print(f"Total parameters {p:.3f}M")
|
||||
@@ -257,6 +296,11 @@ if __name__ == "__main__":
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = wikisql.load()
|
||||
|
||||
# Resume training the given adapters.
|
||||
if args.resume_adapter_file is not None:
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file)
|
||||
|
||||
if args.train:
|
||||
print("Training")
|
||||
opt = optim.Adam(learning_rate=args.learning_rate)
|
||||
@@ -287,5 +331,4 @@ if __name__ == "__main__":
|
||||
|
||||
if args.prompt is not None:
|
||||
print("Generating")
|
||||
|
||||
generate(model, args.prompt, tokenizer, args)
|
||||
|
||||
Reference in New Issue
Block a user