| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | # Copyright © 2023 Apple Inc. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import argparse | 
					
						
							|  |  |  | from pathlib import Path | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import mlx.core as mx | 
					
						
							| 
									
										
										
										
											2024-01-22 17:32:24 -08:00
										 |  |  | import mlx.nn as nn | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | import utils | 
					
						
							|  |  |  | from mlx.utils import tree_flatten, tree_unflatten | 
					
						
							| 
									
										
										
										
											2024-01-12 13:45:30 -08:00
										 |  |  | from models.lora import LoRALinear | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--model", | 
					
						
							|  |  |  |         default="mlx_model", | 
					
						
							|  |  |  |         help="The path to the local model directory or Hugging Face repo.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--save-path", | 
					
						
							|  |  |  |         default="lora_fused_model", | 
					
						
							|  |  |  |         help="The path to save the fused model.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--adapter-file", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="adapters.npz", | 
					
						
							|  |  |  |         help="Path to the trained adapter weights (npz or safetensors).", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--hf-path", | 
					
						
							|  |  |  |         help=( | 
					
						
							|  |  |  |             "Path to the original Hugging Face model. This is " | 
					
						
							|  |  |  |             "required for upload if --model is a local directory." | 
					
						
							|  |  |  |         ), | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default=None, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--upload-name", | 
					
						
							|  |  |  |         help="The name of model to upload to Hugging Face MLX Community", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default=None, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-01-22 17:32:24 -08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "-d", | 
					
						
							|  |  |  |         "--de-quantize", | 
					
						
							|  |  |  |         help="Generate a de-quantized model.", | 
					
						
							|  |  |  |         action="store_true", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     print("Loading pretrained model") | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-12 13:45:30 -08:00
										 |  |  |     model, tokenizer, config = utils.load(args.model) | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Load adapters and get number of LoRA layers | 
					
						
							|  |  |  |     adapters = list(mx.load(args.adapter_file).items()) | 
					
						
							|  |  |  |     lora_layers = len([m for m in adapters if "q_proj.lora_a" in m[0]]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Freeze all layers other than LORA linears | 
					
						
							|  |  |  |     model.freeze() | 
					
						
							| 
									
										
										
										
											2024-01-22 17:32:24 -08:00
										 |  |  |     for l in model.model.layers[len(model.model.layers) - lora_layers :]: | 
					
						
							| 
									
										
										
										
											2024-01-12 13:45:30 -08:00
										 |  |  |         l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) | 
					
						
							|  |  |  |         l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) | 
					
						
							| 
									
										
										
										
											2024-01-20 06:07:45 -08:00
										 |  |  |         if hasattr(l, "block_sparse_moe"): | 
					
						
							|  |  |  |             l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate) | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     model.update(tree_unflatten(adapters)) | 
					
						
							|  |  |  |     fused_linears = [ | 
					
						
							|  |  |  |         (n, m.to_linear()) | 
					
						
							|  |  |  |         for n, m in model.named_modules() | 
					
						
							| 
									
										
										
										
											2024-01-12 13:45:30 -08:00
										 |  |  |         if isinstance(m, LoRALinear) | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  |     ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model.update_modules(tree_unflatten(fused_linears)) | 
					
						
							| 
									
										
										
										
											2024-01-22 17:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if args.de_quantize: | 
					
						
							|  |  |  |         de_quantize_layers = [] | 
					
						
							|  |  |  |         for n, m in model.named_modules(): | 
					
						
							|  |  |  |             if isinstance(m, nn.QuantizedLinear): | 
					
						
							|  |  |  |                 bias = "bias" in m | 
					
						
							|  |  |  |                 weight = m.weight | 
					
						
							|  |  |  |                 weight = mx.dequantize( | 
					
						
							|  |  |  |                     weight, | 
					
						
							|  |  |  |                     m.scales, | 
					
						
							|  |  |  |                     m.biases, | 
					
						
							|  |  |  |                     m.group_size, | 
					
						
							|  |  |  |                     m.bits, | 
					
						
							|  |  |  |                 ).astype(mx.float16) | 
					
						
							|  |  |  |                 output_dims, input_dims = weight.shape | 
					
						
							|  |  |  |                 linear = nn.Linear(input_dims, output_dims, bias=bias) | 
					
						
							|  |  |  |                 linear.weight = weight | 
					
						
							|  |  |  |                 if bias: | 
					
						
							|  |  |  |                     linear.bias = m.bias | 
					
						
							|  |  |  |                 de_quantize_layers.append((n, linear)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model.update_modules(tree_unflatten(de_quantize_layers)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  |     weights = dict(tree_flatten(model.parameters())) | 
					
						
							| 
									
										
										
										
											2024-01-22 17:32:24 -08:00
										 |  |  |     if args.de_quantize: | 
					
						
							|  |  |  |         config.pop("quantization", None) | 
					
						
							| 
									
										
										
										
											2024-01-10 19:19:32 +05:30
										 |  |  |     utils.save_model(args.save_path, weights, tokenizer, config) | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if args.upload_name is not None: | 
					
						
							|  |  |  |         hf_path = args.hf_path | 
					
						
							|  |  |  |         if not Path(args.model).exists(): | 
					
						
							|  |  |  |             # If the model path doesn't exist, assume it's an HF repo | 
					
						
							|  |  |  |             hf_path = args.model | 
					
						
							|  |  |  |         elif hf_path is None: | 
					
						
							|  |  |  |             raise ValueError( | 
					
						
							|  |  |  |                 "Must provide original Hugging Face repo to upload local model." | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         utils.upload_to_hub(args.save_path, args.upload_name, hf_path) |