mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
Move lora example to use the same model format / conversion as hf_llm
(#252)
* huffing face the lora example to allow more models * fixes * comments * more readme nits * fusion + works better for qlora * nits' * comments
This commit is contained in:
105
lora/lora.py
105
lora/lora.py
@@ -5,15 +5,13 @@ import json
|
||||
import math
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import models
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from models import LoRALinear, Model, ModelArgs
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
def build_parser():
|
||||
@@ -21,7 +19,7 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx_model",
|
||||
help="A path to the model files containing the tokenizer, weights, config.",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
# Generation args
|
||||
parser.add_argument(
|
||||
@@ -111,34 +109,6 @@ def build_parser():
|
||||
return parser
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
assert Path(model_path).exists(), model_path
|
||||
self._model = SentencePieceProcessor(model_file=model_path)
|
||||
self._sep = "▁"
|
||||
assert self._model.vocab_size() == self._model.get_piece_size()
|
||||
|
||||
def encode(self, s: str, eos: bool = False) -> List[int]:
|
||||
toks = [self._model.bos_id(), *self._model.encode(s)]
|
||||
if eos:
|
||||
toks.append(self.eos_id)
|
||||
return toks
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._model.eos_id()
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
out = self._model.decode(t)
|
||||
if t and self._model.id_to_piece(t[0])[0] == self._sep:
|
||||
return " " + out
|
||||
return out
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._model.vocab_size()
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold lines from a jsonl file
|
||||
@@ -295,56 +265,27 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
|
||||
|
||||
def generate(model, prompt, tokenizer, args):
|
||||
print(args.prompt, end="", flush=True)
|
||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||
|
||||
def generate_step():
|
||||
temp = args.temp
|
||||
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
logits, cache = model(prompt[None])
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[:, None], cache)
|
||||
y = sample(logits.squeeze(1))
|
||||
yield y
|
||||
prompt = tokenizer.encode(args.prompt)
|
||||
|
||||
tokens = []
|
||||
for token, _ in zip(generate_step(), range(args.num_tokens)):
|
||||
tokens.append(token)
|
||||
skip = 0
|
||||
for token, n in zip(
|
||||
models.generate(prompt, model, args.temp),
|
||||
range(args.max_tokens),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, flush=True)
|
||||
|
||||
|
||||
def load_model(folder: str):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
quantization = config.pop("quantization", None)
|
||||
model_args = ModelArgs(**config)
|
||||
model = Model(model_args)
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
model.update(weights)
|
||||
return model, tokenizer
|
||||
tokens.append(token.item())
|
||||
s = tokenizer.decode(tokens)
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
print(tokenizer.decode(tokens)[skip:], flush=True)
|
||||
print("=" * 10)
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -354,13 +295,13 @@ if __name__ == "__main__":
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load_model(args.model)
|
||||
model, tokenizer, _ = models.load(args.model)
|
||||
|
||||
# Freeze all layers other than LORA linears
|
||||
model.freeze()
|
||||
for l in model.layers[-args.lora_layers :]:
|
||||
l.attention.wq = LoRALinear.from_linear(l.attention.wq)
|
||||
l.attention.wv = LoRALinear.from_linear(l.attention.wv)
|
||||
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
||||
l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
|
||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
print(f"Total parameters {p:.3f}M")
|
||||
|
Reference in New Issue
Block a user