mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	Quantize example (#162)
* testing quantization * conversion + quantization working * one config processor * quantization in mistral / nits in llama * args for quantization * llama / mistral conversion in good shape * phi2 quantized * mixtral * qwen conversion
This commit is contained in:
		| @@ -43,10 +43,18 @@ Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so | ||||
| MLX can read them: | ||||
|  | ||||
| ``` | ||||
| python convert.py --model-path $MIXTRAL_MODEL/ | ||||
| python convert.py --torch-path $MIXTRAL_MODEL/ | ||||
| ``` | ||||
|  | ||||
| The conversion script will save the converted weights in the same location. | ||||
| To generate a 4-bit quantized model, use ``-q``. For a full list of options: | ||||
|  | ||||
| ``` | ||||
| python convert.py --help | ||||
| ``` | ||||
|  | ||||
| By default, the conversion script will make the directory `mlx_model` and save | ||||
| the converted `weights.npz`, `tokenizer.model`, and `config.json` there. | ||||
|  | ||||
|  | ||||
| ### Generate | ||||
|  | ||||
|   | ||||
| @@ -1,59 +1,152 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import argparse | ||||
| import copy | ||||
| import glob | ||||
| import json | ||||
| import shutil | ||||
| from pathlib import Path | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import numpy as np | ||||
| import torch | ||||
| from mixtral import Mixtral, ModelArgs | ||||
| from mlx.utils import tree_flatten, tree_map, tree_unflatten | ||||
|  | ||||
|  | ||||
| def convert(k, v, config): | ||||
|     v = v.to(torch.float16).numpy() | ||||
|     if "block_sparse_moe" not in k: | ||||
|         return [(k, v)] | ||||
|     if "gate" in k: | ||||
|         return [(k.replace("block_sparse_moe", "feed_forward"), v)] | ||||
| def convert(weights, config): | ||||
|     def convert_single(k, v): | ||||
|         v = v.to(torch.float16).numpy() | ||||
|         if "block_sparse_moe" not in k: | ||||
|             return [(k, v)] | ||||
|         if "gate" in k: | ||||
|             return [(k.replace("block_sparse_moe", "feed_forward"), v)] | ||||
|  | ||||
|     # From: layers.N.block_sparse_moe.w | ||||
|     # To: layers.N.experts.M.w | ||||
|     num_experts = args["moe"]["num_experts"] | ||||
|     key_path = k.split(".") | ||||
|     v = np.split(v, num_experts, axis=0) | ||||
|     if key_path[-1] == "w2": | ||||
|         v = [u.T for u in v] | ||||
|         # From: layers.N.block_sparse_moe.w | ||||
|         # To: layers.N.experts.M.w | ||||
|         num_experts = config["moe"]["num_experts"] | ||||
|         key_path = k.split(".") | ||||
|         v = np.split(v, num_experts, axis=0) | ||||
|         if key_path[-1] == "w2": | ||||
|             v = [u.T for u in v] | ||||
|  | ||||
|     w_name = key_path.pop() | ||||
|     key_path[-1] = "feed_forward.experts" | ||||
|     return [ | ||||
|         (".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v) | ||||
|     ] | ||||
|         w_name = key_path.pop() | ||||
|         key_path[-1] = "feed_forward.experts" | ||||
|         return [ | ||||
|             (".".join(key_path + [str(e), w_name, "weight"]), u) | ||||
|             for e, u in enumerate(v) | ||||
|         ] | ||||
|  | ||||
|     state = torch.load(tf) | ||||
|     weights = {} | ||||
|     for k, v in state.items(): | ||||
|         weights.update(convert_single(k, v)) | ||||
|     return weights | ||||
|  | ||||
|  | ||||
| def quantize(weights, config, args): | ||||
|     quantized_config = copy.deepcopy(config) | ||||
|  | ||||
|     # Load the model and update with the subset of weights: | ||||
|     config.pop("quantization", None) | ||||
|     model = Mixtral(ModelArgs(**config)) | ||||
|     all_weights = dict(tree_flatten(model.parameters())) | ||||
|  | ||||
|     weights = tree_map(mx.array, weights) | ||||
|  | ||||
|     all_weights.update(weights) | ||||
|     all_weights = tree_unflatten(list(all_weights.items())) | ||||
|     model.update(all_weights) | ||||
|  | ||||
|     # Quantize the model: | ||||
|     nn.QuantizedLinear.quantize_module( | ||||
|         model, | ||||
|         args.q_group_size, | ||||
|         args.q_bits, | ||||
|         # TODO: Quantize gate matrices when < 32 tiles supported | ||||
|         linear_class_predicate=lambda m: isinstance(m, nn.Linear) | ||||
|         and m.weight.shape[0] != 8, | ||||
|     ) | ||||
|  | ||||
|     # Extract the subset of quantized weights: | ||||
|     all_weights = dict(tree_flatten(model.parameters())) | ||||
|     quantized_weights = {} | ||||
|     for k, v in all_weights.items(): | ||||
|         if k not in weights: | ||||
|             continue | ||||
|         quantized_weights[k] = v | ||||
|         prefix = k.split(".")[:-1] | ||||
|         for qw in ["scales", "biases"]: | ||||
|             if (k := ".".join(prefix + [qw])) in all_weights: | ||||
|                 quantized_weights[k] = all_weights[k] | ||||
|  | ||||
|     # Update the config: | ||||
|     quantized_config["quantization"] = { | ||||
|         "group_size": args.q_group_size, | ||||
|         "bits": args.q_bits, | ||||
|     } | ||||
|     return quantized_weights, quantized_config | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") | ||||
|     parser.add_argument( | ||||
|         "--model-path", | ||||
|         "--torch-path", | ||||
|         type=str, | ||||
|         default="Mixtral-8x7B-v0.1/", | ||||
|         help="The path to the Mixtral model. The MLX model weights will also be saved there.", | ||||
|         default="Mixtral-8x7B-v0.1", | ||||
|         help="The path to the PyTorch model.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--mlx-path", | ||||
|         type=str, | ||||
|         default="mlx_model", | ||||
|         help="The path to save the MLX model.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "-q", | ||||
|         "--quantize", | ||||
|         help="Generate a quantized model.", | ||||
|         action="store_true", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--q_group_size", | ||||
|         help="Group size for quantization.", | ||||
|         type=int, | ||||
|         default=64, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--q_bits", | ||||
|         help="Bits per weight for quantization.", | ||||
|         type=int, | ||||
|         default=4, | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|     model_path = Path(args.model_path) | ||||
|     torch_path = Path(args.torch_path) | ||||
|     mlx_path = Path(args.mlx_path) | ||||
|     mlx_path.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     with open("params.json") as fid: | ||||
|         args = json.load(fid) | ||||
|         args["model_type"] = "mixtral" | ||||
|     with open(model_path / "config.json", "w") as f: | ||||
|         json.dump(args, f, indent=4) | ||||
|         config = json.load(fid) | ||||
|  | ||||
|     torch_files = glob.glob(str(model_path / "consolidated.*.pt")) | ||||
|     # Copy tokenizer | ||||
|     shutil.copyfile( | ||||
|         str(torch_path / "tokenizer.model"), | ||||
|         str(mlx_path / "tokenizer.model"), | ||||
|     ) | ||||
|  | ||||
|     # Convert and save model in shards | ||||
|     torch_files = glob.glob(str(torch_path / "consolidated.*.pt")) | ||||
|     torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2])) | ||||
|     for e, tf in enumerate(torch_files): | ||||
|         print(f"[INFO] Converting file {e + 1}/{len(torch_files)}") | ||||
|         state = torch.load(tf) | ||||
|         new_state = {} | ||||
|         for k, v in state.items(): | ||||
|             new_state.update(convert(k, v, args)) | ||||
|         np.savez(str(model_path / f"weights.{e}.npz"), **new_state) | ||||
|         weights = convert(tf, config) | ||||
|         if args.quantize: | ||||
|             print("[INFO] Quantizing") | ||||
|             weights, config = quantize(weights, config, args) | ||||
|         np.savez(str(mlx_path / f"weights.{e}.npz"), **weights) | ||||
|  | ||||
|     # Save updated config | ||||
|     with open(mlx_path / "config.json", "w") as f: | ||||
|         config["model_type"] = "mixtral" | ||||
|         json.dump(config, f, indent=4) | ||||
|   | ||||
| @@ -244,20 +244,27 @@ class Tokenizer: | ||||
|         return out | ||||
|  | ||||
|  | ||||
| def load_model(folder: str, dtype=mx.float16): | ||||
| def load_model(folder: str): | ||||
|     model_path = Path(folder) | ||||
|     tokenizer = Tokenizer(str(model_path / "tokenizer.model")) | ||||
|     with open(model_path / "config.json", "r") as f: | ||||
|         config = json.loads(f.read()) | ||||
|         config.pop("model_type", None) | ||||
|         quantization = config.pop("quantization", None) | ||||
|         model_args = ModelArgs(**config) | ||||
|     weight_files = glob.glob(str(model_path / "weights.*.npz")) | ||||
|     weights = {} | ||||
|     for wf in weight_files: | ||||
|         weights.update(mx.load(wf).items()) | ||||
|     weights = tree_unflatten(list(weights.items())) | ||||
|     weights = tree_map(lambda p: p.astype(dtype), weights) | ||||
|     model = Mixtral(model_args) | ||||
|     if quantization is not None: | ||||
|         # TODO: Quantize gate matrices when < 32 tiles supported | ||||
|         quantization["linear_class_predicate"] = ( | ||||
|             lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8 | ||||
|         ) | ||||
|         nn.QuantizedLinear.quantize_module(model, **quantization) | ||||
|  | ||||
|     model.update(weights) | ||||
|     return model, tokenizer | ||||
|  | ||||
| @@ -284,7 +291,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--model-path", | ||||
|         type=str, | ||||
|         default="Mixtral-8x7B-v0.1", | ||||
|         default="mlx_model", | ||||
|         help="The path to the model weights, tokenizer, and config", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun