mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 02:48:07 +08:00 
			
		
		
		
	Quantize example (#162)
* testing quantization * conversion + quantization working * one config processor * quantization in mistral / nits in llama * args for quantization * llama / mistral conversion in good shape * phi2 quantized * mixtral * qwen conversion
This commit is contained in:
		
							
								
								
									
										1
									
								
								llms/phi2/.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								llms/phi2/.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1 +0,0 @@ | ||||
| weights.npz | ||||
| @@ -15,7 +15,14 @@ Download and convert the model: | ||||
| python convert.py | ||||
| ``` | ||||
|  | ||||
| This will make the `weights.npz` file which MLX can read. | ||||
| To generate a 4-bit quantized model use the `-q` flag: | ||||
|  | ||||
| ``` | ||||
| python convert.py -q | ||||
| ``` | ||||
|  | ||||
| By default, the conversion script will make the directory `mlx_model` and save | ||||
| the converted `weights.npz`, and `config.json` there. | ||||
|  | ||||
| > [!TIP] Alternatively, you can also download a few converted checkpoints from | ||||
| > the [MLX Community](https://huggingface.co/mlx-community) organization on | ||||
|   | ||||
| @@ -1,7 +1,37 @@ | ||||
| import argparse | ||||
| import copy | ||||
| import json | ||||
| from pathlib import Path | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import numpy as np | ||||
| from mlx.utils import tree_flatten, tree_map, tree_unflatten | ||||
| from phi2 import ModelArgs, Phi2 | ||||
| from transformers import AutoModelForCausalLM | ||||
|  | ||||
|  | ||||
| def quantize(weights, config, args): | ||||
|     quantized_config = copy.deepcopy(config) | ||||
|  | ||||
|     # Load the model: | ||||
|     model = Phi2(ModelArgs()) | ||||
|     weights = tree_map(mx.array, weights) | ||||
|     model.update(tree_unflatten(list(weights.items()))) | ||||
|  | ||||
|     # Quantize the model: | ||||
|     nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) | ||||
|  | ||||
|     # Update the config: | ||||
|     quantized_config["quantization"] = { | ||||
|         "group_size": args.q_group_size, | ||||
|         "bits": args.q_bits, | ||||
|     } | ||||
|     quantized_weights = dict(tree_flatten(model.parameters())) | ||||
|  | ||||
|     return quantized_weights, quantized_config | ||||
|  | ||||
|  | ||||
| def replace_key(key: str) -> str: | ||||
|     if "wte.weight" in key: | ||||
|         key = "wte.weight" | ||||
| @@ -12,12 +42,50 @@ def replace_key(key: str) -> str: | ||||
|  | ||||
|  | ||||
| def convert(): | ||||
|     parser = argparse.ArgumentParser(description="Convert Phi-2 weights to MLX") | ||||
|     parser.add_argument( | ||||
|         "--mlx-path", | ||||
|         type=str, | ||||
|         default="mlx_model", | ||||
|         help="The path to save the MLX model.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "-q", | ||||
|         "--quantize", | ||||
|         help="Generate a quantized model.", | ||||
|         action="store_true", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--q_group_size", | ||||
|         help="Group size for quantization.", | ||||
|         type=int, | ||||
|         default=64, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--q_bits", | ||||
|         help="Bits per weight for quantization.", | ||||
|         type=int, | ||||
|         default=4, | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     mlx_path = Path(args.mlx_path) | ||||
|     mlx_path.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     model = AutoModelForCausalLM.from_pretrained( | ||||
|         "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True | ||||
|     ) | ||||
|     state_dict = model.state_dict() | ||||
|     weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} | ||||
|     np.savez("weights.npz", **weights) | ||||
|     params = {} | ||||
|     if args.quantize: | ||||
|         print("[INFO] Quantizing") | ||||
|         weights, params = quantize(weights, params, args) | ||||
|  | ||||
|     np.savez(str(mlx_path / "weights.npz"), **weights) | ||||
|     with open(mlx_path / "config.json", "w") as fid: | ||||
|         params["model_type"] = "phi2" | ||||
|         json.dump(params, fid, indent=4) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import argparse | ||||
| import json | ||||
| import math | ||||
| from dataclasses import dataclass | ||||
| from pathlib import Path | ||||
| @@ -158,8 +159,16 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): | ||||
| def load_model(model_path: str): | ||||
|     model = Phi2(ModelArgs()) | ||||
|     model_path = Path(model_path) | ||||
|     with open(model_path / "config.json", "r") as f: | ||||
|         config = json.loads(f.read()) | ||||
|         config.pop("model_type", None) | ||||
|         quantization = config.pop("quantization", None) | ||||
|     weights = mx.load(str(model_path / "weights.npz")) | ||||
|     model.update(tree_unflatten(list(weights.items()))) | ||||
|     weights = tree_unflatten(list(weights.items())) | ||||
|     if quantization is not None: | ||||
|         nn.QuantizedLinear.quantize_module(model, **quantization) | ||||
|     model.update(weights) | ||||
|  | ||||
|     tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) | ||||
|     return model, tokenizer | ||||
|  | ||||
| @@ -169,7 +178,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--model-path", | ||||
|         type=str, | ||||
|         default=".", | ||||
|         default="mlx_model", | ||||
|         help="The path to the model weights", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun