mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-23 14:08:07 +08:00 
			
		
		
		
	feat(lora): add de-quantized support for fuse.py (#351)
* feat(lora): add de-quantized support for fuse.py * address comments
This commit is contained in:
		
							
								
								
									
										34
									
								
								lora/fuse.py
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								lora/fuse.py
									
									
									
									
									
								
							| @@ -4,6 +4,7 @@ import argparse | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| import mlx.core as mx | import mlx.core as mx | ||||||
|  | import mlx.nn as nn | ||||||
| import utils | import utils | ||||||
| from mlx.utils import tree_flatten, tree_unflatten | from mlx.utils import tree_flatten, tree_unflatten | ||||||
| from models.lora import LoRALinear | from models.lora import LoRALinear | ||||||
| @@ -41,6 +42,12 @@ if __name__ == "__main__": | |||||||
|         type=str, |         type=str, | ||||||
|         default=None, |         default=None, | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "-d", | ||||||
|  |         "--de-quantize", | ||||||
|  |         help="Generate a de-quantized model.", | ||||||
|  |         action="store_true", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     print("Loading pretrained model") |     print("Loading pretrained model") | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
| @@ -53,7 +60,7 @@ if __name__ == "__main__": | |||||||
|  |  | ||||||
|     # Freeze all layers other than LORA linears |     # Freeze all layers other than LORA linears | ||||||
|     model.freeze() |     model.freeze() | ||||||
|     for l in model.model.layers[-lora_layers:]: |     for l in model.model.layers[len(model.model.layers) - lora_layers :]: | ||||||
|         l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) |         l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) | ||||||
|         l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) |         l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) | ||||||
|         if hasattr(l, "block_sparse_moe"): |         if hasattr(l, "block_sparse_moe"): | ||||||
| @@ -67,7 +74,32 @@ if __name__ == "__main__": | |||||||
|     ] |     ] | ||||||
|  |  | ||||||
|     model.update_modules(tree_unflatten(fused_linears)) |     model.update_modules(tree_unflatten(fused_linears)) | ||||||
|  |  | ||||||
|  |     if args.de_quantize: | ||||||
|  |         de_quantize_layers = [] | ||||||
|  |         for n, m in model.named_modules(): | ||||||
|  |             if isinstance(m, nn.QuantizedLinear): | ||||||
|  |                 bias = "bias" in m | ||||||
|  |                 weight = m.weight | ||||||
|  |                 weight = mx.dequantize( | ||||||
|  |                     weight, | ||||||
|  |                     m.scales, | ||||||
|  |                     m.biases, | ||||||
|  |                     m.group_size, | ||||||
|  |                     m.bits, | ||||||
|  |                 ).astype(mx.float16) | ||||||
|  |                 output_dims, input_dims = weight.shape | ||||||
|  |                 linear = nn.Linear(input_dims, output_dims, bias=bias) | ||||||
|  |                 linear.weight = weight | ||||||
|  |                 if bias: | ||||||
|  |                     linear.bias = m.bias | ||||||
|  |                 de_quantize_layers.append((n, linear)) | ||||||
|  |  | ||||||
|  |         model.update_modules(tree_unflatten(de_quantize_layers)) | ||||||
|  |  | ||||||
|     weights = dict(tree_flatten(model.parameters())) |     weights = dict(tree_flatten(model.parameters())) | ||||||
|  |     if args.de_quantize: | ||||||
|  |         config.pop("quantization", None) | ||||||
|     utils.save_model(args.save_path, weights, tokenizer, config) |     utils.save_model(args.save_path, weights, tokenizer, config) | ||||||
|  |  | ||||||
|     if args.upload_name is not None: |     if args.upload_name is not None: | ||||||
|   | |||||||
| @@ -16,7 +16,7 @@ class LoRALinear(nn.Module): | |||||||
|         lora_lin.linear = linear |         lora_lin.linear = linear | ||||||
|         return lora_lin |         return lora_lin | ||||||
|  |  | ||||||
|     def to_linear(self): |     def to_linear(self, de_quantize: bool = False): | ||||||
|         linear = self.linear |         linear = self.linear | ||||||
|         bias = "bias" in linear |         bias = "bias" in linear | ||||||
|         weight = linear.weight |         weight = linear.weight | ||||||
| @@ -43,7 +43,7 @@ class LoRALinear(nn.Module): | |||||||
|         if bias: |         if bias: | ||||||
|             fused_linear.bias = linear.bias |             fused_linear.bias = linear.bias | ||||||
|  |  | ||||||
|         if is_quantized: |         if is_quantized and not de_quantize: | ||||||
|             fused_linear = nn.QuantizedLinear.from_linear( |             fused_linear = nn.QuantizedLinear.from_linear( | ||||||
|                 fused_linear, |                 fused_linear, | ||||||
|                 linear.group_size, |                 linear.group_size, | ||||||
|   | |||||||
| @@ -115,11 +115,21 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15): | |||||||
| def save_model(save_dir: str, weights, tokenizer, config): | def save_model(save_dir: str, weights, tokenizer, config): | ||||||
|     save_dir = Path(save_dir) |     save_dir = Path(save_dir) | ||||||
|     save_dir.mkdir(parents=True, exist_ok=True) |     save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|     shards = make_shards(weights) |  | ||||||
|  |     shards = make_shards(weights, max_file_size_gibibyte=5) | ||||||
|  |     shards_count = len(shards) | ||||||
|  |     shard_file_format = ( | ||||||
|  |         "model-{:05d}-of-{:05d}.safetensors" | ||||||
|  |         if shards_count > 1 | ||||||
|  |         else "model.safetensors" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     for i, shard in enumerate(shards): |     for i, shard in enumerate(shards): | ||||||
|         # TODO use HF file name scheme for simplicity |         shard_name = shard_file_format.format(i + 1, shards_count) | ||||||
|         mx.save_safetensors(str(save_dir / f"weights.{i:02d}.safetensors"), shard) |         mx.save_safetensors(str(save_dir / shard_name), shard) | ||||||
|  |  | ||||||
|     tokenizer.save_pretrained(save_dir) |     tokenizer.save_pretrained(save_dir) | ||||||
|  |  | ||||||
|     with open(save_dir / "config.json", "w") as fid: |     with open(save_dir / "config.json", "w") as fid: | ||||||
|         json.dump(config, fid, indent=4) |         json.dump(config, fid, indent=4) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Anchen
					Anchen