From 3cf436b529ea58d6c0c0a29c0dd799908cd4497d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 21 Dec 2023 12:59:37 -0800 Subject: [PATCH] Quantize example (#162) * testing quantization * conversion + quantization working * one config processor * quantization in mistral / nits in llama * args for quantization * llama / mistral conversion in good shape * phi2 quantized * mixtral * qwen conversion --- llms/llama/README.md | 18 +++-- llms/llama/convert.py | 82 ++++++++++++++++++--- llms/llama/llama.py | 64 +++++++++------- llms/mistral/README.md | 13 +++- llms/mistral/convert.py | 88 +++++++++++++++++++--- llms/mistral/mistral.py | 14 ++-- llms/mixtral/README.md | 12 ++- llms/mixtral/convert.py | 157 ++++++++++++++++++++++++++++++++-------- llms/mixtral/mixtral.py | 13 +++- llms/phi2/.gitignore | 1 - llms/phi2/README.md | 9 ++- llms/phi2/convert.py | 70 +++++++++++++++++- llms/phi2/phi2.py | 13 +++- llms/qwen/.gitignore | 2 - llms/qwen/README.md | 12 ++- llms/qwen/convert.py | 87 +++++++++++++++++++--- llms/qwen/qwen.py | 24 ++++-- 17 files changed, 553 insertions(+), 126 deletions(-) delete mode 100644 llms/phi2/.gitignore delete mode 100644 llms/qwen/.gitignore diff --git a/llms/llama/README.md b/llms/llama/README.md index ffa8105d..3d9a97f1 100644 --- a/llms/llama/README.md +++ b/llms/llama/README.md @@ -30,24 +30,32 @@ Face](https://huggingface.co/TinyLlama). Convert the weights with: ``` -python convert.py --model-path +python convert.py --torch-path +``` + +To generate a 4-bit quantized model use the `-q` flag: + +``` +python convert.py --torch-path -q ``` For TinyLlama use ``` -python convert.py --model-path --model-name tiny_llama +python convert.py --torch-path --model-name tiny_llama ``` -The conversion script will save the converted weights in the same location. +By default, the conversion script will make the directory `mlx_model` and save +the converted `weights.npz`, `tokenizer.model`, and `config.json` there. + ### Run Once you've converted the weights to MLX format, you can interact with the -LlaMA model: +LlamA model: ``` -python llama.py --prompt "hello" +python llama.py --prompt "hello" ``` Run `python llama.py --help` for more details. diff --git a/llms/llama/convert.py b/llms/llama/convert.py index 69fe1af8..618c3070 100644 --- a/llms/llama/convert.py +++ b/llms/llama/convert.py @@ -2,12 +2,18 @@ import argparse import collections +import copy import glob import json +import shutil from pathlib import Path +import mlx.core as mx +import mlx.nn as nn import numpy as np import torch +from llama import Llama, ModelArgs, sanitize_config +from mlx.utils import tree_flatten, tree_map, tree_unflatten def llama(model_path): @@ -57,9 +63,7 @@ def tiny_llama(model_path): except ImportError as e: print("The transformers package must be installed for this model conversion:") print("pip install transformers") - import sys - - sys.exit(0) + exit(0) model = transformers.AutoModelForCausalLM.from_pretrained( str(model_path) @@ -114,11 +118,40 @@ def tiny_llama(model_path): return weights, params +def quantize(weights, config, args): + quantized_config = copy.deepcopy(config) + + # Load the model: + config = sanitize_config(config, weights) + model = Llama(ModelArgs(**config)) + weights = tree_map(mx.array, weights) + model.update(tree_unflatten(list(weights.items()))) + + # Quantize the model: + nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) + + # Update the config: + quantized_config["quantization"] = { + "group_size": args.q_group_size, + "bits": args.q_bits, + } + quantized_weights = dict(tree_flatten(model.parameters())) + + return quantized_weights, quantized_config + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") parser.add_argument( - "--model-path", - help="Path to the model. The MLX weights will also be saved there.", + "--torch-path", + type=str, + help="Path to the PyTorch model.", + ) + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_model", + help="Path to save the MLX model.", ) parser.add_argument( "--model-name", @@ -130,12 +163,43 @@ if __name__ == "__main__": choices=["tiny_llama", "llama"], default="llama", ) + parser.add_argument( + "-q", + "--quantize", + help="Generate a quantized model.", + action="store_true", + ) + parser.add_argument( + "--q_group_size", + help="Group size for quantization.", + type=int, + default=64, + ) + parser.add_argument( + "--q_bits", + help="Bits per weight for quantization.", + type=int, + default=4, + ) args = parser.parse_args() - model_path = Path(args.model_path) - weights, params = globals()[args.model_name](model_path) + torch_path = Path(args.torch_path) + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + + print("[INFO] Loading") + weights, params = globals()[args.model_name](torch_path) params["model_type"] = "llama" - np.savez(str(model_path / "weights.npz"), **weights) - with open(model_path / "config.json", "w") as fid: + if args.quantize: + print("[INFO] Quantizing") + weights, params = quantize(weights, params, args) + + print("[INFO] Saving") + shutil.copyfile( + str(torch_path / "tokenizer.model"), + str(mlx_path / "tokenizer.model"), + ) + np.savez(str(mlx_path / "weights.npz"), **weights) + with open(mlx_path / "config.json", "w") as fid: json.dump(params, fid, indent=4) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 293f7210..6a7352f3 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -178,6 +178,12 @@ class Llama(nn.Module): return self.output(x) def generate(self, x, temp=1.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + cache = [] # Make an additive causal mask. We will need that to process the prompt. @@ -194,7 +200,7 @@ class Llama(nn.Module): x = self.norm(x) # We only care about the last logits that generate the next token y = self.output(x[:, -1]) - y = mx.random.categorical(y * (1 / temp)) + y = sample(y) # y now has size [1] # Since MLX is lazily evaluated nothing is computed yet. @@ -218,8 +224,7 @@ class Llama(nn.Module): # old cache the moment it is not needed anymore. x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) x = self.norm(x) - y = self.output(x[:, -1]) - y = mx.random.categorical(y * (1 / temp)) + y = sample(self.output(x[:, -1])) yield y @@ -326,38 +331,46 @@ def few_shot_generate(args): print() +def sanitize_config(config, weights): + config.pop("model_type", None) + n_heads = config["n_heads"] + if "n_kv_heads" not in config: + config["n_kv_heads"] = n_heads + if "head_dim" not in config: + config["head_dim"] = config["dim"] // n_heads + if "hidden_dim" not in config: + config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] + if config.get("vocab_size", -1) < 0: + config["vocab_size"] = weights["output.weight"].shape[-1] + if "rope_theta" not in config: + config["rope_theta"] = 10000 + unused = ["multiple_of", "ffn_dim_multiplier"] + for k in unused: + config.pop(k, None) + return config + + def load_model(model_path): model_path = Path(model_path) weights = mx.load(str(model_path / "weights.npz")) with open(model_path / "config.json", "r") as f: - config = json.loads(f.read()) - config.pop("model_type", None) - n_heads = config["n_heads"] - if "n_kv_heads" not in config: - config["n_kv_heads"] = n_heads - if "head_dim" not in config: - config["head_dim"] = config["dim"] // n_heads - if "hidden_dim" not in config: - config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] - if config.get("vocab_size", -1) < 0: - config["vocab_size"] = weights["output.weight"].shape[-1] - if "rope_theta" not in config: - config["rope_theta"] = 10000 - unused = ["multiple_of", "ffn_dim_multiplier"] - for k in unused: - if k in config: - config.pop(k) + config = sanitize_config(json.loads(f.read()), weights) + quantization = config.pop("quantization", None) model = Llama(ModelArgs(**config)) + if quantization is not None: + nn.QuantizedLinear.quantize_module(model, **quantization) model.update(tree_unflatten(list(weights.items()))) - return model + tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model")) + return model, tokenizer if __name__ == "__main__": parser = argparse.ArgumentParser(description="Llama inference script") parser.add_argument( - "model", help="Path to the model directory containing the MLX weights" + "--model-path", + help="Path to the model directory containing the MLX weights", + default="mlx_model", ) - parser.add_argument("tokenizer", help="The sentencepiece tokenizer") parser.add_argument( "--prompt", help="The message to be processed by the model. Ignored when --few-shot is provided.", @@ -374,7 +387,7 @@ if __name__ == "__main__": "--write-every", type=int, default=1, help="After how many tokens to detokenize" ) parser.add_argument( - "--temp", type=float, default=0.8, help="The sampling temperature" + "--temp", type=float, default=0.0, help="The sampling temperature" ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") @@ -382,9 +395,8 @@ if __name__ == "__main__": mx.random.seed(args.seed) - tokenizer = SentencePieceProcessor(model_file=args.tokenizer) print("[INFO] Loading model from disk.") - model = load_model(args.model) + model, tokenizer = load_model(args.model_path) if args.few_shot: few_shot_generate(args) else: diff --git a/llms/mistral/README.md b/llms/mistral/README.md index 5da5ace0..2a34497f 100644 --- a/llms/mistral/README.md +++ b/llms/mistral/README.md @@ -23,10 +23,17 @@ tar -xf mistral-7B-v0.1.tar Then, convert the weights with: ``` -python convert.py +python convert.py --torch-path ``` -The conversion script will save the converted weights in the same location. +To generate a 4-bit quantized model, use ``-q``. For a full list of options: + +``` +python convert.py --help +``` + +By default, the conversion script will make the directory `mlx_model` and save +the converted `weights.npz`, `tokenizer.model`, and `config.json` there. > [!TIP] > Alternatively, you can also download a few converted checkpoints from the @@ -40,7 +47,7 @@ Once you've converted the weights to MLX format, you can generate text with the Mistral model: ``` -python mistral.py --prompt "It is a truth universally acknowledged," --temp 0 +python mistral.py --prompt "It is a truth universally acknowledged," ``` Run `python mistral.py --help` for more details. diff --git a/llms/mistral/convert.py b/llms/mistral/convert.py index 792731db..808c3405 100644 --- a/llms/mistral/convert.py +++ b/llms/mistral/convert.py @@ -1,32 +1,98 @@ # Copyright © 2023 Apple Inc. import argparse +import copy import json +import shutil from pathlib import Path +import mlx.core as mx +import mlx.nn as nn import numpy as np import torch +from mistral import Mistral, ModelArgs +from mlx.utils import tree_flatten, tree_map, tree_unflatten + + +def quantize(weights, config, args): + quantized_config = copy.deepcopy(config) + + # Load the model: + config.pop("sliding_window", None) + model = Mistral(ModelArgs(**config)) + weights = tree_map(mx.array, weights) + model.update(tree_unflatten(list(weights.items()))) + + # Quantize the model: + nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) + + # Update the config: + quantized_config["quantization"] = { + "group_size": args.q_group_size, + "bits": args.q_bits, + } + quantized_weights = dict(tree_flatten(model.parameters())) + + return quantized_weights, quantized_config + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") parser.add_argument( - "--model-path", + "--torch-path", type=str, - default="mistral-7B-v0.1/", - help="The path to the Mistral model. The MLX weights will also be saved there.", + default="mistral-7B-v0.1", + help="The path to the PyTorch model.", + ) + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_model", + help="The path to save the MLX model.", + ) + parser.add_argument( + "-q", + "--quantize", + help="Generate a quantized model.", + action="store_true", + ) + parser.add_argument( + "--q_group_size", + help="Group size for quantization.", + type=int, + default=64, + ) + parser.add_argument( + "--q_bits", + help="Bits per weight for quantization.", + type=int, + default=4, ) args = parser.parse_args() - model_path = Path(args.model_path) - state = torch.load(str(model_path / "consolidated.00.pth")) - np.savez( - str(model_path / "weights.npz"), - **{k: v.to(torch.float16).numpy() for k, v in state.items()} + torch_path = Path(args.torch_path) + state = torch.load(str(torch_path / "consolidated.00.pth")) + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + + weights = {k: v.to(torch.float16).numpy() for k, v in state.items()} + with open(torch_path / "params.json", "r") as f: + config = json.loads(f.read()) + + if args.quantize: + print("[INFO] Quantizing") + weights, config = quantize(weights, config, args) + + # Save weights + np.savez(str(mlx_path / "weights.npz"), **weights) + + # Copy tokenizer + shutil.copyfile( + str(torch_path / "tokenizer.model"), + str(mlx_path / "tokenizer.model"), ) # Save config.json with model_type - with open(model_path / "params.json", "r") as f: - config = json.loads(f.read()) + with open(mlx_path / "config.json", "w") as f: config["model_type"] = "mistral" - with open(model_path / "config.json", "w") as f: json.dump(config, f, indent=4) diff --git a/llms/mistral/mistral.py b/llms/mistral/mistral.py index 44d3fe91..f023ae02 100644 --- a/llms/mistral/mistral.py +++ b/llms/mistral/mistral.py @@ -8,7 +8,7 @@ from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_map, tree_unflatten +from mlx.utils import tree_unflatten from sentencepiece import SentencePieceProcessor @@ -189,18 +189,20 @@ class Tokenizer: return out -def load_model(folder: str, dtype=mx.float16): +def load_model(folder: str): model_path = Path(folder) tokenizer = Tokenizer(str(model_path / "tokenizer.model")) with open(model_path / "config.json", "r") as f: config = json.loads(f.read()) config.pop("sliding_window", None) config.pop("model_type", None) + quantization = config.pop("quantization", None) model_args = ModelArgs(**config) weights = mx.load(str(model_path / "weights.npz")) weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: p.astype(dtype), weights) model = Mistral(model_args) + if quantization is not None: + nn.QuantizedLinear.quantize_module(model, **quantization) model.update(weights) return model, tokenizer @@ -227,7 +229,7 @@ if __name__ == "__main__": parser.add_argument( "--model-path", type=str, - default="mistral-7B-v0.1", + default="mlx_model", help="The path to the model weights and tokenizer", ) parser.add_argument( @@ -236,7 +238,7 @@ if __name__ == "__main__": default="In the beginning the Universe was created.", ) parser.add_argument( - "--max_tokens", + "--max-tokens", "-m", type=int, default=100, @@ -246,7 +248,7 @@ if __name__ == "__main__": "--temp", help="The sampling temperature.", type=float, - default=1.0, + default=0.0, ) parser.add_argument( "--tokens_per_eval", diff --git a/llms/mixtral/README.md b/llms/mixtral/README.md index 858711dd..49e50c91 100644 --- a/llms/mixtral/README.md +++ b/llms/mixtral/README.md @@ -43,10 +43,18 @@ Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so MLX can read them: ``` -python convert.py --model-path $MIXTRAL_MODEL/ +python convert.py --torch-path $MIXTRAL_MODEL/ ``` -The conversion script will save the converted weights in the same location. +To generate a 4-bit quantized model, use ``-q``. For a full list of options: + +``` +python convert.py --help +``` + +By default, the conversion script will make the directory `mlx_model` and save +the converted `weights.npz`, `tokenizer.model`, and `config.json` there. + ### Generate diff --git a/llms/mixtral/convert.py b/llms/mixtral/convert.py index 3069cf5a..023bf9a5 100644 --- a/llms/mixtral/convert.py +++ b/llms/mixtral/convert.py @@ -1,59 +1,152 @@ # Copyright © 2023 Apple Inc. import argparse +import copy import glob import json +import shutil from pathlib import Path +import mlx.core as mx +import mlx.nn as nn import numpy as np import torch +from mixtral import Mixtral, ModelArgs +from mlx.utils import tree_flatten, tree_map, tree_unflatten -def convert(k, v, config): - v = v.to(torch.float16).numpy() - if "block_sparse_moe" not in k: - return [(k, v)] - if "gate" in k: - return [(k.replace("block_sparse_moe", "feed_forward"), v)] +def convert(weights, config): + def convert_single(k, v): + v = v.to(torch.float16).numpy() + if "block_sparse_moe" not in k: + return [(k, v)] + if "gate" in k: + return [(k.replace("block_sparse_moe", "feed_forward"), v)] - # From: layers.N.block_sparse_moe.w - # To: layers.N.experts.M.w - num_experts = args["moe"]["num_experts"] - key_path = k.split(".") - v = np.split(v, num_experts, axis=0) - if key_path[-1] == "w2": - v = [u.T for u in v] + # From: layers.N.block_sparse_moe.w + # To: layers.N.experts.M.w + num_experts = config["moe"]["num_experts"] + key_path = k.split(".") + v = np.split(v, num_experts, axis=0) + if key_path[-1] == "w2": + v = [u.T for u in v] - w_name = key_path.pop() - key_path[-1] = "feed_forward.experts" - return [ - (".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v) - ] + w_name = key_path.pop() + key_path[-1] = "feed_forward.experts" + return [ + (".".join(key_path + [str(e), w_name, "weight"]), u) + for e, u in enumerate(v) + ] + + state = torch.load(tf) + weights = {} + for k, v in state.items(): + weights.update(convert_single(k, v)) + return weights + + +def quantize(weights, config, args): + quantized_config = copy.deepcopy(config) + + # Load the model and update with the subset of weights: + config.pop("quantization", None) + model = Mixtral(ModelArgs(**config)) + all_weights = dict(tree_flatten(model.parameters())) + + weights = tree_map(mx.array, weights) + + all_weights.update(weights) + all_weights = tree_unflatten(list(all_weights.items())) + model.update(all_weights) + + # Quantize the model: + nn.QuantizedLinear.quantize_module( + model, + args.q_group_size, + args.q_bits, + # TODO: Quantize gate matrices when < 32 tiles supported + linear_class_predicate=lambda m: isinstance(m, nn.Linear) + and m.weight.shape[0] != 8, + ) + + # Extract the subset of quantized weights: + all_weights = dict(tree_flatten(model.parameters())) + quantized_weights = {} + for k, v in all_weights.items(): + if k not in weights: + continue + quantized_weights[k] = v + prefix = k.split(".")[:-1] + for qw in ["scales", "biases"]: + if (k := ".".join(prefix + [qw])) in all_weights: + quantized_weights[k] = all_weights[k] + + # Update the config: + quantized_config["quantization"] = { + "group_size": args.q_group_size, + "bits": args.q_bits, + } + return quantized_weights, quantized_config if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") parser.add_argument( - "--model-path", + "--torch-path", type=str, - default="Mixtral-8x7B-v0.1/", - help="The path to the Mixtral model. The MLX model weights will also be saved there.", + default="Mixtral-8x7B-v0.1", + help="The path to the PyTorch model.", + ) + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_model", + help="The path to save the MLX model.", + ) + parser.add_argument( + "-q", + "--quantize", + help="Generate a quantized model.", + action="store_true", + ) + parser.add_argument( + "--q_group_size", + help="Group size for quantization.", + type=int, + default=64, + ) + parser.add_argument( + "--q_bits", + help="Bits per weight for quantization.", + type=int, + default=4, ) args = parser.parse_args() - model_path = Path(args.model_path) + torch_path = Path(args.torch_path) + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) with open("params.json") as fid: - args = json.load(fid) - args["model_type"] = "mixtral" - with open(model_path / "config.json", "w") as f: - json.dump(args, f, indent=4) + config = json.load(fid) - torch_files = glob.glob(str(model_path / "consolidated.*.pt")) + # Copy tokenizer + shutil.copyfile( + str(torch_path / "tokenizer.model"), + str(mlx_path / "tokenizer.model"), + ) + + # Convert and save model in shards + torch_files = glob.glob(str(torch_path / "consolidated.*.pt")) torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2])) for e, tf in enumerate(torch_files): print(f"[INFO] Converting file {e + 1}/{len(torch_files)}") - state = torch.load(tf) - new_state = {} - for k, v in state.items(): - new_state.update(convert(k, v, args)) - np.savez(str(model_path / f"weights.{e}.npz"), **new_state) + weights = convert(tf, config) + if args.quantize: + print("[INFO] Quantizing") + weights, config = quantize(weights, config, args) + np.savez(str(mlx_path / f"weights.{e}.npz"), **weights) + + # Save updated config + with open(mlx_path / "config.json", "w") as f: + config["model_type"] = "mixtral" + json.dump(config, f, indent=4) diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index fc6c95e7..30fa8d8b 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -244,20 +244,27 @@ class Tokenizer: return out -def load_model(folder: str, dtype=mx.float16): +def load_model(folder: str): model_path = Path(folder) tokenizer = Tokenizer(str(model_path / "tokenizer.model")) with open(model_path / "config.json", "r") as f: config = json.loads(f.read()) config.pop("model_type", None) + quantization = config.pop("quantization", None) model_args = ModelArgs(**config) weight_files = glob.glob(str(model_path / "weights.*.npz")) weights = {} for wf in weight_files: weights.update(mx.load(wf).items()) weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: p.astype(dtype), weights) model = Mixtral(model_args) + if quantization is not None: + # TODO: Quantize gate matrices when < 32 tiles supported + quantization["linear_class_predicate"] = ( + lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8 + ) + nn.QuantizedLinear.quantize_module(model, **quantization) + model.update(weights) return model, tokenizer @@ -284,7 +291,7 @@ if __name__ == "__main__": parser.add_argument( "--model-path", type=str, - default="Mixtral-8x7B-v0.1", + default="mlx_model", help="The path to the model weights, tokenizer, and config", ) parser.add_argument( diff --git a/llms/phi2/.gitignore b/llms/phi2/.gitignore deleted file mode 100644 index 258ec872..00000000 --- a/llms/phi2/.gitignore +++ /dev/null @@ -1 +0,0 @@ -weights.npz diff --git a/llms/phi2/README.md b/llms/phi2/README.md index b02d017f..c79dd5e8 100644 --- a/llms/phi2/README.md +++ b/llms/phi2/README.md @@ -15,7 +15,14 @@ Download and convert the model: python convert.py ``` -This will make the `weights.npz` file which MLX can read. +To generate a 4-bit quantized model use the `-q` flag: + +``` +python convert.py -q +``` + +By default, the conversion script will make the directory `mlx_model` and save +the converted `weights.npz`, and `config.json` there. > [!TIP] Alternatively, you can also download a few converted checkpoints from > the [MLX Community](https://huggingface.co/mlx-community) organization on diff --git a/llms/phi2/convert.py b/llms/phi2/convert.py index 819e0363..bf6ff937 100644 --- a/llms/phi2/convert.py +++ b/llms/phi2/convert.py @@ -1,7 +1,37 @@ +import argparse +import copy +import json +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn import numpy as np +from mlx.utils import tree_flatten, tree_map, tree_unflatten +from phi2 import ModelArgs, Phi2 from transformers import AutoModelForCausalLM +def quantize(weights, config, args): + quantized_config = copy.deepcopy(config) + + # Load the model: + model = Phi2(ModelArgs()) + weights = tree_map(mx.array, weights) + model.update(tree_unflatten(list(weights.items()))) + + # Quantize the model: + nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) + + # Update the config: + quantized_config["quantization"] = { + "group_size": args.q_group_size, + "bits": args.q_bits, + } + quantized_weights = dict(tree_flatten(model.parameters())) + + return quantized_weights, quantized_config + + def replace_key(key: str) -> str: if "wte.weight" in key: key = "wte.weight" @@ -12,12 +42,50 @@ def replace_key(key: str) -> str: def convert(): + parser = argparse.ArgumentParser(description="Convert Phi-2 weights to MLX") + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_model", + help="The path to save the MLX model.", + ) + parser.add_argument( + "-q", + "--quantize", + help="Generate a quantized model.", + action="store_true", + ) + parser.add_argument( + "--q_group_size", + help="Group size for quantization.", + type=int, + default=64, + ) + parser.add_argument( + "--q_bits", + help="Bits per weight for quantization.", + type=int, + default=4, + ) + args = parser.parse_args() + + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + model = AutoModelForCausalLM.from_pretrained( "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True ) state_dict = model.state_dict() weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} - np.savez("weights.npz", **weights) + params = {} + if args.quantize: + print("[INFO] Quantizing") + weights, params = quantize(weights, params, args) + + np.savez(str(mlx_path / "weights.npz"), **weights) + with open(mlx_path / "config.json", "w") as fid: + params["model_type"] = "phi2" + json.dump(params, fid, indent=4) if __name__ == "__main__": diff --git a/llms/phi2/phi2.py b/llms/phi2/phi2.py index bcca4209..f824549d 100644 --- a/llms/phi2/phi2.py +++ b/llms/phi2/phi2.py @@ -1,4 +1,5 @@ import argparse +import json import math from dataclasses import dataclass from pathlib import Path @@ -158,8 +159,16 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): def load_model(model_path: str): model = Phi2(ModelArgs()) model_path = Path(model_path) + with open(model_path / "config.json", "r") as f: + config = json.loads(f.read()) + config.pop("model_type", None) + quantization = config.pop("quantization", None) weights = mx.load(str(model_path / "weights.npz")) - model.update(tree_unflatten(list(weights.items()))) + weights = tree_unflatten(list(weights.items())) + if quantization is not None: + nn.QuantizedLinear.quantize_module(model, **quantization) + model.update(weights) + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) return model, tokenizer @@ -169,7 +178,7 @@ if __name__ == "__main__": parser.add_argument( "--model-path", type=str, - default=".", + default="mlx_model", help="The path to the model weights", ) parser.add_argument( diff --git a/llms/qwen/.gitignore b/llms/qwen/.gitignore deleted file mode 100644 index 0c68f15d..00000000 --- a/llms/qwen/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -weights.npz -config.json diff --git a/llms/qwen/README.md b/llms/qwen/README.md index f9276098..75154325 100644 --- a/llms/qwen/README.md +++ b/llms/qwen/README.md @@ -11,11 +11,15 @@ First download and convert the model with: ```sh python convert.py ``` -The script downloads the model from Hugging Face. The default model is -`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models. -The conversion script will make the `weights.npz` and `config.json` files in -the working directory. +To generate a 4-bit quantized model, use ``-q``. For a full list of options: + +The script downloads the model from Hugging Face. The default model is +`Qwen/Qwen-1_8B`. Check out the [Hugging Face +page](https://huggingface.co/Qwen) to see a list of available models. + +By default, the conversion script will make the directory `mlx_model` and save +the converted `weights.npz` and `config.json` there. ## Generate diff --git a/llms/qwen/convert.py b/llms/qwen/convert.py index c5fec060..88135208 100644 --- a/llms/qwen/convert.py +++ b/llms/qwen/convert.py @@ -1,8 +1,14 @@ import argparse +import copy import json +from pathlib import Path +import mlx.core as mx +import mlx.nn as nn import numpy as np import torch +from mlx.utils import tree_flatten, tree_map, tree_unflatten +from qwen import ModelArgs, Qwen from transformers import AutoModelForCausalLM @@ -14,19 +20,58 @@ def replace_key(key: str) -> str: return key -def convert(model_path: str = "Qwen/Qwen-1_8B"): +def quantize(weights, config, args): + quantized_config = copy.deepcopy(config) + + # Load the model: + model_args = ModelArgs() + model_args.vocab_size = config["vocab_size"] + model_args.hidden_size = config["hidden_size"] + model_args.num_attention_heads = config["num_attention_heads"] + model_args.num_hidden_layers = config["num_hidden_layers"] + model_args.kv_channels = config["kv_channels"] + model_args.max_position_embeddings = config["max_position_embeddings"] + model_args.layer_norm_epsilon = config["layer_norm_epsilon"] + model_args.intermediate_size = config["intermediate_size"] + model_args.no_bias = config["no_bias"] + model = Qwen(model_args) + + weights = tree_map(mx.array, weights) + model.update(tree_unflatten(list(weights.items()))) + + # Quantize the model: + nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) + + # Update the config: + quantized_config["quantization"] = { + "group_size": args.q_group_size, + "bits": args.q_bits, + } + quantized_weights = dict(tree_flatten(model.parameters())) + + return quantized_weights, quantized_config + + +def convert(args): + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + model = AutoModelForCausalLM.from_pretrained( - model_path, trust_remote_code=True, torch_dtype=torch.float16 + args.model, trust_remote_code=True, torch_dtype=torch.float16 ) state_dict = model.state_dict() weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} - np.savez("weights.npz", **weights) + config = model.config.to_dict() + + if args.quantize: + print("[INFO] Quantizing") + weights, config = quantize(weights, config, args) + + np.savez(str(mlx_path / "weights.npz"), **weights) # write config - config = model.config - config_dict = config.to_dict() - with open("config.json", "w") as f: - json.dump(config_dict, f, indent=4) + with open(mlx_path / "config.json", "w") as f: + json.dump(config, f, indent=4) if __name__ == "__main__": @@ -37,7 +82,29 @@ if __name__ == "__main__": help="The huggingface model to be converted", default="Qwen/Qwen-1_8B", ) - + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_model", + help="The path to save the MLX model.", + ) + parser.add_argument( + "-q", + "--quantize", + help="Generate a quantized model.", + action="store_true", + ) + parser.add_argument( + "--q_group_size", + help="Group size for quantization.", + type=int, + default=64, + ) + parser.add_argument( + "--q_bits", + help="Bits per weight for quantization.", + type=int, + default=4, + ) args = parser.parse_args() - - convert(args.model) + convert(args) diff --git a/llms/qwen/qwen.py b/llms/qwen/qwen.py index 3b153eb9..532a8031 100644 --- a/llms/qwen/qwen.py +++ b/llms/qwen/qwen.py @@ -1,6 +1,7 @@ import argparse import json from dataclasses import dataclass +from pathlib import Path import mlx.core as mx import mlx.nn as nn @@ -175,12 +176,11 @@ def generate(prompt: mx.array, model: Qwen, temp: 0.0): yield y -def load_model( - tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json" -): +def load_model(model_path: str, tokenizer_path: str = "Qwen/Qwen-1_8B"): model_args = ModelArgs() - with open(config_path, "r") as f: + model_path = Path(model_path) + with open(model_path / "config.json", "r") as f: config = json.load(f) model_args.vocab_size = config["vocab_size"] model_args.hidden_size = config["hidden_size"] @@ -193,9 +193,11 @@ def load_model( model_args.no_bias = config["no_bias"] model = Qwen(model_args) - - weights = mx.load("weights.npz") + weights = mx.load(str(model_path / "weights.npz")) + if quantization := config.get("quantization", False): + nn.QuantizedLinear.quantize_module(model, **quantization) model.update(tree_unflatten(list(weights.items()))) + tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>" ) @@ -204,6 +206,12 @@ def load_model( if __name__ == "__main__": parser = argparse.ArgumentParser(description="Qwen inference script") + parser.add_argument( + "--model-path", + type=str, + default="mlx_model", + help="The path to the model weights and config", + ) parser.add_argument( "--tokenizer", help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B", @@ -216,7 +224,7 @@ if __name__ == "__main__": default="蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是", ) parser.add_argument( - "--max_tokens", + "--max-tokens", "-m", type=int, default=100, @@ -233,7 +241,7 @@ if __name__ == "__main__": mx.random.seed(args.seed) - model, tokenizer = load_model(args.tokenizer) + model, tokenizer = load_model(args.model_path, args.tokenizer) prompt = tokenizer( args.prompt,