# Copyright © 2023 Apple Inc. import argparse from pathlib import Path import mlx.core as mx import utils from mlx.utils import tree_flatten, tree_unflatten from models.lora import LoRALinear if __name__ == "__main__": parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser.add_argument( "--model", default="mlx_model", help="The path to the local model directory or Hugging Face repo.", ) parser.add_argument( "--save-path", default="lora_fused_model", help="The path to save the fused model.", ) parser.add_argument( "--adapter-file", type=str, default="adapters.npz", help="Path to the trained adapter weights (npz or safetensors).", ) parser.add_argument( "--hf-path", help=( "Path to the original Hugging Face model. This is " "required for upload if --model is a local directory." ), type=str, default=None, ) parser.add_argument( "--upload-name", help="The name of model to upload to Hugging Face MLX Community", type=str, default=None, ) print("Loading pretrained model") args = parser.parse_args() model, tokenizer, config = utils.load(args.model) # Load adapters and get number of LoRA layers adapters = list(mx.load(args.adapter_file).items()) lora_layers = len([m for m in adapters if "q_proj.lora_a" in m[0]]) # Freeze all layers other than LORA linears model.freeze() for l in model.model.layers[-lora_layers:]: l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) if hasattr(l, "block_sparse_moe"): l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate) model.update(tree_unflatten(adapters)) fused_linears = [ (n, m.to_linear()) for n, m in model.named_modules() if isinstance(m, LoRALinear) ] model.update_modules(tree_unflatten(fused_linears)) weights = dict(tree_flatten(model.parameters())) utils.save_model(args.save_path, weights, tokenizer, config) if args.upload_name is not None: hf_path = args.hf_path if not Path(args.model).exists(): # If the model path doesn't exist, assume it's an HF repo hf_path = args.model elif hf_path is None: raise ValueError( "Must provide original Hugging Face repo to upload local model." ) utils.upload_to_hub(args.save_path, args.upload_name, hf_path)