Move lora example to use the same model format / conversion as hf_llm (#252)

* huffing face the lora example to allow more models

* fixes

* comments

* more readme nits

* fusion + works better for qlora

* nits'

* comments
This commit is contained in:
Awni Hannun
2024-01-09 11:14:52 -08:00
committed by GitHub
parent bbd7172eef
commit 7b258f33ac
10 changed files with 521 additions and 224 deletions

View File

@@ -5,15 +5,13 @@ import json
import math
import time
from pathlib import Path
from typing import List
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import models
import numpy as np
from mlx.utils import tree_flatten, tree_unflatten
from models import LoRALinear, Model, ModelArgs
from sentencepiece import SentencePieceProcessor
def build_parser():
@@ -21,7 +19,7 @@ def build_parser():
parser.add_argument(
"--model",
default="mlx_model",
help="A path to the model files containing the tokenizer, weights, config.",
help="The path to the local model directory or Hugging Face repo.",
)
# Generation args
parser.add_argument(
@@ -111,34 +109,6 @@ def build_parser():
return parser
class Tokenizer:
def __init__(self, model_path: str):
assert Path(model_path).exists(), model_path
self._model = SentencePieceProcessor(model_file=model_path)
self._sep = ""
assert self._model.vocab_size() == self._model.get_piece_size()
def encode(self, s: str, eos: bool = False) -> List[int]:
toks = [self._model.bos_id(), *self._model.encode(s)]
if eos:
toks.append(self.eos_id)
return toks
@property
def eos_id(self) -> int:
return self._model.eos_id()
def decode(self, t: List[int]) -> str:
out = self._model.decode(t)
if t and self._model.id_to_piece(t[0])[0] == self._sep:
return " " + out
return out
@property
def vocab_size(self) -> int:
return self._model.vocab_size()
class Dataset:
"""
Light-weight wrapper to hold lines from a jsonl file
@@ -295,56 +265,27 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
def generate(model, prompt, tokenizer, args):
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
def generate_step():
temp = args.temp
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt[None])
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache)
y = sample(logits.squeeze(1))
yield y
prompt = tokenizer.encode(args.prompt)
tokens = []
for token, _ in zip(generate_step(), range(args.num_tokens)):
tokens.append(token)
skip = 0
for token, n in zip(
models.generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
break
if (len(tokens) % 10) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)
def load_model(folder: str):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
model.update(weights)
return model, tokenizer
tokens.append(token.item())
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
if __name__ == "__main__":
@@ -354,13 +295,13 @@ if __name__ == "__main__":
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer = load_model(args.model)
model, tokenizer, _ = models.load(args.model)
# Freeze all layers other than LORA linears
model.freeze()
for l in model.layers[-args.lora_layers :]:
l.attention.wq = LoRALinear.from_linear(l.attention.wq)
l.attention.wv = LoRALinear.from_linear(l.attention.wv)
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")