mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
31
lora/lora.py
31
lora/lora.py
@@ -17,12 +17,10 @@ from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LoRA finetuning with Llama or Mistral"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
default="mlx_model",
|
||||
help="A path to the model files containing the tokenizer, weights, config.",
|
||||
)
|
||||
# Generation args
|
||||
@@ -332,18 +330,22 @@ def generate(model, prompt, tokenizer, args):
|
||||
print(s, flush=True)
|
||||
|
||||
|
||||
def load_model(folder: str, dtype=mx.float16):
|
||||
def load_model(folder: str):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
if config.get("vocab_size", -1) < 0:
|
||||
config["vocab_size"] = tokenizer.vocab_size
|
||||
quantization = config.pop("quantization", None)
|
||||
model_args = ModelArgs(**config)
|
||||
model = Model(model_args)
|
||||
if quantization is not None:
|
||||
quantization["linear_class_predicate"] = lambda m: isinstance(
|
||||
m, nn.Linear
|
||||
) and (m.weight.shape[0] != model_args.vocab_size)
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Model(model_args)
|
||||
model.update(weights)
|
||||
return model, tokenizer
|
||||
|
||||
@@ -374,7 +376,7 @@ if __name__ == "__main__":
|
||||
# Resume training the given adapters.
|
||||
if args.resume_adapter_file is not None:
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file)
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
|
||||
if args.train:
|
||||
print("Training")
|
||||
@@ -387,7 +389,12 @@ if __name__ == "__main__":
|
||||
mx.savez(args.adapter_file, **dict(tree_flatten(model.trainable_parameters())))
|
||||
|
||||
# Load the LoRA adapter weights which we assume should exist by this point
|
||||
model.load_weights(args.adapter_file)
|
||||
if not Path(args.adapter_file).is_file():
|
||||
raise ValueError(
|
||||
f"Adapter file {args.adapter_file} missing. "
|
||||
"Use --train to learn and save the adapters.npz."
|
||||
)
|
||||
model.load_weights(args.adapter_file, strict=False)
|
||||
|
||||
if args.test:
|
||||
print("Testing")
|
||||
|
Reference in New Issue
Block a user