mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	 f20e68fcc0
			
		
	
	f20e68fcc0
	
	
	
		
			
			* save format for transformers compatibility * save format for transformers compatibility arg * hardcode mlx * hardcode mlx format
		
			
				
	
	
		
			115 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			115 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023-2024 Apple Inc.
 | |
| 
 | |
| import argparse
 | |
| from pathlib import Path
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx.nn as nn
 | |
| import utils
 | |
| from mlx.utils import tree_flatten, tree_unflatten
 | |
| from models import LoRALinear
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
 | |
|     parser.add_argument(
 | |
|         "--model",
 | |
|         default="mlx_model",
 | |
|         help="The path to the local model directory or Hugging Face repo.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--save-path",
 | |
|         default="lora_fused_model",
 | |
|         help="The path to save the fused model.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--adapter-file",
 | |
|         type=str,
 | |
|         default="adapters.npz",
 | |
|         help="Path to the trained adapter weights (npz or safetensors).",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--hf-path",
 | |
|         help=(
 | |
|             "Path to the original Hugging Face model. This is "
 | |
|             "required for upload if --model is a local directory."
 | |
|         ),
 | |
|         type=str,
 | |
|         default=None,
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--upload-name",
 | |
|         help="The name of model to upload to Hugging Face MLX Community.",
 | |
|         type=str,
 | |
|         default=None,
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "-d",
 | |
|         "--de-quantize",
 | |
|         help="Generate a de-quantized model.",
 | |
|         action="store_true",
 | |
|     )
 | |
| 
 | |
|     print("Loading pretrained model")
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     model, tokenizer, config = utils.load(args.model)
 | |
| 
 | |
|     # Load adapters and get number of LoRA layers
 | |
|     adapters = list(mx.load(args.adapter_file).items())
 | |
|     lora_layers = len([m for m in adapters if "q_proj.lora_a" in m[0]])
 | |
| 
 | |
|     # Freeze all layers other than LORA linears
 | |
|     model.freeze()
 | |
|     for l in model.model.layers[len(model.model.layers) - lora_layers :]:
 | |
|         l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
 | |
|         l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
 | |
|         if hasattr(l, "block_sparse_moe"):
 | |
|             l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
 | |
| 
 | |
|     model.update(tree_unflatten(adapters))
 | |
|     fused_linears = [
 | |
|         (n, m.to_linear())
 | |
|         for n, m in model.named_modules()
 | |
|         if isinstance(m, LoRALinear)
 | |
|     ]
 | |
| 
 | |
|     model.update_modules(tree_unflatten(fused_linears))
 | |
| 
 | |
|     if args.de_quantize:
 | |
|         de_quantize_layers = []
 | |
|         for n, m in model.named_modules():
 | |
|             if isinstance(m, nn.QuantizedLinear):
 | |
|                 bias = "bias" in m
 | |
|                 weight = m.weight
 | |
|                 weight = mx.dequantize(
 | |
|                     weight,
 | |
|                     m.scales,
 | |
|                     m.biases,
 | |
|                     m.group_size,
 | |
|                     m.bits,
 | |
|                 ).astype(mx.float16)
 | |
|                 output_dims, input_dims = weight.shape
 | |
|                 linear = nn.Linear(input_dims, output_dims, bias=bias)
 | |
|                 linear.weight = weight
 | |
|                 if bias:
 | |
|                     linear.bias = m.bias
 | |
|                 de_quantize_layers.append((n, linear))
 | |
| 
 | |
|         model.update_modules(tree_unflatten(de_quantize_layers))
 | |
| 
 | |
|     weights = dict(tree_flatten(model.parameters()))
 | |
|     if args.de_quantize:
 | |
|         config.pop("quantization", None)
 | |
|     utils.save_model(args.save_path, weights, tokenizer, config)
 | |
| 
 | |
|     if args.upload_name is not None:
 | |
|         hf_path = args.hf_path
 | |
|         if not Path(args.model).exists():
 | |
|             # If the model path doesn't exist, assume it's an HF repo
 | |
|             hf_path = args.model
 | |
|         elif hf_path is None:
 | |
|             raise ValueError(
 | |
|                 "Must provide original Hugging Face repo to upload local model."
 | |
|             )
 | |
|         utils.upload_to_hub(args.save_path, args.upload_name, hf_path)
 |